use rand::Rng;
use std::time::Duration;
use thiserror::Error;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub use_jitter: bool,
pub jitter_factor: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
multiplier: 2.0,
use_jitter: true,
jitter_factor: 0.3,
}
}
}
#[derive(Debug, Error)]
pub enum RetryError<E> {
#[error("Operation failed after {attempts} attempts: {last_error}")]
MaxRetriesExceeded { attempts: u32, last_error: E },
#[error("Retry aborted: {0}")]
Aborted(String),
}
pub struct RetryPolicy {
config: RetryConfig,
}
impl RetryPolicy {
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
pub fn with_default_config() -> Self {
Self {
config: RetryConfig::default(),
}
}
pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display + Clone,
{
let mut attempts = 0;
#[allow(unused_assignments)]
let mut _last_error: Option<E> = None;
loop {
attempts += 1;
debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!("Operation succeeded after {} attempts", attempts);
}
return Ok(result);
}
Err(e) => {
warn!("Attempt {} failed: {}", attempts, e);
if attempts > self.config.max_retries {
return Err(RetryError::MaxRetriesExceeded {
attempts,
last_error: e,
});
}
_last_error = Some(e);
let delay = self.calculate_delay(attempts);
debug!("Waiting {:?} before retry", delay);
tokio::time::sleep(delay).await;
}
}
}
}
pub fn execute_sync<F, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
where
F: FnMut() -> Result<T, E>,
E: std::fmt::Display + Clone,
{
let mut attempts = 0;
#[allow(unused_assignments)]
let mut _last_error: Option<E> = None;
loop {
attempts += 1;
debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
match operation() {
Ok(result) => {
if attempts > 1 {
debug!("Operation succeeded after {} attempts", attempts);
}
return Ok(result);
}
Err(e) => {
warn!("Attempt {} failed: {}", attempts, e);
if attempts > self.config.max_retries {
return Err(RetryError::MaxRetriesExceeded {
attempts,
last_error: e,
});
}
_last_error = Some(e);
let delay = self.calculate_delay(attempts);
debug!("Waiting {:?} before retry", delay);
std::thread::sleep(delay);
}
}
}
}
pub async fn execute_with_condition<F, Fut, T, E, C>(
&self,
mut operation: F,
mut should_retry: C,
) -> Result<T, RetryError<E>>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: std::fmt::Display + Clone,
C: FnMut(&E) -> bool,
{
let mut attempts = 0;
loop {
attempts += 1;
debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!("Operation succeeded after {} attempts", attempts);
}
return Ok(result);
}
Err(e) => {
if !should_retry(&e) {
debug!("Error is not retryable: {}", e);
return Err(RetryError::Aborted(format!(
"Non-retryable error after {} attempts: {}",
attempts, e
)));
}
warn!("Attempt {} failed: {}", attempts, e);
if attempts > self.config.max_retries {
return Err(RetryError::MaxRetriesExceeded {
attempts,
last_error: e,
});
}
let delay = self.calculate_delay(attempts);
debug!("Waiting {:?} before retry", delay);
tokio::time::sleep(delay).await;
}
}
}
}
fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay_ms = self.config.initial_delay.as_millis() as f64
* self.config.multiplier.powi((attempt - 1) as i32);
let base_delay = Duration::from_millis(base_delay_ms as u64);
let capped_delay = if base_delay > self.config.max_delay {
self.config.max_delay
} else {
base_delay
};
if self.config.use_jitter {
self.add_jitter(capped_delay)
} else {
capped_delay
}
}
fn add_jitter(&self, delay: Duration) -> Duration {
let mut rng = rand::thread_rng();
let delay_ms = delay.as_millis() as f64;
let jitter_range = delay_ms * self.config.jitter_factor;
let jitter = rng.gen_range(-jitter_range..=jitter_range);
let jittered_ms = (delay_ms + jitter).max(0.0);
Duration::from_millis(jittered_ms as u64)
}
pub fn config(&self) -> &RetryConfig {
&self.config
}
}
pub struct RetryConfigBuilder {
config: RetryConfig,
}
impl RetryConfigBuilder {
pub fn new() -> Self {
Self {
config: RetryConfig::default(),
}
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.config.max_retries = max_retries;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.config.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.config.max_delay = delay;
self
}
pub fn multiplier(mut self, multiplier: f64) -> Self {
self.config.multiplier = multiplier;
self
}
pub fn use_jitter(mut self, use_jitter: bool) -> Self {
self.config.use_jitter = use_jitter;
self
}
pub fn jitter_factor(mut self, factor: f64) -> Self {
self.config.jitter_factor = factor.clamp(0.0, 1.0);
self
}
pub fn build(self) -> RetryConfig {
self.config
}
}
impl Default for RetryConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_immediate_success() {
let policy = RetryPolicy::default();
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy
.execute(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok::<_, String>("success")
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_on_failure() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy
.execute(|| async {
let count = counter_clone.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("temporary failure")
} else {
Ok("success")
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_max_retries_exceeded() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy
.execute(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("persistent failure")
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
match result {
Err(RetryError::MaxRetriesExceeded { attempts, .. }) => {
assert_eq!(attempts, 3);
}
_ => panic!("Expected MaxRetriesExceeded error"),
}
}
#[tokio::test]
async fn test_exponential_backoff() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(50),
multiplier: 2.0,
use_jitter: false,
..Default::default()
};
let policy = RetryPolicy::new(config);
let delay1 = policy.calculate_delay(1);
let delay2 = policy.calculate_delay(2);
let delay3 = policy.calculate_delay(3);
assert_eq!(delay1, Duration::from_millis(50));
assert_eq!(delay2, Duration::from_millis(100));
assert_eq!(delay3, Duration::from_millis(200));
}
#[tokio::test]
async fn test_max_delay_cap() {
let config = RetryConfig {
max_retries: 5,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(500),
multiplier: 2.0,
use_jitter: false,
..Default::default()
};
let policy = RetryPolicy::new(config);
let delay5 = policy.calculate_delay(5);
assert_eq!(delay5, Duration::from_millis(500)); }
#[tokio::test]
async fn test_jitter_adds_variation() {
let config = RetryConfig {
max_retries: 1,
initial_delay: Duration::from_millis(100),
use_jitter: true,
jitter_factor: 0.5,
..Default::default()
};
let policy = RetryPolicy::new(config);
let mut delays = vec![];
for _ in 0..10 {
let delay = policy.calculate_delay(1);
delays.push(delay);
}
for delay in &delays {
let ms = delay.as_millis();
assert!(ms >= 50 && ms <= 150, "Delay {} outside expected range", ms);
}
let all_same = delays.iter().all(|d| d == &delays[0]);
assert!(!all_same, "All delays are the same, jitter not working");
}
#[tokio::test]
async fn test_synchronous_retry() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy.execute_sync(|| {
let count = counter_clone.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("temporary failure")
} else {
Ok("success")
}
});
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_with_custom_condition() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy
.execute_with_condition(
|| {
let counter = Arc::clone(&counter_clone);
async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("permanent error")
}
},
|e| e.contains("temporary"),
)
.await;
assert!(matches!(result, Err(RetryError::Aborted(_))));
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_builder() {
let config = RetryConfigBuilder::new()
.max_retries(5)
.initial_delay(Duration::from_millis(50))
.max_delay(Duration::from_secs(10))
.multiplier(3.0)
.use_jitter(false)
.build();
assert_eq!(config.max_retries, 5);
assert_eq!(config.initial_delay, Duration::from_millis(50));
assert_eq!(config.max_delay, Duration::from_secs(10));
assert_eq!(config.multiplier, 3.0);
assert!(!config.use_jitter);
}
#[tokio::test]
async fn test_zero_retries() {
let config = RetryConfig {
max_retries: 0,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let result = policy
.execute(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("error")
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_concurrent_retries() {
let policy = Arc::new(RetryPolicy::default());
let mut handles = vec![];
for i in 0..5 {
let policy_clone = Arc::clone(&policy);
let handle = tokio::spawn(async move {
policy_clone
.execute(|| async move {
if i % 2 == 0 {
Ok(format!("success {}", i))
} else {
Err(format!("error {}", i))
}
})
.await
});
handles.push(handle);
}
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap();
if i % 2 == 0 {
assert!(result.is_ok());
} else {
assert!(result.is_err());
}
}
}
#[tokio::test]
async fn test_jitter_factor_clamping() {
let config = RetryConfigBuilder::new()
.jitter_factor(1.5) .build();
assert_eq!(config.jitter_factor, 1.0);
let config = RetryConfigBuilder::new()
.jitter_factor(-0.5) .build();
assert_eq!(config.jitter_factor, 0.0);
}
#[tokio::test]
async fn test_timing_accuracy() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(100),
multiplier: 2.0,
use_jitter: false,
..Default::default()
};
let policy = RetryPolicy::new(config);
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = Arc::clone(&counter);
let start = std::time::Instant::now();
let _ = policy
.execute(|| async {
counter_clone.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("error")
})
.await;
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(300),
"Elapsed time {:?} less than expected",
elapsed
);
}
}