use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, warn};
const RATE_LIMIT_COOLDOWN_SECS: u64 = 60;
fn expand_multi_key(env_var: &str, raw: String) -> Vec<(String, String)> {
if raw.contains(',') {
raw.split(',')
.enumerate()
.filter_map(|(i, k)| {
let k = k.trim().to_string();
if k.is_empty() {
None
} else {
Some((format!("{}[{}]", env_var, i), k))
}
})
.collect()
} else {
vec![(env_var.to_string(), raw)]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KeyStats {
pub env_var: String,
pub total_requests: u64,
pub successes: u64,
pub failures: u64,
pub rate_limits: u64,
pub total_latency_ms: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
#[serde(skip)]
pub active_requests: u64,
#[serde(default)]
pub last_rate_limit_at: u64,
#[serde(default)]
pub last_success_at: u64,
}
impl KeyStats {
fn new(env_var: String) -> Self {
Self {
env_var,
total_requests: 0,
successes: 0,
failures: 0,
rate_limits: 0,
total_latency_ms: 0,
total_input_tokens: 0,
total_output_tokens: 0,
active_requests: 0,
last_rate_limit_at: 0,
last_success_at: 0,
}
}
pub fn avg_latency_ms(&self) -> f64 {
if self.successes == 0 {
return 0.0;
}
self.total_latency_ms as f64 / self.successes as f64
}
pub fn success_rate(&self) -> f64 {
let total = self.successes + self.failures;
if total == 0 {
return 0.5;
}
self.successes as f64 / total as f64
}
pub fn is_rate_limited(&self) -> bool {
if self.last_rate_limit_at == 0 {
return false;
}
let now = now_unix();
now.saturating_sub(self.last_rate_limit_at) < RATE_LIMIT_COOLDOWN_SECS
}
pub fn estimated_cost(&self, input_per_mtok: f64, output_per_mtok: f64) -> f64 {
let input_cost = (self.total_input_tokens as f64 / 1_000_000.0) * input_per_mtok;
let output_cost = (self.total_output_tokens as f64 / 1_000_000.0) * output_per_mtok;
input_cost + output_cost
}
}
#[derive(Debug, Clone)]
pub struct KeyLease {
pub api_key: String,
pub env_var: String,
#[allow(dead_code)]
pub(crate) index: usize,
}
struct EndpointKeys {
keys: Vec<(String, String)>,
stats: HashMap<String, KeyStats>,
next_index: usize,
}
impl EndpointKeys {
fn new(env_vars: Vec<String>) -> Self {
let mut keys = Vec::new();
let mut stats = HashMap::new();
for env_var in env_vars {
if let Some(raw) = car_secrets::resolve_env_or_keychain(&env_var) {
for (sub_var, key) in expand_multi_key(&env_var, raw) {
stats.insert(sub_var.clone(), KeyStats::new(sub_var.clone()));
keys.push((sub_var, key));
}
}
}
Self {
keys,
stats,
next_index: 0,
}
}
fn lease(&mut self) -> Option<KeyLease> {
if self.keys.is_empty() {
return None;
}
let mut candidates: Vec<(usize, f64)> = Vec::new();
let mut all_cold = true;
for (idx, (ref env_var, _)) in self.keys.iter().enumerate() {
let stats = self.stats.get(env_var);
if let Some(s) = stats {
if s.is_rate_limited() {
continue;
}
if s.total_requests > 0 {
all_cold = false;
}
}
let score = match stats {
Some(s) if s.total_requests > 0 => {
let total_tokens = s.total_input_tokens + s.total_output_tokens;
let completed = s.successes + s.failures + s.rate_limits;
if completed > 0 {
let avg_tokens_per_req = total_tokens as f64 / completed as f64;
let inflight_estimate = s.active_requests as f64 * avg_tokens_per_req;
total_tokens as f64 + inflight_estimate
} else {
s.active_requests as f64 * 1000.0
}
}
_ => 0.0, };
candidates.push((idx, score));
}
if all_cold && !candidates.is_empty() {
let start = self.next_index % candidates.len();
let (idx, _) = candidates[start];
self.next_index = start + 1;
return self.issue_lease(idx);
}
if candidates.is_empty() {
let mut best_idx = 0;
let mut oldest_rl = u64::MAX;
for (idx, (ref env_var, _)) in self.keys.iter().enumerate() {
if let Some(stats) = self.stats.get(env_var) {
if stats.last_rate_limit_at < oldest_rl {
oldest_rl = stats.last_rate_limit_at;
best_idx = idx;
}
}
}
let env_var = &self.keys[best_idx].0;
warn!(env_var = %env_var, "all keys rate-limited, using oldest-cooldown key");
return self.issue_lease(best_idx);
}
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let (best_idx, _score) = candidates[0];
self.issue_lease(best_idx)
}
fn issue_lease(&mut self, idx: usize) -> Option<KeyLease> {
let (ref env_var, ref key) = self.keys[idx];
let env_var = env_var.clone();
let api_key = key.clone();
if let Some(stats) = self.stats.get_mut(&env_var) {
stats.active_requests += 1;
stats.total_requests += 1;
}
Some(KeyLease {
api_key,
env_var,
index: idx,
})
}
fn report_success(
&mut self,
env_var: &str,
latency_ms: u64,
input_tokens: u64,
output_tokens: u64,
) {
if let Some(stats) = self.stats.get_mut(env_var) {
stats.successes += 1;
stats.active_requests = stats.active_requests.saturating_sub(1);
stats.total_latency_ms += latency_ms;
stats.total_input_tokens += input_tokens;
stats.total_output_tokens += output_tokens;
stats.last_success_at = now_unix();
}
}
fn report_failure(&mut self, env_var: &str, is_rate_limit: bool) {
if let Some(stats) = self.stats.get_mut(env_var) {
stats.active_requests = stats.active_requests.saturating_sub(1);
if is_rate_limit {
stats.rate_limits += 1;
stats.last_rate_limit_at = now_unix();
} else {
stats.failures += 1;
}
}
}
}
pub struct KeyPool {
endpoints: RwLock<HashMap<String, EndpointKeys>>,
}
impl KeyPool {
pub fn new() -> Self {
Self {
endpoints: RwLock::new(HashMap::new()),
}
}
pub async fn register_endpoint(&self, endpoint: &str, env_vars: Vec<String>) {
let mut endpoints = self.endpoints.write().await;
let entry = endpoints
.entry(endpoint.to_string())
.or_insert_with(|| EndpointKeys::new(vec![]));
let existing_vars: std::collections::HashSet<String> =
entry.keys.iter().map(|(v, _)| v.clone()).collect();
let mut new_keys: Vec<(String, String)> = Vec::new();
for env_var in env_vars {
if existing_vars.contains(&env_var) {
continue;
}
if let Some(raw) = car_secrets::resolve_env_or_keychain(&env_var) {
for (sub_var, key) in expand_multi_key(&env_var, raw) {
if !existing_vars.contains(&sub_var) {
new_keys.push((sub_var, key));
}
}
}
}
for (var, key) in new_keys {
entry
.stats
.entry(var.clone())
.or_insert_with(|| KeyStats::new(var.clone()));
entry.keys.push((var, key));
}
debug!(
endpoint = %endpoint,
key_count = entry.keys.len(),
"registered endpoint keys"
);
}
pub async fn lease(&self, endpoint: &str) -> Option<KeyLease> {
let mut endpoints = self.endpoints.write().await;
endpoints.get_mut(endpoint)?.lease()
}
pub async fn lease_or_env(&self, endpoint: &str, fallback_env: &str) -> Option<KeyLease> {
if let Some(lease) = self.lease(endpoint).await {
return Some(lease);
}
if car_secrets::resolve_env_or_keychain(fallback_env).is_some() {
self.register_endpoint(endpoint, vec![fallback_env.to_string()])
.await;
return self.lease(endpoint).await;
}
None
}
pub async fn report_success(
&self,
endpoint: &str,
env_var: &str,
latency_ms: u64,
input_tokens: u64,
output_tokens: u64,
) {
let mut endpoints = self.endpoints.write().await;
if let Some(ep) = endpoints.get_mut(endpoint) {
ep.report_success(env_var, latency_ms, input_tokens, output_tokens);
}
}
pub async fn report_failure(&self, endpoint: &str, env_var: &str, is_rate_limit: bool) {
let mut endpoints = self.endpoints.write().await;
if let Some(ep) = endpoints.get_mut(endpoint) {
ep.report_failure(env_var, is_rate_limit);
}
}
pub async fn endpoint_stats(&self, endpoint: &str) -> Vec<KeyStats> {
let endpoints = self.endpoints.read().await;
endpoints
.get(endpoint)
.map(|ep| ep.stats.values().cloned().collect())
.unwrap_or_default()
}
pub async fn all_stats(&self) -> HashMap<String, Vec<KeyStats>> {
let endpoints = self.endpoints.read().await;
endpoints
.iter()
.map(|(ep, keys)| (ep.clone(), keys.stats.values().cloned().collect()))
.collect()
}
pub async fn total_keys(&self) -> usize {
let endpoints = self.endpoints.read().await;
endpoints.values().map(|ep| ep.keys.len()).sum()
}
pub async fn available_keys(&self, endpoint: &str) -> usize {
let endpoints = self.endpoints.read().await;
endpoints
.get(endpoint)
.map(|ep| {
ep.keys
.iter()
.filter(|(env_var, _)| {
ep.stats
.get(env_var)
.map(|s| !s.is_rate_limited())
.unwrap_or(true)
})
.count()
})
.unwrap_or(0)
}
pub async fn save_stats(&self, path: &std::path::Path) -> Result<(), std::io::Error> {
let endpoints = self.endpoints.read().await;
let stats: HashMap<String, Vec<KeyStats>> = endpoints
.iter()
.map(|(ep, keys)| (ep.clone(), keys.stats.values().cloned().collect()))
.collect();
let json = serde_json::to_string_pretty(&stats)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, json)
}
pub async fn load_stats(&self, path: &std::path::Path) -> Result<usize, std::io::Error> {
if !path.exists() {
return Ok(0);
}
let json = std::fs::read_to_string(path)?;
let saved: HashMap<String, Vec<KeyStats>> = serde_json::from_str(&json)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
let mut endpoints = self.endpoints.write().await;
let mut count = 0;
for (endpoint, stats_list) in saved {
let ep = endpoints
.entry(endpoint)
.or_insert_with(|| EndpointKeys::new(vec![]));
for stats in stats_list {
ep.stats.insert(stats.env_var.clone(), stats);
count += 1;
}
}
Ok(count)
}
}
impl Default for KeyPool {
fn default() -> Self {
Self::new()
}
}
fn now_unix() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn single_key_round_trip() {
std::env::set_var("TEST_KEY_POOL_1", "sk-test-111");
let pool = KeyPool::new();
pool.register_endpoint("https://api.test.com", vec!["TEST_KEY_POOL_1".into()])
.await;
let lease = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(lease.api_key, "sk-test-111");
assert_eq!(lease.env_var, "TEST_KEY_POOL_1");
pool.report_success("https://api.test.com", &lease.env_var, 500, 100, 50)
.await;
let stats = pool.endpoint_stats("https://api.test.com").await;
assert_eq!(stats.len(), 1);
assert_eq!(stats[0].successes, 1);
assert_eq!(stats[0].total_latency_ms, 500);
std::env::remove_var("TEST_KEY_POOL_1");
}
#[tokio::test]
async fn multi_key_cold_start_round_robin() {
std::env::set_var("TEST_KEY_POOL_A", "sk-aaa");
std::env::set_var("TEST_KEY_POOL_B", "sk-bbb");
let pool = KeyPool::new();
pool.register_endpoint(
"https://api.test.com",
vec!["TEST_KEY_POOL_A".into(), "TEST_KEY_POOL_B".into()],
)
.await;
let l1 = pool.lease("https://api.test.com").await.unwrap();
pool.report_success("https://api.test.com", &l1.env_var, 100, 10, 5)
.await;
let l2 = pool.lease("https://api.test.com").await.unwrap();
pool.report_success("https://api.test.com", &l2.env_var, 100, 10, 5)
.await;
assert_ne!(l1.env_var, l2.env_var);
std::env::remove_var("TEST_KEY_POOL_A");
std::env::remove_var("TEST_KEY_POOL_B");
}
#[tokio::test]
async fn token_aware_prefers_least_used() {
std::env::set_var("TEST_KEY_POOL_TA1", "sk-ta1");
std::env::set_var("TEST_KEY_POOL_TA2", "sk-ta2");
let pool = KeyPool::new();
pool.register_endpoint(
"https://api.test.com",
vec!["TEST_KEY_POOL_TA1".into(), "TEST_KEY_POOL_TA2".into()],
)
.await;
let l1 = pool.lease("https://api.test.com").await.unwrap();
pool.report_success("https://api.test.com", &l1.env_var, 100, 1000, 500)
.await;
let l2 = pool.lease("https://api.test.com").await.unwrap();
pool.report_success("https://api.test.com", &l2.env_var, 100, 100, 50)
.await;
let l3 = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(
l3.env_var, l2.env_var,
"should pick the key with fewer tokens"
);
pool.report_success("https://api.test.com", &l3.env_var, 100, 5000, 5000)
.await;
let l4 = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(
l4.env_var, l1.env_var,
"should pick key with fewer tokens after rebalance"
);
std::env::remove_var("TEST_KEY_POOL_TA1");
std::env::remove_var("TEST_KEY_POOL_TA2");
}
#[tokio::test]
async fn comma_separated_keys() {
std::env::set_var("TEST_KEY_POOL_CSV", "sk-one, sk-two, sk-three");
let pool = KeyPool::new();
pool.register_endpoint("https://api.test.com", vec!["TEST_KEY_POOL_CSV".into()])
.await;
assert_eq!(pool.total_keys().await, 3);
let l1 = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(l1.api_key, "sk-one");
let l2 = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(l2.api_key, "sk-two");
let l3 = pool.lease("https://api.test.com").await.unwrap();
assert_eq!(l3.api_key, "sk-three");
std::env::remove_var("TEST_KEY_POOL_CSV");
}
#[tokio::test]
async fn rate_limited_key_skipped() {
std::env::set_var("TEST_KEY_POOL_RL1", "sk-rl1");
std::env::set_var("TEST_KEY_POOL_RL2", "sk-rl2");
let pool = KeyPool::new();
pool.register_endpoint(
"https://api.test.com",
vec!["TEST_KEY_POOL_RL1".into(), "TEST_KEY_POOL_RL2".into()],
)
.await;
let l1 = pool.lease("https://api.test.com").await.unwrap();
pool.report_failure("https://api.test.com", &l1.env_var, true)
.await;
let l2 = pool.lease("https://api.test.com").await.unwrap();
assert_ne!(l1.env_var, l2.env_var);
std::env::remove_var("TEST_KEY_POOL_RL1");
std::env::remove_var("TEST_KEY_POOL_RL2");
}
#[tokio::test]
async fn lease_or_env_fallback() {
std::env::set_var("TEST_KEY_POOL_FB", "sk-fallback");
let pool = KeyPool::new();
let lease = pool
.lease_or_env("https://api.new.com", "TEST_KEY_POOL_FB")
.await
.unwrap();
assert_eq!(lease.api_key, "sk-fallback");
assert_eq!(pool.total_keys().await, 1);
std::env::remove_var("TEST_KEY_POOL_FB");
}
#[tokio::test]
async fn no_keys_returns_none() {
let pool = KeyPool::new();
assert!(pool.lease("https://nonexistent.com").await.is_none());
}
}