1use std::collections::HashMap;
37use std::time::{SystemTime, UNIX_EPOCH};
38
39use serde::{Deserialize, Serialize};
40use tokio::sync::RwLock;
41use tracing::{debug, warn};
42
43const RATE_LIMIT_COOLDOWN_SECS: u64 = 60;
45
46fn expand_multi_key(env_var: &str, raw: String) -> Vec<(String, String)> {
49 if raw.contains(',') {
50 raw.split(',')
51 .enumerate()
52 .filter_map(|(i, k)| {
53 let k = k.trim().to_string();
54 if k.is_empty() {
55 None
56 } else {
57 Some((format!("{}[{}]", env_var, i), k))
58 }
59 })
60 .collect()
61 } else {
62 vec![(env_var.to_string(), raw)]
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct KeyStats {
69 pub env_var: String,
71 pub total_requests: u64,
73 pub successes: u64,
75 pub failures: u64,
77 pub rate_limits: u64,
79 pub total_latency_ms: u64,
81 pub total_input_tokens: u64,
83 pub total_output_tokens: u64,
85 #[serde(skip)]
87 pub active_requests: u64,
88 #[serde(default)]
90 pub last_rate_limit_at: u64,
91 #[serde(default)]
93 pub last_success_at: u64,
94}
95
96impl KeyStats {
97 fn new(env_var: String) -> Self {
98 Self {
99 env_var,
100 total_requests: 0,
101 successes: 0,
102 failures: 0,
103 rate_limits: 0,
104 total_latency_ms: 0,
105 total_input_tokens: 0,
106 total_output_tokens: 0,
107 active_requests: 0,
108 last_rate_limit_at: 0,
109 last_success_at: 0,
110 }
111 }
112
113 pub fn avg_latency_ms(&self) -> f64 {
115 if self.successes == 0 {
116 return 0.0;
117 }
118 self.total_latency_ms as f64 / self.successes as f64
119 }
120
121 pub fn success_rate(&self) -> f64 {
123 let total = self.successes + self.failures;
124 if total == 0 {
125 return 0.5;
126 }
127 self.successes as f64 / total as f64
128 }
129
130 pub fn is_rate_limited(&self) -> bool {
132 if self.last_rate_limit_at == 0 {
133 return false;
134 }
135 let now = now_unix();
136 now.saturating_sub(self.last_rate_limit_at) < RATE_LIMIT_COOLDOWN_SECS
137 }
138
139 pub fn estimated_cost(&self, input_per_mtok: f64, output_per_mtok: f64) -> f64 {
141 let input_cost = (self.total_input_tokens as f64 / 1_000_000.0) * input_per_mtok;
142 let output_cost = (self.total_output_tokens as f64 / 1_000_000.0) * output_per_mtok;
143 input_cost + output_cost
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct KeyLease {
150 pub api_key: String,
152 pub env_var: String,
154 #[allow(dead_code)]
156 pub(crate) index: usize,
157}
158
159struct EndpointKeys {
161 keys: Vec<(String, String)>,
163 stats: HashMap<String, KeyStats>,
165 next_index: usize,
167}
168
169impl EndpointKeys {
170 fn new(env_vars: Vec<String>) -> Self {
171 let mut keys = Vec::new();
172 let mut stats = HashMap::new();
173
174 for env_var in env_vars {
175 if let Some(raw) = car_secrets::resolve_env_or_keychain(&env_var) {
176 for (sub_var, key) in expand_multi_key(&env_var, raw) {
177 stats.insert(sub_var.clone(), KeyStats::new(sub_var.clone()));
178 keys.push((sub_var, key));
179 }
180 }
181 }
182
183 Self {
184 keys,
185 stats,
186 next_index: 0,
187 }
188 }
189
190 fn lease(&mut self) -> Option<KeyLease> {
200 if self.keys.is_empty() {
201 return None;
202 }
203
204 let mut candidates: Vec<(usize, f64)> = Vec::new();
206 let mut all_cold = true;
207
208 for (idx, (ref env_var, _)) in self.keys.iter().enumerate() {
209 let stats = self.stats.get(env_var);
210
211 if let Some(s) = stats {
213 if s.is_rate_limited() {
214 continue;
215 }
216 if s.total_requests > 0 {
217 all_cold = false;
218 }
219 }
220
221 let score = match stats {
227 Some(s) if s.total_requests > 0 => {
228 let total_tokens = s.total_input_tokens + s.total_output_tokens;
229 let completed = s.successes + s.failures + s.rate_limits;
230 if completed > 0 {
231 let avg_tokens_per_req = total_tokens as f64 / completed as f64;
232 let inflight_estimate = s.active_requests as f64 * avg_tokens_per_req;
233 total_tokens as f64 + inflight_estimate
234 } else {
235 s.active_requests as f64 * 1000.0
237 }
238 }
239 _ => 0.0, };
241
242 candidates.push((idx, score));
243 }
244
245 if all_cold && !candidates.is_empty() {
247 let start = self.next_index % candidates.len();
248 let (idx, _) = candidates[start];
249 self.next_index = start + 1;
250 return self.issue_lease(idx);
251 }
252
253 if candidates.is_empty() {
254 let mut best_idx = 0;
256 let mut oldest_rl = u64::MAX;
257 for (idx, (ref env_var, _)) in self.keys.iter().enumerate() {
258 if let Some(stats) = self.stats.get(env_var) {
259 if stats.last_rate_limit_at < oldest_rl {
260 oldest_rl = stats.last_rate_limit_at;
261 best_idx = idx;
262 }
263 }
264 }
265 let env_var = &self.keys[best_idx].0;
266 warn!(env_var = %env_var, "all keys rate-limited, using oldest-cooldown key");
267 return self.issue_lease(best_idx);
268 }
269
270 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
272 let (best_idx, _score) = candidates[0];
273 self.issue_lease(best_idx)
274 }
275
276 fn issue_lease(&mut self, idx: usize) -> Option<KeyLease> {
278 let (ref env_var, ref key) = self.keys[idx];
279 let env_var = env_var.clone();
280 let api_key = key.clone();
281
282 if let Some(stats) = self.stats.get_mut(&env_var) {
283 stats.active_requests += 1;
284 stats.total_requests += 1;
285 }
286
287 Some(KeyLease {
288 api_key,
289 env_var,
290 index: idx,
291 })
292 }
293
294 fn report_success(
295 &mut self,
296 env_var: &str,
297 latency_ms: u64,
298 input_tokens: u64,
299 output_tokens: u64,
300 ) {
301 if let Some(stats) = self.stats.get_mut(env_var) {
302 stats.successes += 1;
303 stats.active_requests = stats.active_requests.saturating_sub(1);
304 stats.total_latency_ms += latency_ms;
305 stats.total_input_tokens += input_tokens;
306 stats.total_output_tokens += output_tokens;
307 stats.last_success_at = now_unix();
308 }
309 }
310
311 fn report_failure(&mut self, env_var: &str, is_rate_limit: bool) {
312 if let Some(stats) = self.stats.get_mut(env_var) {
313 stats.active_requests = stats.active_requests.saturating_sub(1);
314 if is_rate_limit {
315 stats.rate_limits += 1;
316 stats.last_rate_limit_at = now_unix();
317 } else {
318 stats.failures += 1;
319 }
320 }
321 }
322}
323
324pub struct KeyPool {
326 endpoints: RwLock<HashMap<String, EndpointKeys>>,
328}
329
330impl KeyPool {
331 pub fn new() -> Self {
332 Self {
333 endpoints: RwLock::new(HashMap::new()),
334 }
335 }
336
337 pub async fn register_endpoint(&self, endpoint: &str, env_vars: Vec<String>) {
340 let mut endpoints = self.endpoints.write().await;
341 let entry = endpoints
342 .entry(endpoint.to_string())
343 .or_insert_with(|| EndpointKeys::new(vec![]));
344
345 let existing_vars: std::collections::HashSet<String> =
346 entry.keys.iter().map(|(v, _)| v.clone()).collect();
347
348 let mut new_keys: Vec<(String, String)> = Vec::new();
349
350 for env_var in env_vars {
351 if existing_vars.contains(&env_var) {
355 continue;
356 }
357 if let Some(raw) = car_secrets::resolve_env_or_keychain(&env_var) {
358 for (sub_var, key) in expand_multi_key(&env_var, raw) {
359 if !existing_vars.contains(&sub_var) {
360 new_keys.push((sub_var, key));
361 }
362 }
363 }
364 }
365
366 for (var, key) in new_keys {
367 entry
368 .stats
369 .entry(var.clone())
370 .or_insert_with(|| KeyStats::new(var.clone()));
371 entry.keys.push((var, key));
372 }
373
374 debug!(
375 endpoint = %endpoint,
376 key_count = entry.keys.len(),
377 "registered endpoint keys"
378 );
379 }
380
381 pub async fn lease(&self, endpoint: &str) -> Option<KeyLease> {
383 let mut endpoints = self.endpoints.write().await;
384 endpoints.get_mut(endpoint)?.lease()
385 }
386
387 pub async fn lease_or_env(&self, endpoint: &str, fallback_env: &str) -> Option<KeyLease> {
392 if let Some(lease) = self.lease(endpoint).await {
394 return Some(lease);
395 }
396
397 if car_secrets::resolve_env_or_keychain(fallback_env).is_some() {
400 self.register_endpoint(endpoint, vec![fallback_env.to_string()])
401 .await;
402 return self.lease(endpoint).await;
403 }
404
405 None
406 }
407
408 pub async fn report_success(
410 &self,
411 endpoint: &str,
412 env_var: &str,
413 latency_ms: u64,
414 input_tokens: u64,
415 output_tokens: u64,
416 ) {
417 let mut endpoints = self.endpoints.write().await;
418 if let Some(ep) = endpoints.get_mut(endpoint) {
419 ep.report_success(env_var, latency_ms, input_tokens, output_tokens);
420 }
421 }
422
423 pub async fn report_failure(&self, endpoint: &str, env_var: &str, is_rate_limit: bool) {
425 let mut endpoints = self.endpoints.write().await;
426 if let Some(ep) = endpoints.get_mut(endpoint) {
427 ep.report_failure(env_var, is_rate_limit);
428 }
429 }
430
431 pub async fn endpoint_stats(&self, endpoint: &str) -> Vec<KeyStats> {
433 let endpoints = self.endpoints.read().await;
434 endpoints
435 .get(endpoint)
436 .map(|ep| ep.stats.values().cloned().collect())
437 .unwrap_or_default()
438 }
439
440 pub async fn all_stats(&self) -> HashMap<String, Vec<KeyStats>> {
442 let endpoints = self.endpoints.read().await;
443 endpoints
444 .iter()
445 .map(|(ep, keys)| (ep.clone(), keys.stats.values().cloned().collect()))
446 .collect()
447 }
448
449 pub async fn total_keys(&self) -> usize {
451 let endpoints = self.endpoints.read().await;
452 endpoints.values().map(|ep| ep.keys.len()).sum()
453 }
454
455 pub async fn available_keys(&self, endpoint: &str) -> usize {
457 let endpoints = self.endpoints.read().await;
458 endpoints
459 .get(endpoint)
460 .map(|ep| {
461 ep.keys
462 .iter()
463 .filter(|(env_var, _)| {
464 ep.stats
465 .get(env_var)
466 .map(|s| !s.is_rate_limited())
467 .unwrap_or(true)
468 })
469 .count()
470 })
471 .unwrap_or(0)
472 }
473
474 pub async fn save_stats(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
476 let endpoints = self.endpoints.read().await;
477 let stats: HashMap<String, Vec<KeyStats>> = endpoints
478 .iter()
479 .map(|(ep, keys)| (ep.clone(), keys.stats.values().cloned().collect()))
480 .collect();
481 let json = serde_json::to_string_pretty(&stats)
482 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
483 if let Some(parent) = path.parent() {
484 std::fs::create_dir_all(parent)?;
485 }
486 std::fs::write(path, json)
487 }
488
489 pub async fn load_stats(&self, path: &std::path::Path) -> Result<usize, std::io::Error> {
491 if !path.exists() {
492 return Ok(0);
493 }
494 let json = std::fs::read_to_string(path)?;
495 let saved: HashMap<String, Vec<KeyStats>> = serde_json::from_str(&json)
496 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
497
498 let mut endpoints = self.endpoints.write().await;
499 let mut count = 0;
500 for (endpoint, stats_list) in saved {
501 let ep = endpoints
502 .entry(endpoint)
503 .or_insert_with(|| EndpointKeys::new(vec![]));
504 for stats in stats_list {
505 ep.stats.insert(stats.env_var.clone(), stats);
506 count += 1;
507 }
508 }
509 Ok(count)
510 }
511}
512
513impl Default for KeyPool {
514 fn default() -> Self {
515 Self::new()
516 }
517}
518
519fn now_unix() -> u64 {
520 SystemTime::now()
521 .duration_since(UNIX_EPOCH)
522 .unwrap_or_default()
523 .as_secs()
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[tokio::test]
531 async fn single_key_round_trip() {
532 std::env::set_var("TEST_KEY_POOL_1", "sk-test-111");
533
534 let pool = KeyPool::new();
535 pool.register_endpoint("https://api.test.com", vec!["TEST_KEY_POOL_1".into()])
536 .await;
537
538 let lease = pool.lease("https://api.test.com").await.unwrap();
539 assert_eq!(lease.api_key, "sk-test-111");
540 assert_eq!(lease.env_var, "TEST_KEY_POOL_1");
541
542 pool.report_success("https://api.test.com", &lease.env_var, 500, 100, 50)
543 .await;
544
545 let stats = pool.endpoint_stats("https://api.test.com").await;
546 assert_eq!(stats.len(), 1);
547 assert_eq!(stats[0].successes, 1);
548 assert_eq!(stats[0].total_latency_ms, 500);
549
550 std::env::remove_var("TEST_KEY_POOL_1");
551 }
552
553 #[tokio::test]
554 async fn multi_key_cold_start_round_robin() {
555 std::env::set_var("TEST_KEY_POOL_A", "sk-aaa");
556 std::env::set_var("TEST_KEY_POOL_B", "sk-bbb");
557
558 let pool = KeyPool::new();
559 pool.register_endpoint(
560 "https://api.test.com",
561 vec!["TEST_KEY_POOL_A".into(), "TEST_KEY_POOL_B".into()],
562 )
563 .await;
564
565 let l1 = pool.lease("https://api.test.com").await.unwrap();
567 pool.report_success("https://api.test.com", &l1.env_var, 100, 10, 5)
568 .await;
569
570 let l2 = pool.lease("https://api.test.com").await.unwrap();
571 pool.report_success("https://api.test.com", &l2.env_var, 100, 10, 5)
572 .await;
573
574 assert_ne!(l1.env_var, l2.env_var);
576
577 std::env::remove_var("TEST_KEY_POOL_A");
578 std::env::remove_var("TEST_KEY_POOL_B");
579 }
580
581 #[tokio::test]
582 async fn token_aware_prefers_least_used() {
583 std::env::set_var("TEST_KEY_POOL_TA1", "sk-ta1");
584 std::env::set_var("TEST_KEY_POOL_TA2", "sk-ta2");
585
586 let pool = KeyPool::new();
587 pool.register_endpoint(
588 "https://api.test.com",
589 vec!["TEST_KEY_POOL_TA1".into(), "TEST_KEY_POOL_TA2".into()],
590 )
591 .await;
592
593 let l1 = pool.lease("https://api.test.com").await.unwrap();
595 pool.report_success("https://api.test.com", &l1.env_var, 100, 1000, 500)
596 .await;
597
598 let l2 = pool.lease("https://api.test.com").await.unwrap();
599 pool.report_success("https://api.test.com", &l2.env_var, 100, 100, 50)
600 .await;
601
602 let l3 = pool.lease("https://api.test.com").await.unwrap();
605 assert_eq!(
606 l3.env_var, l2.env_var,
607 "should pick the key with fewer tokens"
608 );
609
610 pool.report_success("https://api.test.com", &l3.env_var, 100, 5000, 5000)
612 .await;
613
614 let l4 = pool.lease("https://api.test.com").await.unwrap();
616 assert_eq!(
617 l4.env_var, l1.env_var,
618 "should pick key with fewer tokens after rebalance"
619 );
620
621 std::env::remove_var("TEST_KEY_POOL_TA1");
622 std::env::remove_var("TEST_KEY_POOL_TA2");
623 }
624
625 #[tokio::test]
626 async fn comma_separated_keys() {
627 std::env::set_var("TEST_KEY_POOL_CSV", "sk-one, sk-two, sk-three");
628
629 let pool = KeyPool::new();
630 pool.register_endpoint("https://api.test.com", vec!["TEST_KEY_POOL_CSV".into()])
631 .await;
632
633 assert_eq!(pool.total_keys().await, 3);
634
635 let l1 = pool.lease("https://api.test.com").await.unwrap();
636 assert_eq!(l1.api_key, "sk-one");
637
638 let l2 = pool.lease("https://api.test.com").await.unwrap();
639 assert_eq!(l2.api_key, "sk-two");
640
641 let l3 = pool.lease("https://api.test.com").await.unwrap();
642 assert_eq!(l3.api_key, "sk-three");
643
644 std::env::remove_var("TEST_KEY_POOL_CSV");
645 }
646
647 #[tokio::test]
648 async fn rate_limited_key_skipped() {
649 std::env::set_var("TEST_KEY_POOL_RL1", "sk-rl1");
650 std::env::set_var("TEST_KEY_POOL_RL2", "sk-rl2");
651
652 let pool = KeyPool::new();
653 pool.register_endpoint(
654 "https://api.test.com",
655 vec!["TEST_KEY_POOL_RL1".into(), "TEST_KEY_POOL_RL2".into()],
656 )
657 .await;
658
659 let l1 = pool.lease("https://api.test.com").await.unwrap();
661 pool.report_failure("https://api.test.com", &l1.env_var, true)
662 .await;
663
664 let l2 = pool.lease("https://api.test.com").await.unwrap();
666 assert_ne!(l1.env_var, l2.env_var);
667
668 std::env::remove_var("TEST_KEY_POOL_RL1");
669 std::env::remove_var("TEST_KEY_POOL_RL2");
670 }
671
672 #[tokio::test]
673 async fn lease_or_env_fallback() {
674 std::env::set_var("TEST_KEY_POOL_FB", "sk-fallback");
675
676 let pool = KeyPool::new();
677
678 let lease = pool
680 .lease_or_env("https://api.new.com", "TEST_KEY_POOL_FB")
681 .await
682 .unwrap();
683 assert_eq!(lease.api_key, "sk-fallback");
684
685 assert_eq!(pool.total_keys().await, 1);
687
688 std::env::remove_var("TEST_KEY_POOL_FB");
689 }
690
691 #[tokio::test]
692 async fn no_keys_returns_none() {
693 let pool = KeyPool::new();
694 assert!(pool.lease("https://nonexistent.com").await.is_none());
695 }
696}