hermes_agent_cli_core/
credential_pool.rs1use std::collections::HashMap;
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Clone)]
12pub struct CredentialEntry {
13 pub api_key: String,
14 pub base_url: Option<String>,
15 pub rate_limited_at: Option<Instant>,
17 pub cooldown: Duration,
19 pub failures: u32,
21}
22
23impl CredentialEntry {
24 pub fn new(api_key: String, base_url: Option<String>) -> Self {
25 Self {
26 api_key,
27 base_url,
28 rate_limited_at: None,
29 cooldown: Duration::from_secs(60),
30 failures: 0,
31 }
32 }
33
34 pub fn is_available(&self) -> bool {
36 match self.rate_limited_at {
37 None => true,
38 Some(t) => t.elapsed() >= self.cooldown,
39 }
40 }
41
42 pub fn mark_rate_limited(&mut self, retry_after: Option<u64>) {
44 self.rate_limited_at = Some(Instant::now());
45 self.cooldown = retry_after.map(Duration::from_secs).unwrap_or(Duration::from_secs(60));
46 self.failures += 1;
47 }
48
49 pub fn mark_success(&mut self) {
51 self.rate_limited_at = None;
52 self.failures = 0;
53 }
54}
55
56pub struct CredentialPool {
58 pool: Mutex<HashMap<String, Vec<CredentialEntry>>>,
60 index: Mutex<HashMap<String, usize>>,
62}
63
64impl Default for CredentialPool {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl CredentialPool {
71 pub fn new() -> Self {
72 Self { pool: Mutex::new(HashMap::new()), index: Mutex::new(HashMap::new()) }
73 }
74
75 pub fn from_auth_store(auth: &crate::auth::AuthStore) -> Self {
77 let pool = Self::new();
78 for cred in &auth.credentials {
79 pool.add(&cred.provider, cred.api_key.clone(), cred.base_url.clone());
80 }
81 pool
82 }
83
84 pub fn add(&self, provider: &str, api_key: String, base_url: Option<String>) {
86 let mut pool = self.pool.lock().unwrap();
87 let entries = pool.entry(provider.to_string()).or_default();
88
89 if entries.iter().any(|e| e.api_key == api_key) {
91 return;
92 }
93
94 entries.push(CredentialEntry::new(api_key, base_url));
95 }
96
97 pub fn remove(&self, provider: &str) {
99 let mut pool = self.pool.lock().unwrap();
100 pool.remove(provider);
101 self.index.lock().unwrap().remove(provider);
102 }
103
104 pub fn get(&self, provider: &str) -> Option<CredentialEntry> {
108 let pool = self.pool.lock().unwrap();
109 let entries = pool.get(provider)?;
110
111 if entries.is_empty() {
112 return None;
113 }
114
115 let mut index = self.index.lock().unwrap();
116 let start = *index.entry(provider.to_string()).or_insert(0);
117 let len = entries.len();
118
119 for i in 0..len {
121 let idx = (start + i) % len;
122 if entries[idx].is_available() {
123 *index.get_mut(provider).unwrap() = (idx + 1) % len;
124 return Some(entries[idx].clone());
125 }
126 }
127
128 let best = entries.iter().min_by_key(|e| e.rate_limited_at.map(|t| t.elapsed())).unwrap();
130 Some(best.clone())
131 }
132
133 pub fn report_rate_limit(&self, provider: &str, api_key: &str, retry_after: Option<u64>) {
135 let mut pool = self.pool.lock().unwrap();
136 if let Some(entries) = pool.get_mut(provider) {
137 if let Some(entry) = entries.iter_mut().find(|e| e.api_key == api_key) {
138 entry.mark_rate_limited(retry_after);
139 }
140 }
141 }
142
143 pub fn report_success(&self, provider: &str, api_key: &str) {
145 let mut pool = self.pool.lock().unwrap();
146 if let Some(entries) = pool.get_mut(provider) {
147 if let Some(entry) = entries.iter_mut().find(|e| e.api_key == api_key) {
148 entry.mark_success();
149 }
150 }
151 }
152
153 pub fn count(&self, provider: &str) -> usize {
155 self.pool.lock().unwrap().get(provider).map(|v| v.len()).unwrap_or(0)
156 }
157
158 pub fn available_count(&self, provider: &str) -> usize {
160 self.pool
161 .lock()
162 .unwrap()
163 .get(provider)
164 .map(|v| v.iter().filter(|e| e.is_available()).count())
165 .unwrap_or(0)
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn test_pool_add_and_get() {
175 let pool = CredentialPool::new();
176 pool.add("openai", "key1".to_string(), None);
177 pool.add("openai", "key2".to_string(), None);
178
179 let cred = pool.get("openai").unwrap();
180 assert_eq!(cred.api_key, "key1");
181 }
182
183 #[test]
184 fn test_pool_round_robin() {
185 let pool = CredentialPool::new();
186 pool.add("openai", "key1".to_string(), None);
187 pool.add("openai", "key2".to_string(), None);
188
189 let first = pool.get("openai").unwrap();
190 let second = pool.get("openai").unwrap();
191 assert_ne!(first.api_key, second.api_key);
192 }
193
194 #[test]
195 fn test_pool_no_duplicates() {
196 let pool = CredentialPool::new();
197 pool.add("openai", "key1".to_string(), None);
198 pool.add("openai", "key1".to_string(), None);
199 assert_eq!(pool.count("openai"), 1);
200 }
201
202 #[test]
203 fn test_pool_rate_limit_cooldown() {
204 let pool = CredentialPool::new();
205 pool.add("openai", "key1".to_string(), None);
206 pool.add("openai", "key2".to_string(), None);
207
208 pool.report_rate_limit("openai", "key1", None);
209
210 let cred = pool.get("openai").unwrap();
212 assert_eq!(cred.api_key, "key2");
213 }
214
215 #[test]
216 fn test_pool_report_success() {
217 let pool = CredentialPool::new();
218 pool.add("openai", "key1".to_string(), None);
219 pool.report_rate_limit("openai", "key1", None);
220 pool.report_success("openai", "key1");
221
222 let cred = pool.get("openai").unwrap();
223 assert_eq!(cred.api_key, "key1");
224 assert!(cred.is_available());
225 }
226
227 #[test]
228 fn test_pool_empty_provider() {
229 let pool = CredentialPool::new();
230 assert!(pool.get("nonexistent").is_none());
231 }
232
233 #[test]
234 fn test_pool_remove() {
235 let pool = CredentialPool::new();
236 pool.add("openai", "key1".to_string(), None);
237 pool.remove("openai");
238 assert!(pool.get("openai").is_none());
239 }
240
241 #[test]
242 fn test_pool_available_count() {
243 let pool = CredentialPool::new();
244 pool.add("openai", "key1".to_string(), None);
245 pool.add("openai", "key2".to_string(), None);
246 pool.report_rate_limit("openai", "key1", None);
247
248 assert_eq!(pool.available_count("openai"), 1);
249 assert_eq!(pool.count("openai"), 2);
250 }
251
252 #[test]
253 fn test_credential_entry_cooldown_expiry() {
254 let mut entry = CredentialEntry::new("key".to_string(), None);
255 entry.rate_limited_at = Some(Instant::now() - Duration::from_secs(120));
257 entry.cooldown = Duration::from_secs(60);
258 assert!(entry.is_available());
260 }
261}