agentctl-auth 0.1.0

Unified auth pool and LLM API client for Claude Max Plan, OpenAI, and more
Documentation
//! Auth pool: load, save, manage credentials.

use crate::credential::{Credential, UsageStats};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;

/// The complete auth pool (serialized to TOML).
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuthPool {
    #[serde(default)]
    pub pool: HashMap<String, Credential>,
    #[serde(default)]
    pub defaults: HashMap<String, String>,
    #[serde(default)]
    pub order: HashMap<String, Vec<String>>,
    #[serde(default)]
    pub usage_stats: HashMap<String, UsageStats>,
}

impl AuthPool {
    /// Load auth pool from a TOML file.
    pub fn load(path: &Path) -> Result<Self> {
        if !path.exists() {
            return Ok(Self::default());
        }
        let content = std::fs::read_to_string(path)
            .with_context(|| format!("Failed to read auth pool: {}", path.display()))?;
        let pool: AuthPool = toml::from_str(&content)
            .with_context(|| format!("Failed to parse auth pool: {}", path.display()))?;
        Ok(pool)
    }

    /// Save auth pool to a TOML file (permissions 600).
    pub fn save(&self, path: &Path) -> Result<()> {
        if let Some(parent) = path.parent() {
            std::fs::create_dir_all(parent)?;
        }
        let content =
            toml::to_string_pretty(self).context("Failed to serialize auth pool to TOML")?;
        std::fs::write(path, &content)
            .with_context(|| format!("Failed to write auth pool: {}", path.display()))?;
        #[cfg(unix)]
        {
            use std::os::unix::fs::PermissionsExt;
            std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
        }
        Ok(())
    }

    /// Add a credential to the pool.
    pub fn add(&mut self, name: &str, credential: Credential) {
        let provider = credential.provider.clone();
        self.pool.insert(name.to_string(), credential);

        let order = self.order.entry(provider.clone()).or_default();
        if !order.contains(&name.to_string()) {
            order.push(name.to_string());
        }

        self.defaults
            .entry(provider)
            .or_insert_with(|| name.to_string());
    }

    /// Remove a credential from the pool.
    pub fn remove(&mut self, name: &str) -> Result<()> {
        let cred = self
            .pool
            .remove(name)
            .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;

        if let Some(order) = self.order.get_mut(&cred.provider) {
            order.retain(|n| n != name);
        }

        if self.defaults.get(&cred.provider).map(|s| s.as_str()) == Some(name) {
            if let Some(order) = self.order.get(&cred.provider) {
                if let Some(next) = order.first() {
                    self.defaults.insert(cred.provider.clone(), next.clone());
                } else {
                    self.defaults.remove(&cred.provider);
                }
            } else {
                self.defaults.remove(&cred.provider);
            }
        }

        self.usage_stats.remove(name);
        Ok(())
    }

    /// Get a credential by name.
    pub fn get(&self, name: &str) -> Option<&Credential> {
        self.pool.get(name)
    }

    /// Get the default credential for a provider.
    pub fn get_default(&self, provider: &str) -> Option<(&str, &Credential)> {
        self.defaults
            .get(provider)
            .and_then(|name| self.pool.get(name).map(|c| (name.as_str(), c)))
    }

    /// Set a credential as the default for its provider.
    pub fn set_default(&mut self, name: &str) -> Result<()> {
        let cred = self
            .pool
            .get(name)
            .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
        let provider = cred.provider.clone();
        self.defaults.insert(provider.clone(), name.to_string());

        if let Some(order) = self.order.get_mut(&provider) {
            order.retain(|n| n != name);
            order.insert(0, name.to_string());
        }

        Ok(())
    }

    /// Get all credentials for a provider, in order.
    pub fn credentials_for_provider(&self, provider: &str) -> Vec<(&str, &Credential)> {
        if let Some(order) = self.order.get(provider) {
            let mut result: Vec<(&str, &Credential)> = Vec::new();
            for name in order {
                if let Some(cred) = self.pool.get(name) {
                    result.push((name.as_str(), cred));
                }
            }
            for (name, cred) in &self.pool {
                if cred.provider == provider && !order.contains(name) {
                    result.push((name.as_str(), cred));
                }
            }
            result
        } else {
            self.pool
                .iter()
                .filter(|(_, c)| c.provider == provider)
                .map(|(n, c)| (n.as_str(), c))
                .collect()
        }
    }

