use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub multiplier: f64,
pub max_overload_retries: u32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_secs(60),
multiplier: 2.0,
max_overload_retries: 3,
}
}
}
#[derive(Debug, Default)]
pub struct RetryState {
pub consecutive_failures: u32,
pub rate_limit_retries: u32,
pub overload_retries: u32,
pub using_fallback: bool,
}
impl RetryState {
pub fn next_action(&mut self, error: &RetryableError, config: &RetryConfig) -> RetryAction {
self.consecutive_failures += 1;
match error {
RetryableError::RateLimited { retry_after } => {
self.rate_limit_retries += 1;
if self.rate_limit_retries > config.max_retries {
return RetryAction::Abort("Rate limit retries exhausted".into());
}
RetryAction::Retry {
after: Duration::from_millis(*retry_after),
}
}
RetryableError::Overloaded => {
self.overload_retries += 1;
if self.overload_retries > config.max_overload_retries {
if !self.using_fallback {
self.using_fallback = true;
self.overload_retries = 0;
return RetryAction::FallbackModel;
}
return RetryAction::Abort("Overload retries exhausted on fallback".into());
}
let backoff = calculate_backoff(
self.overload_retries,
config.initial_backoff,
config.max_backoff,
config.multiplier,
);
RetryAction::Retry { after: backoff }
}
RetryableError::StreamInterrupted => {
if self.consecutive_failures > config.max_retries {
return RetryAction::Abort("Stream retry limit reached".into());
}
let backoff = calculate_backoff(
self.consecutive_failures,
config.initial_backoff,
config.max_backoff,
config.multiplier,
);
RetryAction::Retry { after: backoff }
}
RetryableError::NonRetryable(msg) => RetryAction::Abort(msg.clone()),
}
}
pub fn reset(&mut self) {
self.consecutive_failures = 0;
self.rate_limit_retries = 0;
}
}
pub enum RetryableError {
RateLimited { retry_after: u64 },
Overloaded,
StreamInterrupted,
NonRetryable(String),
}
pub enum RetryAction {
Retry { after: Duration },
FallbackModel,
Abort(String),
}
fn calculate_backoff(attempt: u32, initial: Duration, max: Duration, multiplier: f64) -> Duration {
let base = initial.as_millis() as f64 * multiplier.powi(attempt as i32 - 1);
let capped = base.min(max.as_millis() as f64);
let jitter = capped * 0.1 * rand_f64();
Duration::from_millis((capped + jitter) as u64)
}
fn rand_f64() -> f64 {
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
(nanos % 1000) as f64 / 1000.0
}