use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use crate::error::ProviderError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum FailoverReason {
RateLimit,
Overloaded,
Timeout,
Auth,
Billing,
Format,
Unknown,
}
impl FailoverReason {
pub fn from_provider_error(err: &ProviderError) -> Self {
match err {
ProviderError::RateLimit(_) => Self::RateLimit,
ProviderError::Overloaded(_) => Self::Overloaded,
ProviderError::Timeout(_) => Self::Timeout,
ProviderError::Auth(_) => Self::Auth,
ProviderError::Billing(_) => Self::Billing,
ProviderError::Format(_) | ProviderError::InvalidRequest(_) => Self::Format,
_ => Self::Unknown,
}
}
pub fn cooldown_duration(&self, consecutive: u32) -> Duration {
match self {
Self::Billing => {
let hours = 2u64.saturating_pow(consecutive.saturating_sub(1));
Duration::from_secs(hours.min(24) * 3600)
}
Self::RateLimit => {
let mins = 2u64.saturating_pow(consecutive.saturating_sub(1));
Duration::from_secs(mins.min(30) * 60)
}
Self::Overloaded => {
let secs = 30u64 * 2u64.saturating_pow(consecutive.saturating_sub(1));
Duration::from_secs(secs.min(300))
}
Self::Timeout => {
let secs = 15u64 * 2u64.saturating_pow(consecutive.saturating_sub(1));
Duration::from_secs(secs.min(120))
}
Self::Auth | Self::Format => Duration::from_secs(300),
Self::Unknown => Duration::from_secs(60),
}
}
}
#[derive(Debug)]
struct CooldownEntry {
consecutive: u32,
cooldown_end: Option<Instant>,
last_failure: Option<Instant>,
billing_disabled_until: Option<Instant>,
}
impl CooldownEntry {
fn new() -> Self {
Self {
consecutive: 0,
cooldown_end: None,
last_failure: None,
billing_disabled_until: None,
}
}
fn is_in_cooldown(&self) -> bool {
let now = Instant::now();
if let Some(billing) = self.billing_disabled_until {
if now < billing {
return true;
}
}
if let Some(end) = self.cooldown_end {
return now < end;
}
false
}
fn reset_if_stale(&mut self) {
if let Some(last) = self.last_failure {
if last.elapsed() > Duration::from_secs(86400) {
self.consecutive = 0;
self.cooldown_end = None;
self.billing_disabled_until = None;
}
}
}
}
#[derive(Clone)]
pub struct CooldownTracker {
entries: Arc<RwLock<HashMap<String, CooldownEntry>>>,
}
impl CooldownTracker {
pub fn new() -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn is_in_cooldown(&self, provider: &str) -> bool {
let entries = self.entries.read().unwrap();
entries
.get(provider)
.map(|e| e.is_in_cooldown())
.unwrap_or(false)
}
pub fn mark_failure(&self, provider: &str, reason: FailoverReason) {
let mut entries = self.entries.write().unwrap();
let entry = entries
.entry(provider.to_string())
.or_insert_with(CooldownEntry::new);
entry.reset_if_stale();
entry.consecutive += 1;
entry.last_failure = Some(Instant::now());
let duration = reason.cooldown_duration(entry.consecutive);
if reason == FailoverReason::Billing {
entry.billing_disabled_until = Some(Instant::now() + duration);
} else {
entry.cooldown_end = Some(Instant::now() + duration);
}
}
pub fn mark_success(&self, provider: &str) {
let mut entries = self.entries.write().unwrap();
if let Some(entry) = entries.get_mut(provider) {
entry.consecutive = 0;
entry.cooldown_end = None;
entry.billing_disabled_until = None;
}
}
}
impl Default for CooldownTracker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_not_in_cooldown_initially() {
let tracker = CooldownTracker::new();
assert!(!tracker.is_in_cooldown("anthropic"));
}
#[test]
fn test_mark_failure_puts_in_cooldown() {
let tracker = CooldownTracker::new();
tracker.mark_failure("anthropic", FailoverReason::RateLimit);
assert!(tracker.is_in_cooldown("anthropic"));
}
#[test]
fn test_mark_success_clears_cooldown() {
let tracker = CooldownTracker::new();
tracker.mark_failure("anthropic", FailoverReason::RateLimit);
tracker.mark_success("anthropic");
assert!(!tracker.is_in_cooldown("anthropic"));
}
#[test]
fn test_billing_longer_than_rate_limit() {
let billing = FailoverReason::Billing.cooldown_duration(1);
let rate = FailoverReason::RateLimit.cooldown_duration(1);
assert!(billing > rate);
}
#[test]
fn test_overloaded_shorter_than_billing() {
let billing = FailoverReason::Billing.cooldown_duration(1);
let overloaded = FailoverReason::Overloaded.cooldown_duration(1);
assert!(overloaded < billing);
}
#[test]
fn test_exponential_backoff_increases() {
let d1 = FailoverReason::RateLimit.cooldown_duration(1);
let d2 = FailoverReason::RateLimit.cooldown_duration(2);
let d3 = FailoverReason::RateLimit.cooldown_duration(3);
assert!(d2 > d1);
assert!(d3 > d2);
}
#[test]
fn test_billing_caps_at_24h() {
let d = FailoverReason::Billing.cooldown_duration(100);
assert_eq!(d, Duration::from_secs(86400));
}
#[test]
fn test_from_provider_error() {
assert_eq!(
FailoverReason::from_provider_error(&ProviderError::RateLimit("".into())),
FailoverReason::RateLimit
);
assert_eq!(
FailoverReason::from_provider_error(&ProviderError::Billing("".into())),
FailoverReason::Billing
);
assert_eq!(
FailoverReason::from_provider_error(&ProviderError::Auth("".into())),
FailoverReason::Auth
);
assert_eq!(
FailoverReason::from_provider_error(&ProviderError::Overloaded("".into())),
FailoverReason::Overloaded
);
assert_eq!(
FailoverReason::from_provider_error(&ProviderError::Format("".into())),
FailoverReason::Format
);
}
#[test]
fn test_multiple_providers_independent() {
let tracker = CooldownTracker::new();
tracker.mark_failure("anthropic", FailoverReason::Billing);
assert!(tracker.is_in_cooldown("anthropic"));
assert!(!tracker.is_in_cooldown("openai"));
}
#[test]
fn test_consecutive_increases_cooldown() {
let tracker = CooldownTracker::new();
tracker.mark_failure("p", FailoverReason::RateLimit);
let entry1_end = {
let entries = tracker.entries.read().unwrap();
entries["p"].cooldown_end.unwrap()
};
tracker.mark_failure("p", FailoverReason::RateLimit);
let entry2_end = {
let entries = tracker.entries.read().unwrap();
entries["p"].cooldown_end.unwrap()
};
assert!(
entry2_end >= entry1_end,
"second failure should have equal or longer cooldown"
);
}
}