    /// List all unique providers.
    pub fn providers(&self) -> Vec<String> {
        let mut providers: Vec<String> = self
            .pool
            .values()
            .map(|c| c.provider.clone())
            .collect::<std::collections::HashSet<_>>()
            .into_iter()
            .collect();
        providers.sort();
        providers
    }

    /// List all credentials sorted by name.
    pub fn all_credentials(&self) -> Vec<(&str, &Credential)> {
        let mut creds: Vec<(&str, &Credential)> =
            self.pool.iter().map(|(n, c)| (n.as_str(), c)).collect();
        creds.sort_by_key(|(n, _)| n.to_string());
        creds
    }

    /// Get the next credential to try for a provider after a failure.
    ///
    /// Rotates through the order list, skipping credentials in cooldown.
    pub fn next_credential(&self, provider: &str, failed_name: &str) -> Option<(&str, &Credential)> {
        let order = self.order.get(provider)?;
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_millis() as u64;

        // Find the failed credential's position
        let failed_pos = order.iter().position(|n| n == failed_name).unwrap_or(0);

        // Try credentials after the failed one
        for i in 1..order.len() {
            let idx = (failed_pos + i) % order.len();
            let name = &order[idx];

            // Skip if in cooldown
            if let Some(stats) = self.usage_stats.get(name) {
                if let Some(cooldown) = stats.cooldown_until {
                    if now < cooldown {
                        continue;
                    }
                }
            }

            if let Some(cred) = self.pool.get(name) {
                return Some((name.as_str(), cred));
            }
        }

        None
    }

    /// Record a usage event for a credential.
    pub fn record_usage(&mut self, name: &str, success: bool) {
        let now = std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap_or_default()
            .as_millis() as u64;

        let stats = self.usage_stats.entry(name.to_string()).or_default();
        stats.last_used = Some(now);

        if !success {
            let count = stats.error_count.unwrap_or(0) + 1;
            stats.error_count = Some(count);
            // Cooldown: 30s after first error, 5min after 3+
            let cooldown_ms = if count >= 3 { 300_000 } else { 30_000 };
            stats.cooldown_until = Some(now + cooldown_ms);
        } else {
            stats.error_count = Some(0);
            stats.cooldown_until = None;
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn make_cred(provider: &str, token: &str) -> Credential {
        Credential {
            provider: provider.to_string(),
            cred_type: "token".to_string(),
            token: Some(token.to_string()),
            keychain_service: None,
        }
    }

    #[test]
    fn test_add_and_get() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
        assert!(pool.get("anthropic:a").is_some());
        assert_eq!(
            pool.defaults.get("anthropic").map(|s| s.as_str()),
            Some("anthropic:a")
        );
    }

    #[test]
    fn test_remove() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
        pool.remove("anthropic:a").unwrap();
        assert!(pool.get("anthropic:a").is_none());
        assert_eq!(
            pool.defaults.get("anthropic").map(|s| s.as_str()),
            Some("anthropic:b")
        );
    }

    #[test]
    fn test_set_default() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
        pool.set_default("anthropic:b").unwrap();
        assert_eq!(
            pool.defaults.get("anthropic").map(|s| s.as_str()),
            Some("anthropic:b")
        );
    }

    #[test]
    fn test_next_credential() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
        pool.add("anthropic:c", make_cred("anthropic", "sk-c"));

        let next = pool.next_credential("anthropic", "anthropic:a");
        assert!(next.is_some());
        assert_eq!(next.unwrap().0, "anthropic:b");
    }

    #[test]
    fn test_record_usage_cooldown() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));

        pool.record_usage("anthropic:a", false);
        let stats = pool.usage_stats.get("anthropic:a").unwrap();
        assert_eq!(stats.error_count, Some(1));
        assert!(stats.cooldown_until.is_some());

        pool.record_usage("anthropic:a", true);
        let stats = pool.usage_stats.get("anthropic:a").unwrap();
        assert_eq!(stats.error_count, Some(0));
        assert!(stats.cooldown_until.is_none());
    }

    #[test]
    fn test_roundtrip_toml() {
        let mut pool = AuthPool::default();
        pool.add("anthropic:default", make_cred("anthropic", "sk-ant-test"));
        let toml_str = toml::to_string_pretty(&pool).unwrap();
        let loaded: AuthPool = toml::from_str(&toml_str).unwrap();
        assert!(loaded.get("anthropic:default").is_some());
    }
}