use crate::error::{AixError, AixResult};
use rand::Rng;
use std::future::Future;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub multiplier: f64,
pub jitter: bool,
pub retry_on_rate_limit: bool,
pub retry_on_transport: bool,
pub retry_on_timeout: bool,
}
impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> RetryConfigBuilder {
RetryConfigBuilder::new()
}
pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
self.max_attempts = max_attempts;
self
}
pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self {
self.initial_backoff = initial_backoff;
self
}
pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
self.max_backoff = max_backoff;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier;
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn with_retry_on_rate_limit(mut self, retry: bool) -> Self {
self.retry_on_rate_limit = retry;
self
}
pub fn with_retry_on_transport(mut self, retry: bool) -> Self {
self.retry_on_transport = retry;
self
}
pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
self.retry_on_timeout = retry;
self
}
pub fn should_retry(&self, error: &AixError) -> bool {
if !error.is_retryable() {
return false;
}
match error {
AixError::RateLimit { .. } => self.retry_on_rate_limit,
AixError::Transport { .. } => self.retry_on_transport,
AixError::Timeout { .. } => self.retry_on_timeout,
AixError::Provider { status, .. } => {
status.map_or(false, |s| s >= 500)
}
_ => false,
}
}
pub fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = self.initial_backoff.as_secs_f64() * self.multiplier.powi(attempt as i32);
let base_delay = Duration::from_secs_f64(base_delay);
let delay = std::cmp::min(base_delay, self.max_backoff);
if self.jitter {
let jitter_range = delay.as_secs_f64() * 0.5; let jitter = rand::thread_rng().gen_range(0.0..jitter_range);
let actual_delay = delay.as_secs_f64() * (0.5 + jitter / jitter_range);
Duration::from_secs_f64(actual_delay)
} else {
delay
}
}
pub fn extract_retry_delay(&self, error: &AixError) -> Option<Duration> {
match error {
AixError::RateLimit { retry_after, .. } => *retry_after,
_ => None,
}
}
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_secs(30),
multiplier: 2.0,
jitter: true,
retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: true,
}
}
}
pub struct RetryConfigBuilder {
config: RetryConfig,
}
impl RetryConfigBuilder {
pub fn new() -> Self {
Self {
config: RetryConfig::default(),
}
}
pub fn max_attempts(mut self, max_attempts: u32) -> Self {
self.config.max_attempts = max_attempts;
self
}
pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
self.config.initial_backoff = initial_backoff;
self
}
pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
self.config.max_backoff = max_backoff;
self
}
pub fn multiplier(mut self, multiplier: f64) -> Self {
self.config.multiplier = multiplier;
self
}
pub fn jitter(mut self, jitter: bool) -> Self {
self.config.jitter = jitter;
self
}
pub fn retry_on_rate_limit(mut self, retry: bool) -> Self {
self.config.retry_on_rate_limit = retry;
self
}
pub fn retry_on_transport(mut self, retry: bool) -> Self {
self.config.retry_on_transport = retry;
self
}
pub fn retry_on_timeout(mut self, retry: bool) -> Self {
self.config.retry_on_timeout = retry;
self
}
pub fn build(self) -> RetryConfig {
self.config
}
}
impl Default for RetryConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct RetryStrategy {
config: RetryConfig,
}
impl RetryStrategy {
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub async fn execute<F, Fut, T>(&self, mut f: F) -> AixResult<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = AixResult<T>>,
{
let mut last_error = None;
for attempt in 0..self.config.max_attempts {
match f().await {
Ok(result) => return Ok(result),
Err(error) => {
last_error = Some(error.clone());
if !self.config.should_retry(&error) {
return Err(error);
}
if attempt == self.config.max_attempts - 1 {
return Err(error);
}
let delay = self
.config
.extract_retry_delay(&error)
.unwrap_or_else(|| self.config.calculate_delay(attempt));
tokio::time::sleep(delay).await;
}
}
}
Err(last_error.unwrap_or_else(|| AixError::other("All retry attempts failed")))
}
pub fn config(&self) -> &RetryConfig {
&self.config
}
pub fn config_mut(&mut self) -> &mut RetryConfig {
&mut self.config
}
}
impl From<RetryConfig> for RetryStrategy {
fn from(config: RetryConfig) -> Self {
Self::new(config)
}
}
impl RetryConfig {
pub fn no_retry() -> Self {
Self {
max_attempts: 1,
initial_backoff: Duration::from_millis(0),
max_backoff: Duration::from_millis(0),
multiplier: 1.0,
jitter: false,
retry_on_rate_limit: false,
retry_on_transport: false,
retry_on_timeout: false,
}
}
pub fn conservative() -> Self {
Self {
max_attempts: 2,
initial_backoff: Duration::from_secs(2),
max_backoff: Duration::from_secs(10),
multiplier: 2.0,
jitter: true,
retry_on_rate_limit: true,
retry_on_transport: false, retry_on_timeout: false,
}
}
pub fn aggressive() -> Self {
Self {
max_attempts: 5,
initial_backoff: Duration::from_millis(500),
max_backoff: Duration::from_secs(30),
multiplier: 1.5,
jitter: true,
retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: true,
}
}
pub fn fast() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(200),
max_backoff: Duration::from_secs(5),
multiplier: 1.5,
jitter: true,
retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: false, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::AixError;
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::builder()
.max_attempts(5)
.initial_backoff(Duration::from_millis(500))
.max_backoff(Duration::from_secs(10))
.multiplier(1.5)
.jitter(false)
.retry_on_rate_limit(false)
.build();
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_backoff, Duration::from_millis(500));
assert_eq!(config.max_backoff, Duration::from_secs(10));
assert_eq!(config.multiplier, 1.5);
assert!(!config.jitter);
assert!(!config.retry_on_rate_limit);
}
#[test]
fn test_backoff_calculation() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_secs(10),
multiplier: 2.0,
jitter: false, retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: true,
};
assert_eq!(config.calculate_delay(0), Duration::from_millis(1000));
assert_eq!(config.calculate_delay(1), Duration::from_millis(2000));
assert_eq!(config.calculate_delay(2), Duration::from_millis(4000));
let long_config = RetryConfig {
max_attempts: 10,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_millis(3000),
multiplier: 2.0,
jitter: false,
retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: true,
};
assert_eq!(long_config.calculate_delay(3), Duration::from_millis(3000));
}
#[test]
fn test_jitter() {
let config = RetryConfig {
max_attempts: 3,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_secs(10),
multiplier: 2.0,
jitter: true,
retry_on_rate_limit: true,
retry_on_transport: true,
retry_on_timeout: true,
};
let delay = config.calculate_delay(0);
assert!(delay >= Duration::from_millis(500));
assert!(delay <= Duration::from_millis(1500));
}
#[test]
fn test_should_retry() {
let config = RetryConfig::default();
assert!(config.should_retry(&AixError::transport("network error", "request")));
assert!(config.should_retry(&AixError::rate_limit("openai", "too many requests")));
assert!(config.should_retry(&AixError::timeout("chat", Duration::from_secs(30))));
assert!(config.should_retry(&AixError::provider_with_details("openai", "server error", 500, "internal_error")));
assert!(!config.should_retry(&AixError::config("invalid config")));
assert!(!config.should_retry(&AixError::auth("openai", "unauthorized")));
assert!(!config.should_retry(&AixError::provider_with_details("openai", "bad request", 400, "invalid_request")));
}
#[tokio::test]
async fn test_retry_strategy_success() {
let strategy = RetryStrategy::new(RetryConfig::default());
let mut call_count = 0;
let result = strategy
.execute(|| {
call_count += 1;
async move { Ok::<_, AixError>("success") }
})
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(call_count, 1); }
#[tokio::test]
async fn test_retry_strategy_with_retry() {
let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(3).build());
let mut call_count = 0;
let result = strategy
.execute(|| {
call_count += 1;
async move {
if call_count < 3 {
Err::<_, AixError>(AixError::transport("network error", "request"))
} else {
Ok("success")
}
}
})
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(call_count, 3); }
#[tokio::test]
async fn test_retry_strategy_exhausted() {
let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(2).build());
let mut call_count = 0;
let result = strategy
.execute(|| {
call_count += 1;
async move {
Err::<_, AixError>(AixError::transport("network error", "request"))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count, 2); }
#[test]
fn test_preset_configs() {
let no_retry = RetryConfig::no_retry();
assert_eq!(no_retry.max_attempts, 1);
assert!(!no_retry.retry_on_rate_limit);
let conservative = RetryConfig::conservative();
assert_eq!(conservative.max_attempts, 2);
assert!(!conservative.retry_on_transport);
let aggressive = RetryConfig::aggressive();
assert_eq!(aggressive.max_attempts, 5);
assert!(aggressive.retry_on_transport);
let fast = RetryConfig::fast();
assert_eq!(fast.max_attempts, 3);
assert_eq!(fast.initial_backoff, Duration::from_millis(200));
}
}