use std::future::Future;
use std::time::Duration;
use rand::Rng;
use tokio::time::sleep;
const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(30);
const DEFAULT_JITTER_FACTOR: f64 = 0.25;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub jitter_factor: f64,
use_exponential_backoff: bool,
retryable_predicate: Option<fn(&str) -> bool>,
}
impl RetryPolicy {
pub fn new(max_attempts: u32, delay_ms: u64) -> Self {
Self {
max_attempts,
base_delay: Duration::from_millis(delay_ms),
max_delay: Duration::from_millis(delay_ms),
jitter_factor: 0.0,
use_exponential_backoff: false,
retryable_predicate: None,
}
}
pub fn with_exponential_backoff(
max_attempts: u32,
base_delay: Duration,
max_delay: Duration,
jitter_factor: f64,
) -> Self {
Self {
max_attempts,
base_delay,
max_delay,
jitter_factor: jitter_factor.clamp(0.0, 1.0),
use_exponential_backoff: true,
retryable_predicate: None,
}
}
pub fn with_retryable_predicate(mut self, predicate: fn(&str) -> bool) -> Self {
self.retryable_predicate = Some(predicate);
self
}
pub fn next_delay(&self, attempt: u32) -> Duration {
if !self.use_exponential_backoff {
return self.base_delay;
}
let base_ms = self.base_delay.as_millis() as u64;
let max_ms = self.max_delay.as_millis() as u64;
let delay_ms = if attempt >= 64 {
max_ms
} else {
let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
base_ms.saturating_mul(multiplier).min(max_ms)
};
if self.jitter_factor > 0.0 {
self.apply_jitter(Duration::from_millis(delay_ms))
} else {
Duration::from_millis(delay_ms)
}
}
fn apply_jitter(&self, delay: Duration) -> Duration {
let mut rng = rand::rng();
let jitter_range = self.jitter_factor * 2.0;
let jitter_offset = rng.random::<f64>() * jitter_range - self.jitter_factor;
let factor = 1.0 + jitter_offset;
let delay_ms = delay.as_millis() as f64;
let jittered_ms = (delay_ms * factor).max(1.0) as u64;
Duration::from_millis(jittered_ms)
}
fn is_retryable(&self, error_msg: &str) -> bool {
match self.retryable_predicate {
Some(pred) => pred(error_msg),
None => true,
}
}
pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display,
{
let mut last_error: Option<E> = None;
for attempt in 0..self.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
let error_msg = e.to_string();
if !self.is_retryable(&error_msg) {
return Err(e);
}
last_error = Some(e);
if attempt < self.max_attempts - 1 {
let delay = self.next_delay(attempt);
sleep(delay).await;
}
}
}
}
Err(last_error.expect("at least one attempt must have been made"))
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::with_exponential_backoff(
3,
Duration::from_millis(200),
DEFAULT_MAX_DELAY,
DEFAULT_JITTER_FACTOR,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fixed_delay() {
let policy = RetryPolicy::new(3, 200);
assert_eq!(policy.next_delay(0), Duration::from_millis(200));
assert_eq!(policy.next_delay(1), Duration::from_millis(200));
assert_eq!(policy.next_delay(2), Duration::from_millis(200));
}
#[test]
fn test_exponential_no_jitter() {
let policy = RetryPolicy::with_exponential_backoff(
5,
Duration::from_millis(100),
Duration::from_secs(30),
0.0,
);
assert_eq!(policy.next_delay(0), Duration::from_millis(100));
assert_eq!(policy.next_delay(1), Duration::from_millis(200));
assert_eq!(policy.next_delay(2), Duration::from_millis(400));
}
#[test]
fn test_capped_at_max_delay() {
let policy = RetryPolicy::with_exponential_backoff(
5,
Duration::from_millis(100),
Duration::from_secs(1),
0.0,
);
assert_eq!(policy.next_delay(20), Duration::from_secs(1));
}
#[test]
fn test_overflow_protection() {
let policy = RetryPolicy::with_exponential_backoff(
5,
Duration::from_secs(1),
Duration::from_secs(3600),
0.0,
);
let delay = policy.next_delay(100);
assert!(delay <= Duration::from_secs(3600));
}
#[tokio::test]
async fn test_execute_success() {
let policy = RetryPolicy::new(3, 10);
let result: Result<i32, String> = policy.execute(|| async { Ok(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_execute_retries_then_succeeds() {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let policy = RetryPolicy::new(3, 1);
let result: Result<i32, String> = policy
.execute(|| {
let a = Arc::clone(&attempts_clone);
async move {
let count = a.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("transient".to_string())
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
}