use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
}
}
}
impl RetryConfig {
pub fn new(max_retries: usize) -> Self {
RetryConfig {
max_retries,
..Default::default()
}
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier;
self
}
pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
let delay_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
let delay = Duration::from_millis(delay_ms as u64);
delay.min(self.max_delay)
}
}
pub struct RetryExecutor {
config: RetryConfig,
}
impl RetryExecutor {
pub fn new(config: RetryConfig) -> Self {
RetryExecutor { config }
}
pub fn default_config() -> Self {
Self::new(RetryConfig::default())
}
pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Debug,
{
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
if attempt < self.config.max_retries {
let delay = self.config.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
}
last_error = Some(e);
}
}
}
Err(last_error.expect("At least one error occurred"))
}
pub async fn execute_with_predicate<F, Fut, T, E, P>(
&self,
mut operation: F,
mut should_retry: P,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
P: FnMut(&E) -> bool,
E: std::fmt::Debug,
{
let mut last_error = None;
for attempt in 0..=self.config.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
if attempt < self.config.max_retries && should_retry(&e) {
let delay = self.config.delay_for_attempt(attempt);
tokio::time::sleep(delay).await;
} else {
return Err(e);
}
last_error = Some(e);
}
}
}
Err(last_error.expect("At least one error occurred"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay, Duration::from_millis(100));
}
#[test]
fn test_delay_calculation() {
let config = RetryConfig::default();
assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
}
#[tokio::test]
async fn test_retry_success() {
let executor = RetryExecutor::new(RetryConfig::new(3));
let attempts = std::sync::Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result = executor
.execute(move || {
let attempts = attempts_clone.clone();
async move {
let count = attempts.fetch_add(1, Ordering::SeqCst) + 1;
if count < 2 {
Err("temporary error")
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result, Ok(42));
assert_eq!(attempts.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_exhausted() {
let executor = RetryExecutor::new(RetryConfig::new(2));
let attempts = std::sync::Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result: Result<(), &str> = executor
.execute(move || {
let attempts = attempts_clone.clone();
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err("permanent error")
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
}