use crate::error::SyncError;
use std::future::Future;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
pub fn new(max_attempts: u32, initial_delay: Duration) -> Self {
Self {
max_attempts,
initial_delay,
..Default::default()
}
}
#[allow(dead_code)] pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
self.max_delay = max_delay;
self
}
#[allow(dead_code)] pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
fn calculate_delay(&self, attempt: u32) -> Duration {
let delay_secs =
self.initial_delay.as_secs_f64() * self.backoff_multiplier.powi(attempt as i32);
let delay = Duration::from_secs_f64(delay_secs);
std::cmp::min(delay, self.max_delay)
}
}
pub async fn retry_with_backoff<F, Fut, T>(
config: &RetryConfig,
mut operation: F,
) -> Result<T, SyncError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, SyncError>>,
{
let mut attempt = 0;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(e) if attempt >= config.max_attempts => {
return Err(e);
}
Err(e) if e.is_retryable() => {
let delay = config.calculate_delay(attempt);
eprintln!(
"Attempt {}/{} failed: {}. Retrying in {:?}...",
attempt + 1,
config.max_attempts,
e,
delay
);
tokio::time::sleep(delay).await;
attempt += 1;
}
Err(e) => {
return Err(e);
}
}
}
}
#[allow(dead_code)] pub async fn retry_default<F, Fut, T>(operation: F) -> Result<T, SyncError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, SyncError>>,
{
retry_with_backoff(&RetryConfig::default(), operation).await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_attempts, 3);
assert_eq!(config.initial_delay, Duration::from_secs(1));
assert_eq!(config.max_delay, Duration::from_secs(30));
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_retry_config_custom() {
let config = RetryConfig::new(5, Duration::from_millis(500))
.with_max_delay(Duration::from_secs(60))
.with_backoff_multiplier(3.0);
assert_eq!(config.max_attempts, 5);
assert_eq!(config.initial_delay, Duration::from_millis(500));
assert_eq!(config.max_delay, Duration::from_secs(60));
assert_eq!(config.backoff_multiplier, 3.0);
}
#[test]
fn test_calculate_delay_exponential() {
let config = RetryConfig::default();
assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
}
#[test]
fn test_calculate_delay_capped_at_max() {
let config = RetryConfig {
initial_delay: Duration::from_secs(10),
max_delay: Duration::from_secs(15),
backoff_multiplier: 2.0,
max_attempts: 5,
};
assert_eq!(config.calculate_delay(0), Duration::from_secs(10));
assert_eq!(config.calculate_delay(1), Duration::from_secs(15));
assert_eq!(config.calculate_delay(2), Duration::from_secs(15));
}
#[tokio::test]
async fn test_retry_success_first_attempt() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig::default();
let result = retry_with_backoff(&config, || {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok::<_, SyncError>(42)
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_retries() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig::new(3, Duration::from_millis(10));
let result = retry_with_backoff(&config, || {
let c = counter_clone.clone();
async move {
let count = c.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(SyncError::NetworkTimeout {
duration: Duration::from_secs(1),
})
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_retry_exhausted() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig::new(2, Duration::from_millis(10));
let result = retry_with_backoff(&config, || {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(SyncError::NetworkTimeout {
duration: Duration::from_secs(1),
})
}
})
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SyncError::NetworkTimeout { .. }
));
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_non_retryable_error() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let config = RetryConfig::default();
let result = retry_with_backoff(&config, || {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Err::<i32, _>(SyncError::NetworkFatal {
message: "Fatal error".to_string(),
})
}
})
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
SyncError::NetworkFatal { .. }
));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_default_wrapper() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result = retry_default(|| {
let c = counter_clone.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
Ok::<_, SyncError>(100)
}
})
.await;
assert_eq!(result.unwrap(), 100);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
}