use crate::error::{Error, Result};
use std::time::Duration;
use tokio::time::sleep;
use tracing::debug;
#[derive(Debug, Clone)]
pub struct RetryStrategy {
pub max_attempts: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryStrategy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryStrategy {
fn calculate_delay(&self, attempt: usize) -> Duration {
let base_delay = self.initial_delay.as_millis() as f64;
let multiplier = self.backoff_multiplier.powi(attempt as i32);
let delay_ms = (base_delay * multiplier) as u64;
let delay = Duration::from_millis(delay_ms).min(self.max_delay);
if self.jitter {
let jitter_factor = 0.75 + (fastrand::f64() * 0.5);
let jittered_ms = (delay.as_millis() as f64 * jitter_factor) as u64;
Duration::from_millis(jittered_ms)
} else {
delay
}
}
pub async fn execute<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..self.max_attempts {
match operation().await {
Ok(result) => {
if attempt > 0 {
debug!("Operation succeeded after {} retries", attempt);
}
return Ok(result);
}
Err(error) => {
if !error.is_retryable() {
debug!("Error is not retryable, failing immediately: {}", error);
return Err(error);
}
debug!("Attempt {} failed: {}", attempt + 1, error);
last_error = Some(error);
if attempt < self.max_attempts - 1 {
let delay = self.calculate_delay(attempt);
debug!("Retrying in {:?}", delay);
sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| Error::generic("All retry attempts failed")))
}
}
pub async fn with_retry<F, Fut, T>(strategy: &RetryStrategy, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
strategy.execute(operation).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let strategy = RetryStrategy::default();
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = attempt_count.clone();
let result = strategy
.execute(|| async {
attempt_count_clone.fetch_add(1, Ordering::SeqCst);
Ok::<i32, Error>(42)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let strategy = RetryStrategy {
max_attempts: 3,
initial_delay: Duration::from_millis(10),
jitter: false,
..Default::default()
};
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = attempt_count.clone();
let result = strategy
.execute(|| async {
let count = attempt_count_clone.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(Error::session("Temporary failure"))
} else {
Ok::<i32, Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_non_retryable_error() {
let strategy = RetryStrategy::default();
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = attempt_count.clone();
let result = strategy
.execute(|| async {
attempt_count_clone.fetch_add(1, Ordering::SeqCst);
Err::<i32, Error>(Error::authentication("Invalid credentials"))
})
.await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_delay_calculation() {
let strategy = RetryStrategy {
initial_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_delay: Duration::from_secs(10),
jitter: false,
..Default::default()
};
assert_eq!(strategy.calculate_delay(0), Duration::from_millis(100));
assert_eq!(strategy.calculate_delay(1), Duration::from_millis(200));
assert_eq!(strategy.calculate_delay(2), Duration::from_millis(400));
}
}