Skip to main content

agentctl_auth/
pool.rs

1//! Auth pool: load, save, manage credentials.
2
3use crate::credential::{Credential, UsageStats};
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::Path;
8
9/// The complete auth pool (serialized to TOML).
10#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct AuthPool {
12    #[serde(default)]
13    pub pool: HashMap<String, Credential>,
14    #[serde(default)]
15    pub defaults: HashMap<String, String>,
16    #[serde(default)]
17    pub order: HashMap<String, Vec<String>>,
18    #[serde(default)]
19    pub usage_stats: HashMap<String, UsageStats>,
20}
21
22impl AuthPool {
23    /// Load auth pool from a TOML file.
24    pub fn load(path: &Path) -> Result<Self> {
25        if !path.exists() {
26            return Ok(Self::default());
27        }
28        let content = std::fs::read_to_string(path)
29            .with_context(|| format!("Failed to read auth pool: {}", path.display()))?;
30        let pool: AuthPool = toml::from_str(&content)
31            .with_context(|| format!("Failed to parse auth pool: {}", path.display()))?;
32        Ok(pool)
33    }
34
35    /// Save auth pool to a TOML file (permissions 600).
36    pub fn save(&self, path: &Path) -> Result<()> {
37        if let Some(parent) = path.parent() {
38            std::fs::create_dir_all(parent)?;
39        }
40        let content =
41            toml::to_string_pretty(self).context("Failed to serialize auth pool to TOML")?;
42        std::fs::write(path, &content)
43            .with_context(|| format!("Failed to write auth pool: {}", path.display()))?;
44        #[cfg(unix)]
45        {
46            use std::os::unix::fs::PermissionsExt;
47            std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
48        }
49        Ok(())
50    }
51
52    /// Add a credential to the pool.
53    pub fn add(&mut self, name: &str, credential: Credential) {
54        let provider = credential.provider.clone();
55        self.pool.insert(name.to_string(), credential);
56
57        let order = self.order.entry(provider.clone()).or_default();
58        if !order.contains(&name.to_string()) {
59            order.push(name.to_string());
60        }
61
62        self.defaults
63            .entry(provider)
64            .or_insert_with(|| name.to_string());
65    }
66
67    /// Remove a credential from the pool.
68    pub fn remove(&mut self, name: &str) -> Result<()> {
69        let cred = self
70            .pool
71            .remove(name)
72            .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
73
74        if let Some(order) = self.order.get_mut(&cred.provider) {
75            order.retain(|n| n != name);
76        }
77
78        if self.defaults.get(&cred.provider).map(|s| s.as_str()) == Some(name) {
79            if let Some(order) = self.order.get(&cred.provider) {
80                if let Some(next) = order.first() {
81                    self.defaults.insert(cred.provider.clone(), next.clone());
82                } else {
83                    self.defaults.remove(&cred.provider);
84                }
85            } else {
86                self.defaults.remove(&cred.provider);
87            }
88        }
89
90        self.usage_stats.remove(name);
91        Ok(())
92    }
93
94    /// Get a credential by name.
95    pub fn get(&self, name: &str) -> Option<&Credential> {
96        self.pool.get(name)
97    }
98
99    /// Get the default credential for a provider.
100    pub fn get_default(&self, provider: &str) -> Option<(&str, &Credential)> {
101        self.defaults
102            .get(provider)
103            .and_then(|name| self.pool.get(name).map(|c| (name.as_str(), c)))
104    }
105
106    /// Set a credential as the default for its provider.
107    pub fn set_default(&mut self, name: &str) -> Result<()> {
108        let cred = self
109            .pool
110            .get(name)
111            .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
112        let provider = cred.provider.clone();
113        self.defaults.insert(provider.clone(), name.to_string());
114
115        if let Some(order) = self.order.get_mut(&provider) {
116            order.retain(|n| n != name);
117            order.insert(0, name.to_string());
118        }
119
120        Ok(())
121    }
122
123    /// Get all credentials for a provider, in order.
124    pub fn credentials_for_provider(&self, provider: &str) -> Vec<(&str, &Credential)> {
125        if let Some(order) = self.order.get(provider) {
126            let mut result: Vec<(&str, &Credential)> = Vec::new();
127            for name in order {
128                if let Some(cred) = self.pool.get(name) {
129                    result.push((name.as_str(), cred));
130                }
131            }
132            for (name, cred) in &self.pool {
133                if cred.provider == provider && !order.contains(name) {
134                    result.push((name.as_str(), cred));
135                }
136            }
137            result
138        } else {
139            self.pool
140                .iter()
141                .filter(|(_, c)| c.provider == provider)
142                .map(|(n, c)| (n.as_str(), c))
143                .collect()
144        }
145    }
146
147    /// List all unique providers.
148    pub fn providers(&self) -> Vec<String> {
149        let mut providers: Vec<String> = self
150            .pool
151            .values()
152            .map(|c| c.provider.clone())
153            .collect::<std::collections::HashSet<_>>()
154            .into_iter()
155            .collect();
156        providers.sort();
157        providers
158    }
159
160    /// List all credentials sorted by name.
161    pub fn all_credentials(&self) -> Vec<(&str, &Credential)> {
162        let mut creds: Vec<(&str, &Credential)> =
163            self.pool.iter().map(|(n, c)| (n.as_str(), c)).collect();
164        creds.sort_by_key(|(n, _)| n.to_string());
165        creds
166    }
167
168    /// Get the next credential to try for a provider after a failure.
169    ///
170    /// Rotates through the order list, skipping credentials in cooldown.
171    pub fn next_credential(&self, provider: &str, failed_name: &str) -> Option<(&str, &Credential)> {
172        let order = self.order.get(provider)?;
173        let now = std::time::SystemTime::now()
174            .duration_since(std::time::UNIX_EPOCH)
175            .unwrap_or_default()
176            .as_millis() as u64;
177
178        // Find the failed credential's position
179        let failed_pos = order.iter().position(|n| n == failed_name).unwrap_or(0);
180
181        // Try credentials after the failed one
182        for i in 1..order.len() {
183            let idx = (failed_pos + i) % order.len();
184            let name = &order[idx];
185
186            // Skip if in cooldown
187            if let Some(stats) = self.usage_stats.get(name) {
188                if let Some(cooldown) = stats.cooldown_until {
189                    if now < cooldown {
190                        continue;
191                    }
192                }
193            }
194
195            if let Some(cred) = self.pool.get(name) {
196                return Some((name.as_str(), cred));
197            }
198        }
199
200        None
201    }
202
203    /// Record a usage event for a credential.
204    pub fn record_usage(&mut self, name: &str, success: bool) {
205        let now = std::time::SystemTime::now()
206            .duration_since(std::time::UNIX_EPOCH)
207            .unwrap_or_default()
208            .as_millis() as u64;
209
210        let stats = self.usage_stats.entry(name.to_string()).or_default();
211        stats.last_used = Some(now);
212
213        if !success {
214            let count = stats.error_count.unwrap_or(0) + 1;
215            stats.error_count = Some(count);
216            // Cooldown: 30s after first error, 5min after 3+
217            let cooldown_ms = if count >= 3 { 300_000 } else { 30_000 };
218            stats.cooldown_until = Some(now + cooldown_ms);
219        } else {
220            stats.error_count = Some(0);
221            stats.cooldown_until = None;
222        }
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn make_cred(provider: &str, token: &str) -> Credential {
231        Credential {
232            provider: provider.to_string(),
233            cred_type: "token".to_string(),
234            token: Some(token.to_string()),
235            keychain_service: None,
236        }
237    }
238
239    #[test]
240    fn test_add_and_get() {
241        let mut pool = AuthPool::default();
242        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
243        assert!(pool.get("anthropic:a").is_some());
244        assert_eq!(
245            pool.defaults.get("anthropic").map(|s| s.as_str()),
246            Some("anthropic:a")
247        );
248    }
249
250    #[test]
251    fn test_remove() {
252        let mut pool = AuthPool::default();
253        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
254        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
255        pool.remove("anthropic:a").unwrap();
256        assert!(pool.get("anthropic:a").is_none());
257        assert_eq!(
258            pool.defaults.get("anthropic").map(|s| s.as_str()),
259            Some("anthropic:b")
260        );
261    }
262
263    #[test]
264    fn test_set_default() {
265        let mut pool = AuthPool::default();
266        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
267        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
268        pool.set_default("anthropic:b").unwrap();
269        assert_eq!(
270            pool.defaults.get("anthropic").map(|s| s.as_str()),
271            Some("anthropic:b")
272        );
273    }
274
275    #[test]
276    fn test_next_credential() {
277        let mut pool = AuthPool::default();
278        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
279        pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
280        pool.add("anthropic:c", make_cred("anthropic", "sk-c"));
281
282        let next = pool.next_credential("anthropic", "anthropic:a");
283        assert!(next.is_some());
284        assert_eq!(next.unwrap().0, "anthropic:b");
285    }
286
287    #[test]
288    fn test_record_usage_cooldown() {
289        let mut pool = AuthPool::default();
290        pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
291
292        pool.record_usage("anthropic:a", false);
293        let stats = pool.usage_stats.get("anthropic:a").unwrap();
294        assert_eq!(stats.error_count, Some(1));
295        assert!(stats.cooldown_until.is_some());
296
297        pool.record_usage("anthropic:a", true);
298        let stats = pool.usage_stats.get("anthropic:a").unwrap();
299        assert_eq!(stats.error_count, Some(0));
300        assert!(stats.cooldown_until.is_none());
301    }
302
303    #[test]
304    fn test_roundtrip_toml() {
305        let mut pool = AuthPool::default();
306        pool.add("anthropic:default", make_cred("anthropic", "sk-ant-test"));
307        let toml_str = toml::to_string_pretty(&pool).unwrap();
308        let loaded: AuthPool = toml::from_str(&toml_str).unwrap();
309        assert!(loaded.get("anthropic:default").is_some());
310    }
311}