use crate::credential::{Credential, UsageStats};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct AuthPool {
#[serde(default)]
pub pool: HashMap<String, Credential>,
#[serde(default)]
pub defaults: HashMap<String, String>,
#[serde(default)]
pub order: HashMap<String, Vec<String>>,
#[serde(default)]
pub usage_stats: HashMap<String, UsageStats>,
}
impl AuthPool {
pub fn load(path: &Path) -> Result<Self> {
if !path.exists() {
return Ok(Self::default());
}
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read auth pool: {}", path.display()))?;
let pool: AuthPool = toml::from_str(&content)
.with_context(|| format!("Failed to parse auth pool: {}", path.display()))?;
Ok(pool)
}
pub fn save(&self, path: &Path) -> Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let content =
toml::to_string_pretty(self).context("Failed to serialize auth pool to TOML")?;
std::fs::write(path, &content)
.with_context(|| format!("Failed to write auth pool: {}", path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))?;
}
Ok(())
}
pub fn add(&mut self, name: &str, credential: Credential) {
let provider = credential.provider.clone();
self.pool.insert(name.to_string(), credential);
let order = self.order.entry(provider.clone()).or_default();
if !order.contains(&name.to_string()) {
order.push(name.to_string());
}
self.defaults
.entry(provider)
.or_insert_with(|| name.to_string());
}
pub fn remove(&mut self, name: &str) -> Result<()> {
let cred = self
.pool
.remove(name)
.ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
if let Some(order) = self.order.get_mut(&cred.provider) {
order.retain(|n| n != name);
}
if self.defaults.get(&cred.provider).map(|s| s.as_str()) == Some(name) {
if let Some(order) = self.order.get(&cred.provider) {
if let Some(next) = order.first() {
self.defaults.insert(cred.provider.clone(), next.clone());
} else {
self.defaults.remove(&cred.provider);
}
} else {
self.defaults.remove(&cred.provider);
}
}
self.usage_stats.remove(name);
Ok(())
}
pub fn get(&self, name: &str) -> Option<&Credential> {
self.pool.get(name)
}
pub fn get_default(&self, provider: &str) -> Option<(&str, &Credential)> {
self.defaults
.get(provider)
.and_then(|name| self.pool.get(name).map(|c| (name.as_str(), c)))
}
pub fn set_default(&mut self, name: &str) -> Result<()> {
let cred = self
.pool
.get(name)
.ok_or_else(|| anyhow::anyhow!("Credential '{}' not found in pool", name))?;
let provider = cred.provider.clone();
self.defaults.insert(provider.clone(), name.to_string());
if let Some(order) = self.order.get_mut(&provider) {
order.retain(|n| n != name);
order.insert(0, name.to_string());
}
Ok(())
}
pub fn credentials_for_provider(&self, provider: &str) -> Vec<(&str, &Credential)> {
if let Some(order) = self.order.get(provider) {
let mut result: Vec<(&str, &Credential)> = Vec::new();
for name in order {
if let Some(cred) = self.pool.get(name) {
result.push((name.as_str(), cred));
}
}
for (name, cred) in &self.pool {
if cred.provider == provider && !order.contains(name) {
result.push((name.as_str(), cred));
}
}
result
} else {
self.pool
.iter()
.filter(|(_, c)| c.provider == provider)
.map(|(n, c)| (n.as_str(), c))
.collect()
}
}
pub fn providers(&self) -> Vec<String> {
let mut providers: Vec<String> = self
.pool
.values()
.map(|c| c.provider.clone())
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
providers.sort();
providers
}
pub fn all_credentials(&self) -> Vec<(&str, &Credential)> {
let mut creds: Vec<(&str, &Credential)> =
self.pool.iter().map(|(n, c)| (n.as_str(), c)).collect();
creds.sort_by_key(|(n, _)| n.to_string());
creds
}
pub fn next_credential(&self, provider: &str, failed_name: &str) -> Option<(&str, &Credential)> {
let order = self.order.get(provider)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let failed_pos = order.iter().position(|n| n == failed_name).unwrap_or(0);
for i in 1..order.len() {
let idx = (failed_pos + i) % order.len();
let name = &order[idx];
if let Some(stats) = self.usage_stats.get(name) {
if let Some(cooldown) = stats.cooldown_until {
if now < cooldown {
continue;
}
}
}
if let Some(cred) = self.pool.get(name) {
return Some((name.as_str(), cred));
}
}
None
}
pub fn record_usage(&mut self, name: &str, success: bool) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
let stats = self.usage_stats.entry(name.to_string()).or_default();
stats.last_used = Some(now);
if !success {
let count = stats.error_count.unwrap_or(0) + 1;
stats.error_count = Some(count);
let cooldown_ms = if count >= 3 { 300_000 } else { 30_000 };
stats.cooldown_until = Some(now + cooldown_ms);
} else {
stats.error_count = Some(0);
stats.cooldown_until = None;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_cred(provider: &str, token: &str) -> Credential {
Credential {
provider: provider.to_string(),
cred_type: "token".to_string(),
token: Some(token.to_string()),
keychain_service: None,
}
}
#[test]
fn test_add_and_get() {
let mut pool = AuthPool::default();
pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
assert!(pool.get("anthropic:a").is_some());
assert_eq!(
pool.defaults.get("anthropic").map(|s| s.as_str()),
Some("anthropic:a")
);
}
#[test]
fn test_remove() {
let mut pool = AuthPool::default();
pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
pool.remove("anthropic:a").unwrap();
assert!(pool.get("anthropic:a").is_none());
assert_eq!(
pool.defaults.get("anthropic").map(|s| s.as_str()),
Some("anthropic:b")
);
}
#[test]
fn test_set_default() {
let mut pool = AuthPool::default();
pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
pool.set_default("anthropic:b").unwrap();
assert_eq!(
pool.defaults.get("anthropic").map(|s| s.as_str()),
Some("anthropic:b")
);
}
#[test]
fn test_next_credential() {
let mut pool = AuthPool::default();
pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
pool.add("anthropic:b", make_cred("anthropic", "sk-b"));
pool.add("anthropic:c", make_cred("anthropic", "sk-c"));
let next = pool.next_credential("anthropic", "anthropic:a");
assert!(next.is_some());
assert_eq!(next.unwrap().0, "anthropic:b");
}
#[test]
fn test_record_usage_cooldown() {
let mut pool = AuthPool::default();
pool.add("anthropic:a", make_cred("anthropic", "sk-a"));
pool.record_usage("anthropic:a", false);
let stats = pool.usage_stats.get("anthropic:a").unwrap();
assert_eq!(stats.error_count, Some(1));
assert!(stats.cooldown_until.is_some());
pool.record_usage("anthropic:a", true);
let stats = pool.usage_stats.get("anthropic:a").unwrap();
assert_eq!(stats.error_count, Some(0));
assert!(stats.cooldown_until.is_none());
}
#[test]
fn test_roundtrip_toml() {
let mut pool = AuthPool::default();
pool.add("anthropic:default", make_cred("anthropic", "sk-ant-test"));
let toml_str = toml::to_string_pretty(&pool).unwrap();
let loaded: AuthPool = toml::from_str(&toml_str).unwrap();
assert!(loaded.get("anthropic:default").is_some());
}
}