use super::strategy::BackoffStrategy;
use async_trait::async_trait;
use std::error::Error;
use std::future::Future;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
max_retries: u32,
initial_delay: Duration,
max_delay: Duration,
multiplier: f64,
jitter: f64,
}
impl ExponentialBackoff {
pub fn builder() -> ExponentialBackoffBuilder {
ExponentialBackoffBuilder::default()
}
}
impl Default for ExponentialBackoff {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
jitter: 0.1,
}
}
}
#[async_trait]
impl BackoffStrategy for ExponentialBackoff {
async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Error + Send + Sync + 'static,
{
let mut attempt = 0;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(err) if !self.should_retry(&err, attempt) => return Err(err),
Err(err) if attempt >= self.max_retries => return Err(err),
Err(_) => {
if let Some(delay) = self.next_delay(attempt) {
tokio::time::sleep(delay).await;
}
attempt += 1;
}
}
}
}
fn next_delay(&self, attempt: u32) -> Option<Duration> {
let base_delay = self.initial_delay.as_secs_f64() * self.multiplier.powi(attempt as i32);
let jittered = if self.jitter > 0.0 {
let jitter_amount = base_delay * self.jitter * (rand::random::<f64>() - 0.5) * 2.0;
base_delay + jitter_amount
} else {
base_delay
};
Some(Duration::from_secs_f64(
jittered.min(self.max_delay.as_secs_f64()),
))
}
fn max_retries(&self) -> u32 {
self.max_retries
}
}
#[derive(Debug, Default)]
pub struct ExponentialBackoffBuilder {
max_retries: Option<u32>,
initial_delay: Option<Duration>,
max_delay: Option<Duration>,
multiplier: Option<f64>,
jitter: Option<f64>,
}
impl ExponentialBackoffBuilder {
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = Some(max_retries);
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = Some(delay);
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = Some(delay);
self
}
pub fn multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = Some(multiplier);
self
}
pub fn jitter(mut self, jitter: f64) -> Self {
self.jitter = Some(jitter.clamp(0.0, 1.0));
self
}
pub fn build(self) -> ExponentialBackoff {
ExponentialBackoff {
max_retries: self.max_retries.unwrap_or(3),
initial_delay: self.initial_delay.unwrap_or(Duration::from_millis(100)),
max_delay: self.max_delay.unwrap_or(Duration::from_secs(60)),
multiplier: self.multiplier.unwrap_or(2.0),
jitter: self.jitter.unwrap_or(0.1),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn test_exponential_delay_calculation() {
let backoff = ExponentialBackoff {
max_retries: 5,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
multiplier: 2.0,
jitter: 0.0, };
assert_eq!(backoff.next_delay(0).unwrap(), Duration::from_millis(100));
assert_eq!(backoff.next_delay(1).unwrap(), Duration::from_millis(200));
assert_eq!(backoff.next_delay(2).unwrap(), Duration::from_millis(400));
assert_eq!(backoff.next_delay(3).unwrap(), Duration::from_millis(800));
}
#[test]
fn test_max_delay_cap() {
let backoff = ExponentialBackoff {
max_retries: 100,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(5), multiplier: 10.0, jitter: 0.0,
};
for attempt in 5..10 {
let delay = backoff.next_delay(attempt).unwrap();
assert!(
delay <= Duration::from_secs(5),
"Delay at attempt {} ({:?}) exceeded max_delay",
attempt,
delay
);
}
}
#[tokio::test]
async fn test_retry_success_on_third_attempt() {
let backoff = ExponentialBackoff::builder()
.max_retries(5)
.initial_delay(Duration::from_millis(1)) .jitter(0.0)
.build();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result = backoff
.execute(|| {
let attempts = Arc::clone(&attempts_clone);
async move {
let current = attempts.fetch_add(1, Ordering::SeqCst);
if current < 2 {
Err(std::io::Error::other("retry me"))
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_max_retries_exceeded() {
let backoff = ExponentialBackoff::builder()
.max_retries(2)
.initial_delay(Duration::from_millis(1))
.jitter(0.0)
.build();
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result = backoff
.execute(|| {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(std::io::Error::other("always fail"))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[test]
fn test_jitter_variation() {
let backoff = ExponentialBackoff {
max_retries: 10,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
multiplier: 2.0,
jitter: 0.5, };
let mut delays = Vec::new();
for _ in 0..20 {
delays.push(backoff.next_delay(0).unwrap());
}
for delay in &delays {
let millis = delay.as_millis();
assert!(
(500..=1500).contains(&millis),
"Delay with 50% jitter should be in range [500ms, 1500ms], got {}ms",
millis
);
}
let all_same = delays.windows(2).all(|w| w[0] == w[1]);
assert!(!all_same, "With randomization, delays should vary");
}
#[test]
fn test_builder_defaults() {
let backoff = ExponentialBackoff::builder().build();
assert_eq!(backoff.max_retries, 3);
assert_eq!(backoff.initial_delay, Duration::from_millis(100));
assert_eq!(backoff.max_delay, Duration::from_secs(60));
assert_eq!(backoff.multiplier, 2.0);
assert_eq!(backoff.jitter, 0.1);
}
#[test]
fn test_builder_custom_values() {
let backoff = ExponentialBackoff::builder()
.max_retries(5)
.initial_delay(Duration::from_millis(200))
.max_delay(Duration::from_secs(30))
.multiplier(1.5)
.jitter(0.2)
.build();
assert_eq!(backoff.max_retries, 5);
assert_eq!(backoff.initial_delay, Duration::from_millis(200));
assert_eq!(backoff.max_delay, Duration::from_secs(30));
assert_eq!(backoff.multiplier, 1.5);
assert_eq!(backoff.jitter, 0.2);
}
#[test]
fn test_jitter_clamped() {
let backoff = ExponentialBackoff::builder().jitter(2.0).build();
assert_eq!(backoff.jitter, 1.0);
let backoff = ExponentialBackoff::builder().jitter(-0.5).build();
assert_eq!(backoff.jitter, 0.0);
}
#[tokio::test]
async fn test_immediate_success() {
let backoff = ExponentialBackoff::default();
let result = backoff
.execute(|| async { Ok::<_, std::io::Error>(42) })
.await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_custom_retry_predicate() {
struct NetworkOnlyBackoff {
inner: ExponentialBackoff,
}
#[async_trait]
impl BackoffStrategy for NetworkOnlyBackoff {
async fn execute<F, Fut, T, E>(&self, operation: F) -> Result<T, E>
where
F: Fn() -> Fut + Send + Sync,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Error + Send + Sync + 'static,
{
let mut attempt = 0;
loop {
match operation().await {
Ok(result) => return Ok(result),
Err(err) if !self.should_retry(&err, attempt) => return Err(err),
Err(err) if attempt >= self.max_retries() => return Err(err),
Err(_) => {
if let Some(delay) = self.next_delay(attempt) {
tokio::time::sleep(delay).await;
}
attempt += 1;
}
}
}
}
fn should_retry(&self, error: &dyn Error, _attempt: u32) -> bool {
error.to_string().contains("network")
}
fn next_delay(&self, attempt: u32) -> Option<Duration> {
self.inner.next_delay(attempt)
}
fn max_retries(&self) -> u32 {
self.inner.max_retries()
}
}
let backoff = NetworkOnlyBackoff {
inner: ExponentialBackoff::builder()
.max_retries(5)
.initial_delay(Duration::from_millis(1))
.build(),
};
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = Arc::clone(&attempts);
let result = backoff
.execute(|| {
let attempts = Arc::clone(&attempts_clone);
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(std::io::Error::other("auth failed"))
}
})
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1);
attempts.store(0, Ordering::SeqCst);
let result = backoff
.execute(|| {
let attempts = Arc::clone(&attempts_clone);
async move {
let current = attempts.fetch_add(1, Ordering::SeqCst);
if current < 2 {
Err(std::io::Error::other("network error"))
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
}