Skip to main content

hermes_agent_cli_core/
credential_pool.rs

1//! Credential pool with automatic failover on rate-limit errors.
2//!
3//! Supports multiple API keys per provider. Rotates keys on 429/rate-limit
4//! responses and tracks which keys are temporarily disabled.
5
6use std::collections::HashMap;
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10/// A single credential entry with health tracking.
11#[derive(Debug, Clone)]
12pub struct CredentialEntry {
13    pub api_key: String,
14    pub base_url: Option<String>,
15    /// When this key was rate-limited. `None` if healthy.
16    pub rate_limited_at: Option<Instant>,
17    /// How long to wait before retrying this key.
18    pub cooldown: Duration,
19    /// Number of consecutive failures.
20    pub failures: u32,
21}
22
23impl CredentialEntry {
24    pub fn new(api_key: String, base_url: Option<String>) -> Self {
25        Self {
26            api_key,
27            base_url,
28            rate_limited_at: None,
29            cooldown: Duration::from_secs(60),
30            failures: 0,
31        }
32    }
33
34    /// Whether this credential is available (not in cooldown).
35    pub fn is_available(&self) -> bool {
36        match self.rate_limited_at {
37            None => true,
38            Some(t) => t.elapsed() >= self.cooldown,
39        }
40    }
41
42    /// Mark this credential as rate-limited.
43    pub fn mark_rate_limited(&mut self, retry_after: Option<u64>) {
44        self.rate_limited_at = Some(Instant::now());
45        self.cooldown = retry_after.map(Duration::from_secs).unwrap_or(Duration::from_secs(60));
46        self.failures += 1;
47    }
48
49    /// Mark this credential as successfully used.
50    pub fn mark_success(&mut self) {
51        self.rate_limited_at = None;
52        self.failures = 0;
53    }
54}
55
56/// Pool of credentials per provider with round-robin selection and failover.
57pub struct CredentialPool {
58    /// provider_name → list of credentials
59    pool: Mutex<HashMap<String, Vec<CredentialEntry>>>,
60    /// provider_name → current index for round-robin
61    index: Mutex<HashMap<String, usize>>,
62}
63
64impl Default for CredentialPool {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl CredentialPool {
71    pub fn new() -> Self {
72        Self { pool: Mutex::new(HashMap::new()), index: Mutex::new(HashMap::new()) }
73    }
74
75    /// Create a pool pre-loaded from an AuthStore's credentials.
76    pub fn from_auth_store(auth: &crate::auth::AuthStore) -> Self {
77        let pool = Self::new();
78        for cred in &auth.credentials {
79            pool.add(&cred.provider, cred.api_key.clone(), cred.base_url.clone());
80        }
81        pool
82    }
83
84    /// Add a credential for a provider.
85    pub fn add(&self, provider: &str, api_key: String, base_url: Option<String>) {
86        let mut pool = self.pool.lock().unwrap();
87        let entries = pool.entry(provider.to_string()).or_default();
88
89        // Don't add duplicates
90        if entries.iter().any(|e| e.api_key == api_key) {
91            return;
92        }
93
94        entries.push(CredentialEntry::new(api_key, base_url));
95    }
96
97    /// Remove all credentials for a provider.
98    pub fn remove(&self, provider: &str) {
99        let mut pool = self.pool.lock().unwrap();
100        pool.remove(provider);
101        self.index.lock().unwrap().remove(provider);
102    }
103
104    /// Get the next available credential for a provider using round-robin.
105    ///
106    /// Returns `None` if no credentials are registered or all are in cooldown.
107    pub fn get(&self, provider: &str) -> Option<CredentialEntry> {
108        let pool = self.pool.lock().unwrap();
109        let entries = pool.get(provider)?;
110
111        if entries.is_empty() {
112            return None;
113        }
114
115        let mut index = self.index.lock().unwrap();
116        let start = *index.entry(provider.to_string()).or_insert(0);
117        let len = entries.len();
118
119        // Try each credential starting from current index
120        for i in 0..len {
121            let idx = (start + i) % len;
122            if entries[idx].is_available() {
123                *index.get_mut(provider).unwrap() = (idx + 1) % len;
124                return Some(entries[idx].clone());
125            }
126        }
127
128        // All in cooldown — return the one closest to expiry as last resort
129        let best = entries.iter().min_by_key(|e| e.rate_limited_at.map(|t| t.elapsed())).unwrap();
130        Some(best.clone())
131    }
132
133    /// Report that a credential hit a rate limit.
134    pub fn report_rate_limit(&self, provider: &str, api_key: &str, retry_after: Option<u64>) {
135        let mut pool = self.pool.lock().unwrap();
136        if let Some(entries) = pool.get_mut(provider) {
137            if let Some(entry) = entries.iter_mut().find(|e| e.api_key == api_key) {
138                entry.mark_rate_limited(retry_after);
139            }
140        }
141    }
142
143    /// Report that a credential was used successfully.
144    pub fn report_success(&self, provider: &str, api_key: &str) {
145        let mut pool = self.pool.lock().unwrap();
146        if let Some(entries) = pool.get_mut(provider) {
147            if let Some(entry) = entries.iter_mut().find(|e| e.api_key == api_key) {
148                entry.mark_success();
149            }
150        }
151    }
152
153    /// Get the number of credentials for a provider.
154    pub fn count(&self, provider: &str) -> usize {
155        self.pool.lock().unwrap().get(provider).map(|v| v.len()).unwrap_or(0)
156    }
157
158    /// Get the number of available (non-cooldown) credentials for a provider.
159    pub fn available_count(&self, provider: &str) -> usize {
160        self.pool
161            .lock()
162            .unwrap()
163            .get(provider)
164            .map(|v| v.iter().filter(|e| e.is_available()).count())
165            .unwrap_or(0)
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn test_pool_add_and_get() {
175        let pool = CredentialPool::new();
176        pool.add("openai", "key1".to_string(), None);
177        pool.add("openai", "key2".to_string(), None);
178
179        let cred = pool.get("openai").unwrap();
180        assert_eq!(cred.api_key, "key1");
181    }
182
183    #[test]
184    fn test_pool_round_robin() {
185        let pool = CredentialPool::new();
186        pool.add("openai", "key1".to_string(), None);
187        pool.add("openai", "key2".to_string(), None);
188
189        let first = pool.get("openai").unwrap();
190        let second = pool.get("openai").unwrap();
191        assert_ne!(first.api_key, second.api_key);
192    }
193
194    #[test]
195    fn test_pool_no_duplicates() {
196        let pool = CredentialPool::new();
197        pool.add("openai", "key1".to_string(), None);
198        pool.add("openai", "key1".to_string(), None);
199        assert_eq!(pool.count("openai"), 1);
200    }
201
202    #[test]
203    fn test_pool_rate_limit_cooldown() {
204        let pool = CredentialPool::new();
205        pool.add("openai", "key1".to_string(), None);
206        pool.add("openai", "key2".to_string(), None);
207
208        pool.report_rate_limit("openai", "key1", None);
209
210        // key1 is in cooldown, should get key2
211        let cred = pool.get("openai").unwrap();
212        assert_eq!(cred.api_key, "key2");
213    }
214
215    #[test]
216    fn test_pool_report_success() {
217        let pool = CredentialPool::new();
218        pool.add("openai", "key1".to_string(), None);
219        pool.report_rate_limit("openai", "key1", None);
220        pool.report_success("openai", "key1");
221
222        let cred = pool.get("openai").unwrap();
223        assert_eq!(cred.api_key, "key1");
224        assert!(cred.is_available());
225    }
226
227    #[test]
228    fn test_pool_empty_provider() {
229        let pool = CredentialPool::new();
230        assert!(pool.get("nonexistent").is_none());
231    }
232
233    #[test]
234    fn test_pool_remove() {
235        let pool = CredentialPool::new();
236        pool.add("openai", "key1".to_string(), None);
237        pool.remove("openai");
238        assert!(pool.get("openai").is_none());
239    }
240
241    #[test]
242    fn test_pool_available_count() {
243        let pool = CredentialPool::new();
244        pool.add("openai", "key1".to_string(), None);
245        pool.add("openai", "key2".to_string(), None);
246        pool.report_rate_limit("openai", "key1", None);
247
248        assert_eq!(pool.available_count("openai"), 1);
249        assert_eq!(pool.count("openai"), 2);
250    }
251
252    #[test]
253    fn test_credential_entry_cooldown_expiry() {
254        let mut entry = CredentialEntry::new("key".to_string(), None);
255        // Manually set a past rate limit with very short cooldown
256        entry.rate_limited_at = Some(Instant::now() - Duration::from_secs(120));
257        entry.cooldown = Duration::from_secs(60);
258        // Should be available since cooldown expired
259        assert!(entry.is_available());
260    }
261}