use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
use tokio::time::sleep;
use crate::error::LlmError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryStrategy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub backoff: BackoffStrategy,
pub jitter: JitterConfig,
pub retryable_errors: Vec<RetryableErrorType>,
}
impl Default for RetryStrategy {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(60),
backoff: BackoffStrategy::Exponential { multiplier: 2.0 },
jitter: JitterConfig::Full,
retryable_errors: vec![
RetryableErrorType::NetworkError,
RetryableErrorType::RateLimitError,
RetryableErrorType::ServerError,
RetryableErrorType::TimeoutError,
],
}
}
}
impl RetryStrategy {
pub fn new() -> Self {
Self::default()
}
pub const fn with_max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = max_attempts;
self
}
pub const fn with_base_delay(mut self, delay: Duration) -> Self {
self.base_delay = delay;
self
}
pub const fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub const fn with_backoff(mut self, backoff: BackoffStrategy) -> Self {
self.backoff = backoff;
self
}
pub const fn with_jitter(mut self, jitter: JitterConfig) -> Self {
self.jitter = jitter;
self
}
pub fn with_retryable_error(mut self, error_type: RetryableErrorType) -> Self {
if !self.retryable_errors.contains(&error_type) {
self.retryable_errors.push(error_type);
}
self
}
pub fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = match self.backoff {
BackoffStrategy::Fixed => self.base_delay,
BackoffStrategy::Linear { increment } => {
self.base_delay + Duration::from_millis((increment * attempt as f64) as u64)
}
BackoffStrategy::Exponential { multiplier } => {
let delay_ms = self.base_delay.as_millis() as f64 * multiplier.powi(attempt as i32);
Duration::from_millis(delay_ms as u64)
}
};
let delay = base_delay.min(self.max_delay);
self.apply_jitter(delay)
}
fn apply_jitter(&self, delay: Duration) -> Duration {
match self.jitter {
JitterConfig::None => delay,
JitterConfig::Full => {
let jitter_ms = (delay.as_millis() as f64 * rand::random::<f64>()) as u64;
Duration::from_millis(jitter_ms)
}
JitterConfig::Equal => {
let half_delay = delay.as_millis() / 2;
let jitter_ms = half_delay + (half_delay as f64 * rand::random::<f64>()) as u128;
Duration::from_millis(jitter_ms as u64)
}
JitterConfig::Decorrelated => {
let min_delay = self.base_delay.as_millis();
let max_delay = (delay.as_millis() * 3).min(self.max_delay.as_millis());
let jitter_ms =
min_delay + ((max_delay - min_delay) as f64 * rand::random::<f64>()) as u128;
Duration::from_millis(jitter_ms as u64)
}
}
}
pub fn is_retryable(&self, error: &LlmError) -> bool {
let error_type = RetryableErrorType::from_error(error);
self.retryable_errors.contains(&error_type)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BackoffStrategy {
Fixed,
Linear { increment: f64 },
Exponential { multiplier: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum JitterConfig {
None,
Full,
Equal,
Decorrelated,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RetryableErrorType {
NetworkError,
RateLimitError,
ServerError,
TimeoutError,
AuthenticationError,
QuotaExceededError,
ClientError,
}
impl RetryableErrorType {
pub const fn from_error(error: &LlmError) -> Self {
match error {
LlmError::HttpError(_) | LlmError::ConnectionError(_) => Self::NetworkError,
LlmError::RateLimitError(_) => Self::RateLimitError,
LlmError::TimeoutError(_) => Self::TimeoutError,
LlmError::AuthenticationError(_) => Self::AuthenticationError,
LlmError::QuotaExceededError(_) => Self::QuotaExceededError,
LlmError::InvalidParameter(_) | LlmError::InvalidInput(_) => Self::ClientError,
LlmError::ApiError { code, .. } => {
if *code >= 500 {
Self::ServerError
} else if *code == 429 {
Self::RateLimitError
} else if *code == 401 || *code == 403 {
Self::AuthenticationError
} else if *code >= 400 {
Self::ClientError
} else {
Self::NetworkError }
}
_ => Self::ClientError, }
}
}
#[derive(Debug, Clone)]
pub struct RateLimitHandler {
config: RateLimitConfig,
state: RateLimitState,
}
impl RateLimitHandler {
pub fn new(config: RateLimitConfig) -> Self {
Self {
config,
state: RateLimitState::default(),
}
}
pub async fn handle_rate_limit(&mut self, error: &LlmError) -> Result<(), LlmError> {
match error {
LlmError::RateLimitError(message) => {
let retry_after = self.extract_retry_after(message);
let delay = retry_after.unwrap_or(self.config.default_delay);
self.state.last_rate_limit = Some(Instant::now());
self.state.consecutive_rate_limits += 1;
let backoff_delay = delay * 2_u32.pow(self.state.consecutive_rate_limits.min(5));
let final_delay = backoff_delay.min(self.config.max_delay);
sleep(final_delay).await;
Ok(())
}
_ => Ok(()),
}
}
fn extract_retry_after(&self, message: &str) -> Option<Duration> {
if let Some(seconds_str) = message.split("retry after ").nth(1)
&& let Some(seconds_str) = seconds_str.split(' ').next()
&& let Ok(seconds) = seconds_str.parse::<u64>()
{
return Some(Duration::from_secs(seconds));
}
if let Some(seconds_str) = message.split("Retry-After: ").nth(1)
&& let Some(seconds_str) = seconds_str.split('\n').next()
&& let Ok(seconds) = seconds_str.trim().parse::<u64>()
{
return Some(Duration::from_secs(seconds));
}
None
}
pub const fn reset_on_success(&mut self) {
self.state.consecutive_rate_limits = 0;
}
}
#[derive(Debug, Clone)]
pub struct RateLimitConfig {
pub default_delay: Duration,
pub max_delay: Duration,
pub respect_retry_after: bool,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
default_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(300), respect_retry_after: true,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct RateLimitState {
pub last_rate_limit: Option<Instant>,
pub consecutive_rate_limits: u32,
}
pub struct RetryExecutor {
strategy: RetryStrategy,
rate_limit_handler: Option<RateLimitHandler>,
}
impl RetryExecutor {
pub const fn new(strategy: RetryStrategy) -> Self {
Self {
strategy,
rate_limit_handler: None,
}
}
pub fn with_rate_limit_handler(mut self, config: RateLimitConfig) -> Self {
self.rate_limit_handler = Some(RateLimitHandler::new(config));
self
}
pub async fn execute<F, Fut, T>(&mut self, mut operation: F) -> Result<T, LlmError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, LlmError>>,
{
let mut last_error = None;
for attempt in 0..self.strategy.max_attempts {
match operation().await {
Ok(result) => {
if let Some(ref mut handler) = self.rate_limit_handler {
handler.reset_on_success();
}
return Ok(result);
}
Err(error) => {
last_error = Some(error.clone());
if !self.strategy.is_retryable(&error) {
return Err(error);
}
if let Some(ref mut handler) = self.rate_limit_handler {
handler.handle_rate_limit(&error).await?;
}
if attempt < self.strategy.max_attempts - 1 {
let delay = self.strategy.calculate_delay(attempt);
sleep(delay).await;
}
}
}
}
Err(last_error
.unwrap_or_else(|| LlmError::InternalError("All retry attempts failed".to_string())))
}
}
#[derive(Debug, Clone)]
pub struct FailoverConfig {
pub provider_priorities: std::collections::HashMap<String, u32>,
pub max_failures: u32,
pub failure_window: Duration,
pub cooldown_period: Duration,
pub auto_failover: bool,
}
impl Default for FailoverConfig {
fn default() -> Self {
Self {
provider_priorities: std::collections::HashMap::new(),
max_failures: 3,
failure_window: Duration::from_secs(300), cooldown_period: Duration::from_secs(60), auto_failover: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderHealth {
pub name: String,
pub failure_count: u32,
pub last_failure: Option<Instant>,
pub is_healthy: bool,
pub last_success: Option<Instant>,
}
impl ProviderHealth {
pub const fn new(name: String) -> Self {
Self {
name,
failure_count: 0,
last_failure: None,
is_healthy: true,
last_success: None,
}
}
pub fn record_failure(&mut self, config: &FailoverConfig) {
self.failure_count += 1;
self.last_failure = Some(Instant::now());
if self.failure_count >= config.max_failures {
self.is_healthy = false;
}
}
pub fn record_success(&mut self) {
self.failure_count = 0;
self.last_success = Some(Instant::now());
self.is_healthy = true;
}
pub fn should_retry(&self, config: &FailoverConfig) -> bool {
if self.is_healthy {
return true;
}
if let Some(last_failure) = self.last_failure {
last_failure.elapsed() >= config.cooldown_period
} else {
true
}
}
}
pub struct FailoverManager {
config: FailoverConfig,
provider_health: std::collections::HashMap<String, ProviderHealth>,
}
impl FailoverManager {
pub fn new(config: FailoverConfig) -> Self {
Self {
config,
provider_health: std::collections::HashMap::new(),
}
}
pub fn get_next_provider(&mut self, providers: &[String]) -> Option<String> {
if !self.config.auto_failover {
return providers.first().cloned();
}
let mut available_providers: Vec<_> = providers
.iter()
.filter_map(|name| {
let health = self
.provider_health
.entry(name.clone())
.or_insert_with(|| ProviderHealth::new(name.clone()));
if health.should_retry(&self.config) {
let priority = self
.config
.provider_priorities
.get(name)
.copied()
.unwrap_or(0);
Some((name.clone(), priority, health.is_healthy))
} else {
None
}
})
.collect();
available_providers.sort_by(|a, b| {
b.2.cmp(&a.2) .then_with(|| b.1.cmp(&a.1)) });
available_providers.first().map(|(name, _, _)| name.clone())
}
pub fn record_failure(&mut self, provider: &str) {
let health = self
.provider_health
.entry(provider.to_string())
.or_insert_with(|| ProviderHealth::new(provider.to_string()));
health.record_failure(&self.config);
}
pub fn record_success(&mut self, provider: &str) {
let health = self
.provider_health
.entry(provider.to_string())
.or_insert_with(|| ProviderHealth::new(provider.to_string()));
health.record_success();
}
pub fn get_provider_health(&self, provider: &str) -> Option<&ProviderHealth> {
self.provider_health.get(provider)
}
pub const fn get_all_health(&self) -> &std::collections::HashMap<String, ProviderHealth> {
&self.provider_health
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_strategy_delay_calculation() {
let strategy = RetryStrategy::new()
.with_base_delay(Duration::from_millis(100))
.with_backoff(BackoffStrategy::Exponential { multiplier: 2.0 })
.with_jitter(JitterConfig::None);
let delay1 = strategy.calculate_delay(0);
let delay2 = strategy.calculate_delay(1);
let delay3 = strategy.calculate_delay(2);
assert_eq!(delay1, Duration::from_millis(100));
assert_eq!(delay2, Duration::from_millis(200));
assert_eq!(delay3, Duration::from_millis(400));
}
#[test]
fn test_retryable_error_detection() {
let strategy = RetryStrategy::default();
assert!(strategy.is_retryable(&LlmError::HttpError("Connection failed".to_string())));
assert!(strategy.is_retryable(&LlmError::RateLimitError("Rate limited".to_string())));
assert!(strategy.is_retryable(&LlmError::TimeoutError("Request timeout".to_string())));
assert!(!strategy.is_retryable(&LlmError::InvalidParameter("Bad param".to_string())));
}
#[test]
fn test_rate_limit_retry_after_extraction() {
let handler = RateLimitHandler::new(RateLimitConfig::default());
let delay1 = handler.extract_retry_after("Rate limited. Please retry after 30 seconds.");
assert_eq!(delay1, Some(Duration::from_secs(30)));
let delay2 = handler.extract_retry_after("HTTP 429: Retry-After: 60");
assert_eq!(delay2, Some(Duration::from_secs(60)));
let delay3 = handler.extract_retry_after("No retry info");
assert_eq!(delay3, None);
}
#[test]
fn test_provider_health_tracking() {
let config = FailoverConfig::default();
let mut health = ProviderHealth::new("test-provider".to_string());
assert!(health.is_healthy);
assert_eq!(health.failure_count, 0);
health.record_failure(&config);
health.record_failure(&config);
assert!(health.is_healthy);
health.record_failure(&config);
assert!(!health.is_healthy);
health.record_success();
assert!(health.is_healthy);
assert_eq!(health.failure_count, 0);
}
#[test]
fn test_failover_manager() {
let mut config = FailoverConfig::default();
config
.provider_priorities
.insert("provider1".to_string(), 10);
config
.provider_priorities
.insert("provider2".to_string(), 5);
let mut manager = FailoverManager::new(config);
let providers = vec!["provider1".to_string(), "provider2".to_string()];
let next = manager.get_next_provider(&providers);
assert_eq!(next, Some("provider1".to_string()));
manager.record_failure("provider1");
manager.record_failure("provider1");
manager.record_failure("provider1");
let next = manager.get_next_provider(&providers);
assert_eq!(next, Some("provider2".to_string()));
}
}