use std::{fmt::Display, future::Future};
use crate::time::{sleep, Duration};
use rand::Rng;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_base: f64,
pub use_jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
backoff_base: 2.0,
use_jitter: true,
}
}
}
impl RetryConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn without_jitter(mut self) -> Self {
self.use_jitter = false;
self
}
fn calculate_delay(&self, attempt: usize) -> Duration {
let base_delay = self.initial_delay.as_millis() as f64;
let exponential_delay = base_delay * self.backoff_base.powi(attempt as i32);
let capped_delay = exponential_delay.min(self.max_delay.as_millis() as f64);
let final_delay = if self.use_jitter {
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.5..1.5);
capped_delay * jitter_factor
} else {
capped_delay
};
Duration::from_millis(final_delay as u64)
}
}
pub trait RetryableError {
fn is_retryable(&self) -> bool;
fn retry_reason(&self) -> &str {
if self.is_retryable() {
"Transient error, will retry"
} else {
"Non-transient error, will not retry"
}
}
}
pub async fn with_retry<F, Fut, T, E>(
mut operation: F,
config: &RetryConfig,
operation_name: &str,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: RetryableError + Display,
{
let mut attempt = 0;
let mut last_error = None;
while attempt <= config.max_retries {
debug!(
operation = operation_name,
attempt = attempt,
max_retries = config.max_retries,
"Attempting operation"
);
match operation().await {
Ok(result) => {
if attempt > 0 {
debug!(
operation = operation_name,
attempt = attempt,
"Operation succeeded after retry"
);
}
return Ok(result);
}
Err(error) => {
debug!(
operation = operation_name,
error = %error,
is_retryable = error.is_retryable(),
reason = error.retry_reason(),
"Error encountered"
);
if !error.is_retryable() {
debug!(
operation = operation_name,
error = %error,
reason = error.retry_reason(),
"Non-retryable error encountered"
);
return Err(error);
}
last_error = Some(error);
if attempt < config.max_retries {
let delay = config.calculate_delay(attempt);
debug!(
operation = operation_name,
attempt = attempt + 1,
max_retries = config.max_retries,
delay_ms = delay.as_millis(),
error = %last_error.as_ref().unwrap(),
"Retryable error encountered, will retry after delay"
);
sleep(delay).await;
} else {
warn!(
operation = operation_name,
attempts = attempt + 1,
error = %last_error.as_ref().unwrap(),
"Max retries exceeded, operation failed"
);
}
}
}
attempt += 1;
}
Err(last_error.unwrap())
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
enum TestError {
#[error("Retryable error")]
Retryable,
#[error("Non-retryable error")]
NonRetryable,
}
impl RetryableError for TestError {
fn is_retryable(&self) -> bool {
matches!(self, TestError::Retryable)
}
}
#[tokio::test]
async fn test_successful_operation() {
let config = RetryConfig::new().without_jitter();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = with_retry(
|| {
counter_clone.fetch_add(1, Ordering::SeqCst);
async { Ok::<_, TestError>(42) }
},
&config,
"test_operation",
)
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retry_then_success() {
let config = RetryConfig::new()
.with_max_retries(3)
.with_initial_delay(Duration::from_millis(10))
.without_jitter();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = with_retry(
|| {
let count = counter_clone.fetch_add(1, Ordering::SeqCst);
async move {
if count < 2 {
Err(TestError::Retryable)
} else {
Ok(42)
}
}
},
&config,
"test_operation",
)
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_non_retryable_error() {
let config = RetryConfig::new().with_max_retries(3);
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = with_retry(
|| {
counter_clone.fetch_add(1, Ordering::SeqCst);
async { Err::<i32, _>(TestError::NonRetryable) }
},
&config,
"test_operation",
)
.await;
assert!(matches!(result, Err(TestError::NonRetryable)));
assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_max_retries_exceeded() {
let config = RetryConfig::new()
.with_max_retries(2)
.with_initial_delay(Duration::from_millis(10))
.without_jitter();
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let result = with_retry(
|| {
counter_clone.fetch_add(1, Ordering::SeqCst);
async { Err::<i32, _>(TestError::Retryable) }
},
&config,
"test_operation",
)
.await;
assert!(matches!(result, Err(TestError::Retryable)));
assert_eq!(counter.load(Ordering::SeqCst), 3); }
#[test]
fn test_exponential_backoff_calculation() {
let config = RetryConfig::new()
.with_initial_delay(Duration::from_secs(1))
.with_max_delay(Duration::from_secs(30))
.without_jitter();
assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
assert_eq!(config.calculate_delay(4), Duration::from_secs(16));
assert_eq!(config.calculate_delay(5), Duration::from_secs(30)); assert_eq!(config.calculate_delay(10), Duration::from_secs(30)); }
#[test]
fn test_jitter() {
let config = RetryConfig::new().with_initial_delay(Duration::from_secs(1));
let delay1 = config.calculate_delay(1);
let delay2 = config.calculate_delay(1);
assert!(delay1.as_millis() >= 1000);
assert!(delay1.as_millis() <= 3000);
assert!(delay2.as_millis() >= 1000);
assert!(delay2.as_millis() <= 3000);
}
}