use std::collections::HashMap;
use std::time::Instant;
use tracing::{info, warn};
fn cooldown_seconds(status: u16) -> f64 {
match status {
429 => 30.0, 401 => 300.0, 402 => 300.0, 403 => 600.0, 500 => 60.0, 502 => 30.0, 503 => 60.0, _ => 60.0, }
}
#[derive(Debug)]
struct AuthProfile {
key: String,
#[allow(dead_code)]
provider: String,
failed_at: Option<Instant>,
failure_status: u16,
cooldown_until: Option<Instant>,
request_count: u64,
failure_count: u64,
}
impl AuthProfile {
fn new(key: String, provider: String) -> Self {
Self {
key,
provider,
failed_at: None,
failure_status: 0,
cooldown_until: None,
request_count: 0,
failure_count: 0,
}
}
fn is_available(&self) -> bool {
match self.cooldown_until {
None => true,
Some(until) => Instant::now() >= until,
}
}
fn mark_success(&mut self) {
self.request_count += 1;
self.failed_at = None;
self.failure_status = 0;
self.cooldown_until = None;
}
fn mark_failure(&mut self, status_code: u16) {
self.failure_count += 1;
let now = Instant::now();
self.failed_at = Some(now);
self.failure_status = status_code;
let cooldown = cooldown_seconds(status_code);
self.cooldown_until = Some(now + std::time::Duration::from_secs_f64(cooldown));
warn!(
key_prefix = &self.key[..self.key.len().min(8)],
status_code,
cooldown_secs = cooldown,
"Auth profile failed, cooling down"
);
}
fn cooldown_remaining(&self) -> f64 {
match self.cooldown_until {
None => 0.0,
Some(until) => {
let now = Instant::now();
if now >= until {
0.0
} else {
(until - now).as_secs_f64()
}
}
}
}
}
pub struct AuthProfileManager {
provider: String,
profiles: Vec<AuthProfile>,
current_index: usize,
}
impl AuthProfileManager {
pub fn new(provider: impl Into<String>, keys: Vec<String>) -> Self {
let provider = provider.into();
let profiles: Vec<_> = keys
.into_iter()
.filter(|k| !k.is_empty())
.map(|k| AuthProfile::new(k, provider.clone()))
.collect();
if profiles.is_empty() {
warn!("No API keys configured for provider '{}'", provider);
}
Self {
provider,
profiles,
current_index: 0,
}
}
pub fn from_env(provider: &str) -> Self {
let prefix = provider.to_uppercase().replace('-', "_");
let mut keys = Vec::new();
if let Ok(val) = std::env::var(format!("{prefix}_API_KEY"))
&& !val.is_empty()
{
keys.push(val);
}
for i in 2..10 {
match std::env::var(format!("{prefix}_API_KEY_{i}")) {
Ok(val) if !val.is_empty() => keys.push(val),
_ => break,
}
}
Self::new(provider, keys)
}
pub fn from_config(provider: &str, config: &HashMap<String, serde_json::Value>) -> Self {
let keys = if let Some(serde_json::Value::Array(arr)) = config.get("api_keys") {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
} else if let Some(serde_json::Value::String(single)) = config.get("api_key") {
vec![single.clone()]
} else {
vec![]
};
Self::new(provider, keys)
}
pub fn get_active_key(&mut self) -> Option<&str> {
if self.profiles.is_empty() {
return None;
}
if self.profiles[self.current_index].is_available() {
return Some(&self.profiles[self.current_index].key);
}
let len = self.profiles.len();
for i in 1..len {
let idx = (self.current_index + i) % len;
if self.profiles[idx].is_available() {
self.current_index = idx;
let profile = &self.profiles[idx];
info!(
key_prefix = &profile.key[..profile.key.len().min(8)],
provider = %self.provider,
"Rotated to next API key"
);
return Some(&self.profiles[idx].key);
}
}
let soonest = self
.profiles
.iter()
.map(|p| p.cooldown_remaining())
.fold(f64::MAX, f64::min);
warn!(
total = self.profiles.len(),
provider = %self.provider,
soonest_available_secs = soonest,
"All API keys are in cooldown"
);
None
}
pub fn mark_success(&mut self) {
if !self.profiles.is_empty() {
self.profiles[self.current_index].mark_success();
}
}
pub fn mark_failure(&mut self, status_code: u16) {
if !self.profiles.is_empty() {
self.profiles[self.current_index].mark_failure(status_code);
}
}
pub fn profile_count(&self) -> usize {
self.profiles.len()
}
pub fn available_count(&self) -> usize {
self.profiles.iter().filter(|p| p.is_available()).count()
}
pub fn provider(&self) -> &str {
&self.provider
}
}
impl std::fmt::Debug for AuthProfileManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthProfileManager")
.field("provider", &self.provider)
.field("profile_count", &self.profiles.len())
.field("current_index", &self.current_index)
.finish()
}
}
#[cfg(test)]
#[path = "rotation_tests.rs"]
mod tests;