use crate::error::{Error, Result};
use rand::RngExt;
use std::time::Duration;
use tokio::time::sleep;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub use_jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 60000,
use_jitter: true,
}
}
}
impl RetryConfig {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn with_base_delay_ms(mut self, ms: u64) -> Self {
self.base_delay_ms = ms;
self
}
pub fn with_max_delay_ms(mut self, ms: u64) -> Self {
self.max_delay_ms = ms;
self
}
pub fn with_jitter(mut self, use_jitter: bool) -> Self {
self.use_jitter = use_jitter;
self
}
}
pub async fn retry_async<F, Fut, T>(config: &RetryConfig, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 1..=config.max_retries {
debug!("Retry attempt {} of {}", attempt, config.max_retries);
let outcome = operation().await;
match outcome {
Ok(result) => return Ok(result),
Err(e) => {
warn!("Attempt {} failed: {}", attempt, e);
if !e.is_retryable() {
debug!("Error is not retryable, failing immediately");
return Err(e);
}
last_error = Some(e);
if attempt < config.max_retries {
let delay = calculate_backoff(attempt, config);
debug!("Backing off for {} ms", delay.as_millis());
sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or(Error::MaxRetriesExceeded(config.max_retries)))
}
fn calculate_backoff(attempt: u32, config: &RetryConfig) -> Duration {
let base_delay = config.base_delay_ms;
let exponential_delay = base_delay.saturating_mul(2u64.saturating_pow(attempt - 1));
let capped_delay = exponential_delay.min(config.max_delay_ms);
let final_delay = if config.use_jitter {
let jitter = rand::rng().random_range(0..=(capped_delay / 4));
capped_delay + jitter
} else {
capped_delay
};
Duration::from_millis(final_delay)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_retry_config_builder() {
let config = RetryConfig::new(5)
.with_base_delay_ms(500)
.with_max_delay_ms(30000)
.with_jitter(false);
assert_eq!(config.max_retries, 5);
assert_eq!(config.base_delay_ms, 500);
assert_eq!(config.max_delay_ms, 30000);
assert!(!config.use_jitter);
}
#[test]
fn test_calculate_backoff_no_jitter() {
let config = RetryConfig::new(3)
.with_base_delay_ms(1000)
.with_max_delay_ms(60000)
.with_jitter(false);
let delay1 = calculate_backoff(1, &config);
assert_eq!(delay1, Duration::from_millis(1000));
let delay2 = calculate_backoff(2, &config);
assert_eq!(delay2, Duration::from_millis(2000));
let delay3 = calculate_backoff(3, &config);
assert_eq!(delay3, Duration::from_millis(4000));
}
#[test]
fn test_calculate_backoff_with_cap() {
let config = RetryConfig::new(10)
.with_base_delay_ms(1000)
.with_max_delay_ms(5000)
.with_jitter(false);
let delay = calculate_backoff(10, &config);
assert_eq!(delay, Duration::from_millis(5000));
}
#[tokio::test]
async fn test_retry_succeeds_on_first_attempt() {
let config = RetryConfig::new(3);
let call_count = AtomicUsize::new(0);
let result = retry_async(&config, || async {
call_count.fetch_add(1, Ordering::SeqCst);
Ok::<_, Error>(42)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_succeeds_on_second_attempt() {
let config = RetryConfig::new(3).with_base_delay_ms(10);
let call_count = AtomicUsize::new(0);
let result = retry_async(&config, || async {
let count = call_count.fetch_add(1, Ordering::SeqCst) + 1;
if count == 1 {
Err(Error::Network("connection failed".to_string()))
} else {
Ok::<_, Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_fails_non_retryable() {
let config = RetryConfig::new(3);
let call_count = AtomicUsize::new(0);
let result = retry_async(&config, || async {
call_count.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(Error::Authentication)
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retry_exhausts_attempts() {
let config = RetryConfig::new(2).with_base_delay_ms(10);
let call_count = AtomicUsize::new(0);
let result = retry_async(&config, || async {
call_count.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(Error::Network("persistent failure".to_string()))
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 2); }
}