use chrono::{DateTime, Duration, Local};
use serde::Deserialize;
use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Deserialize)]
pub struct CredentialPool {
#[serde(flatten)]
pub providers: HashMap<String, ProviderPool>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ProviderPool {
pub strategy: Strategy,
pub keys: Vec<KeyEntry>,
#[serde(skip)]
next_index: Cell<usize>,
#[serde(skip)]
usage_counts: RefCell<Vec<u64>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct KeyEntry {
pub name: String,
pub env: String,
#[serde(skip)]
pub exhausted_until: Option<DateTime<Local>>,
}
#[derive(Debug, Clone, Copy, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum Strategy {
#[default]
FillFirst,
RoundRobin,
LeastUsed,
}
pub fn load_pool() -> Option<CredentialPool> {
load_pool_from_path(&crate::paths::aid_dir().join("credentials.toml"))
}
pub fn get_credential(pool: &CredentialPool, agent: &str) -> Option<(String, String)> {
let provider = pool.providers.get(agent)?;
let now = Local::now();
let mut usage_counts = provider.usage_counts.borrow_mut();
if usage_counts.len() != provider.keys.len() {
usage_counts.resize(provider.keys.len(), 0);
}
let selection = match provider.strategy {
Strategy::FillFirst => select_fill_first(&provider.keys, &now),
Strategy::RoundRobin => select_round_robin(provider, &now),
Strategy::LeastUsed => select_least_used(&provider.keys, &usage_counts, &now),
}?;
usage_counts[selection.0] = usage_counts[selection.0].saturating_add(1);
Some((provider.keys[selection.0].env.clone(), selection.1))
}
pub fn mark_exhausted(
pool: &mut CredentialPool,
agent: &str,
key_name: &str,
cooldown_mins: u64,
) {
let Some(provider) = pool.providers.get_mut(agent) else {
return;
};
let cooldown = match i64::try_from(cooldown_mins) {
Ok(value) => value,
Err(_) => i64::MAX,
};
let exhausted_until = Local::now() + Duration::minutes(cooldown);
if let Some(key) = provider.keys.iter_mut().find(|key| key.name == key_name) {
key.exhausted_until = Some(exhausted_until);
}
}
fn load_pool_from_path(path: &Path) -> Option<CredentialPool> {
let contents = std::fs::read_to_string(path).ok()?;
toml::from_str(&contents).ok()
}
fn select_fill_first(keys: &[KeyEntry], now: &DateTime<Local>) -> Option<(usize, String)> {
keys.iter()
.enumerate()
.find_map(|(index, key)| key_value_if_ready(key, now).map(|value| (index, value)))
}
fn select_round_robin(provider: &ProviderPool, now: &DateTime<Local>) -> Option<(usize, String)> {
let len = provider.keys.len();
if len == 0 {
return None;
}
let start = provider.next_index.get() % len;
for offset in 0..len {
let index = (start + offset) % len;
if let Some(value) = key_value_if_ready(&provider.keys[index], now) {
provider.next_index.set((index + 1) % len);
return Some((index, value));
}
}
None
}
fn select_least_used(
keys: &[KeyEntry],
usage_counts: &[u64],
now: &DateTime<Local>,
) -> Option<(usize, String)> {
let mut best: Option<(usize, u64, String)> = None;
for (index, key) in keys.iter().enumerate() {
let Some(value) = key_value_if_ready(key, now) else {
continue;
};
let usage = usage_counts.get(index).copied().unwrap_or(0);
if best.as_ref().is_none_or(|(_, best_usage, _)| usage < *best_usage) {
best = Some((index, usage, value));
}
}
best.map(|(index, _, value)| (index, value))
}
fn key_value_if_ready(key: &KeyEntry, now: &DateTime<Local>) -> Option<String> {
if key.exhausted_until.as_ref().is_some_and(|until| *until > *now) {
return None;
}
std::env::var(&key.env).ok().filter(|value| !value.is_empty())
}
#[cfg(test)]
mod tests {
use super::{CredentialPool, Strategy, get_credential, load_pool, load_pool_from_path, mark_exhausted};
use crate::paths::AidHomeGuard;
use std::fs;
use tempfile::TempDir;
const SAMPLE: &str = r#"
[codex]
strategy = "round_robin"
keys = [
{ name = "personal", env = "OPENAI_API_KEY" },
{ name = "team", env = "OPENAI_API_KEY_2" },
]
[gemini]
strategy = "fill_first"
keys = [
{ name = "default", env = "GEMINI_API_KEY" },
]
"#;
fn write_pool(dir: &TempDir) {
fs::write(dir.path().join("credentials.toml"), SAMPLE).unwrap();
}
#[test]
fn load_pool_parses_sample_toml() {
let dir = TempDir::new().unwrap();
write_pool(&dir);
let pool = load_pool_from_path(&dir.path().join("credentials.toml")).unwrap();
assert_eq!(pool.providers.len(), 2);
assert_eq!(pool.providers["codex"].strategy, Strategy::RoundRobin);
assert_eq!(pool.providers["codex"].keys[0].name, "personal");
}
#[test]
fn round_robin_returns_keys_in_order() {
let dir = TempDir::new().unwrap();
write_pool(&dir);
let pool: CredentialPool = load_pool_from_path(&dir.path().join("credentials.toml")).unwrap();
unsafe {
std::env::set_var("OPENAI_API_KEY", "first");
std::env::set_var("OPENAI_API_KEY_2", "second");
}
let first = get_credential(&pool, "codex");
let second = get_credential(&pool, "codex");
let third = get_credential(&pool, "codex");
assert_eq!(first, Some(("OPENAI_API_KEY".to_string(), "first".to_string())));
assert_eq!(second, Some(("OPENAI_API_KEY_2".to_string(), "second".to_string())));
assert_eq!(third, Some(("OPENAI_API_KEY".to_string(), "first".to_string())));
}
#[test]
fn fill_first_returns_first_non_exhausted_key() {
let dir = TempDir::new().unwrap();
write_pool(&dir);
let mut pool = load_pool_from_path(&dir.path().join("credentials.toml")).unwrap();
pool.providers.get_mut("gemini").unwrap().keys.push(super::KeyEntry {
name: "backup".to_string(),
env: "GEMINI_API_KEY_BACKUP".to_string(),
exhausted_until: None,
});
unsafe {
std::env::set_var("GEMINI_API_KEY", "primary");
std::env::set_var("GEMINI_API_KEY_BACKUP", "backup");
}
mark_exhausted(&mut pool, "gemini", "default", 5);
let selected = get_credential(&pool, "gemini");
assert_eq!(
selected,
Some(("GEMINI_API_KEY_BACKUP".to_string(), "backup".to_string()))
);
}
#[test]
fn mark_exhausted_skips_exhausted_keys() {
let dir = TempDir::new().unwrap();
write_pool(&dir);
let mut pool = load_pool_from_path(&dir.path().join("credentials.toml")).unwrap();
unsafe {
std::env::set_var("OPENAI_API_KEY", "first");
std::env::set_var("OPENAI_API_KEY_2", "second");
}
mark_exhausted(&mut pool, "codex", "personal", 5);
let selected = get_credential(&pool, "codex");
assert_eq!(selected, Some(("OPENAI_API_KEY_2".to_string(), "second".to_string())));
}
#[test]
fn missing_file_returns_none() {
let dir = TempDir::new().unwrap();
let _guard = AidHomeGuard::set(dir.path());
assert!(load_pool().is_none());
}
}