1use crate::credential::{Credential, UsageStats};
4use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::Path;
8
9#[derive(Debug, Clone, Serialize, Deserialize, Default)]
11pub struct AuthPool {
12 #[serde(default)]
13 pub pool: HashMap<String, Credential>,
14 #[serde(default)]
15 pub defaults: HashMap<String, String>,
16 #[serde(default)]
17 pub order: HashMap<String, Vec<String>>,
18 #[serde(default)]
19 pub usage_stats: HashMap<String, UsageStats>,
20}
21
22impl AuthPool {
23 pub fn load(path: &Path) -> Result<Self> {
25 if !path.exists() {
26 return Ok(Self::default());
27 }
28 let content = std::fs::read_to_string(path)
29 .with_context(|| format!("Failed to read auth pool: {}", path.display()))?;
30 let pool: AuthPool = toml::from_str(&content)
31 .with_context(|| format!("Failed to parse auth pool: {}", path.display()))?;
32 Ok(pool)
33 }
34
35 pub fn save(&self, path: &Path) -> Result<()> {
37 if let Some(parent) = path.parent() {
38 std::fs::create_dir_all(parent)?;
39 }
40 let content =
41 toml::to_string_pretty(self).context("Failed to serialize auth pool to TOML")?;
42 std::fs::write(path, &content)
43 .with_context(|| format!("Failed to write auth pool: {}", path.display()))?;
44 #[cfg(unix)]
45 {
46 use std::os::unix::fs::PermissionsExt;
47 std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
48 }
49 Ok(())
50 }
51
52 pub fn add(&mut self, name: &str, credential: Credential) {
54 let provider = credential.provider.clone();
55 self.pool.insert(name.to_string(), credential);
56
57 let order = self.order.entry(provider.clone()).or_default();
58 if !order.contains(&name.to_string()) {
59 order.push(name.to_string());
60 }
61
62 self.defaults
63 .entry(provider)
64 .or_insert_with(|| name.to_string());
65 }
66
67 pub fn remove(&mut self, name: &str) -> Result<()> {
69 let cred = self
70 .pool
71 .remove(name)
72 .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
73
74 if let Some(order) = self.order.get_mut(&cred.provider) {
75 order.retain(|n| n != name);
76 }
77
78 if self.defaults.get(&cred.provider).map(|s| s.as_str()) == Some(name) {
79 if let Some(order) = self.order.get(&cred.provider) {
80 if let Some(next) = order.first() {
81 self.defaults.insert(cred.provider.clone(), next.clone());
82 } else {
83 self.defaults.remove(&cred.provider);
84 }
85 } else {
86 self.defaults.remove(&cred.provider);
87 }
88 }
89
90 self.usage_stats.remove(name);
91 Ok(())
92 }
93
94 pub fn get(&self, name: &str) -> Option<&Credential> {
96 self.pool.get(name)
97 }
98
99 pub fn get_default(&self, provider: &str) -> Option<(&str, &Credential)> {
101 self.defaults
102 .get(provider)
103 .and_then(|name| self.pool.get(name).map(|c| (name.as_str(), c)))
104 }
105
106 pub fn set_default(&mut self, name: &str) -> Result<()> {
108 let cred = self
109 .pool
110 .get(name)
111 .ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
112 let provider = cred.provider.clone();
113 self.defaults.insert(provider.clone(), name.to_string());
114
115 if let Some(order) = self.order.get_mut(&provider) {
116 order.retain(|n| n != name);
117 order.insert(0, name.to_string());
118 }
119
120 Ok(())
121 }
122
123 pub fn credentials_for_provider(&self, provider: &str) -> Vec<(&str, &Credential)> {
125 if let Some(order) = self.order.get(provider) {
126 let mut result: Vec<(&str, &Credential)> = Vec::new();
127 for name in order {
128 if let Some(cred) = self.pool.get(name) {
129 result.push((name.as_str(), cred));
130 }
131 }
132 for (name, cred) in &self.pool {
133 if cred.provider == provider && !order.contains(name) {
134 result.push((name.as_str(), cred));
135 }
136 }
137 result
138 } else {
139 self.pool
140 .iter()
141 .filter(|(_, c)| c.provider == provider)
142 .map(|(n, c)| (n.as_str(), c))
143 .collect()
144 }
145 }
146
147 pub fn providers(&self) -> Vec<String> {
149 let mut providers: Vec<String> = self
150 .pool
151 .values()
152 .map(|c| c.provider.clone())
153 .collect::<std::collections::HashSet<_>>()
154 .into_iter()
155 .collect();
156 providers.sort();
157 providers
158 }
159
160 pub fn all_credentials(&self) -> Vec<(&str, &Credential)> {
162 let mut creds: Vec<(&str, &Credential)> =
163 self.pool.iter().map(|(n, c)| (n.as_str(), c)).collect();
164 creds.sort_by_key(|(n, _)| n.to_string());
165 creds
166 }
167
168 pub fn next_credential(&self, provider: &str, failed_name: &str) -> Option<(&str, &Credential)> {
172 let order = self.order.get(provider)?;
173 let now = std::time::SystemTime::now()
174 .duration_since(std::time::UNIX_EPOCH)
175 .unwrap_or_default()
176 .as_millis() as u64;
177
178 let failed_pos = order.iter().position(|n| n == failed_name).unwrap_or(0);
180
181 for i in 1..order.len() {
183 let idx = (failed_pos + i) % order.len();
184 let name = &order[idx];
185
186 if let Some(stats) = self.usage_stats.get(name) {
188 if let Some(cooldown) = stats.cooldown_until {
189 if now < cooldown {
190 continue;
191 }
192 }
193 }
194
195 if let Some(cred) = self.pool.get(name) {
196 return Some((name.as_str(), cred));
197 }
198 }
199
200 None
201 }
202
203 pub fn record_usage(&mut self, name: &str, success: bool) {
205 let now = std::time::SystemTime::now()
206 .duration_since(std::time::UNIX_EPOCH)
207 .unwrap_or_default()
208 .as_millis() as u64;
209
210 let stats = self.usage_stats.entry(name.to_string()).or_default();
211 stats.last_used = Some(now);
212
213 if !success {
214 let count = stats.error_count.unwrap_or(0) + 1;
215 stats.error_count = Some(count);
216 let cooldown_ms = if count >= 3 { 300_000 } else { 30_000 };
218 stats.cooldown_until = Some(now + cooldown_ms);
219 } else {
220 stats.error_count = Some(0);
221 stats.cooldown_until = None;
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn make_cred(provider: &str, token: &str) -> Credential {
231 Credential {
232 provider: provider.to_string(),
233 cred_type: "token".to_string(),
234 token: Some(token.to_string()),
235 keychain_service: None,
236 }
237 }
238
239 #[test]
240 fn test_add_and_get() {
241 let mut pool = AuthPool::default();
242 pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
243 assert!(pool.get("anthropic:a").is_some());
244 assert_eq!(
245 pool.defaults.get("anthropic").map(|s| s.as_str()),
246 Some("anthropic:a")
247 );
248 }
249
250 #[test]
251 fn test_remove() {
252 let mut pool = AuthPool::default();
253 pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
254 pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
255 pool.remove("anthropic:a").unwrap();
256 assert!(pool.get("anthropic:a").is_none());
257 assert_eq!(
258 pool.defaults.get("anthropic").map(|s| s.as_str()),
259 Some("anthropic:b")
260 );
261 }
262
263 #[test]
264 fn test_set_default() {
265 let mut pool = AuthPool::default();
266 pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
267 pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
268 pool.set_default("anthropic:b").unwrap();
269 assert_eq!(
270 pool.defaults.get("anthropic").map(|s| s.as_str()),
271 Some("anthropic:b")
272 );
273 }
274
275 #[test]
276 fn test_next_credential() {
277 let mut pool = AuthPool::default();
278 pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
279 pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
280 pool.add("anthropic:c", make_cred("anthropic", "sk-c"));
281
282 let next = pool.next_credential("anthropic", "anthropic:a");
283 assert!(next.is_some());
284 assert_eq!(next.unwrap().0, "anthropic:b");
285 }
286
287 #[test]
288 fn test_record_usage_cooldown() {
289 let mut pool = AuthPool::default();
290 pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
291
292 pool.record_usage("anthropic:a", false);
293 let stats = pool.usage_stats.get("anthropic:a").unwrap();
294 assert_eq!(stats.error_count, Some(1));
295 assert!(stats.cooldown_until.is_some());
296
297 pool.record_usage("anthropic:a", true);
298 let stats = pool.usage_stats.get("anthropic:a").unwrap();
299 assert_eq!(stats.error_count, Some(0));
300 assert!(stats.cooldown_until.is_none());
301 }
302
303 #[test]
304 fn test_roundtrip_toml() {
305 let mut pool = AuthPool::default();
306 pool.add("anthropic:default", make_cred("anthropic", "sk-ant-test"));
307 let toml_str = toml::to_string_pretty(&pool).unwrap();
308 let loaded: AuthPool = toml::from_str(&toml_str).unwrap();
309 assert!(loaded.get("anthropic:default").is_some());
310 }
311}