use std::time::Duration;
#[derive(Debug, Clone)]
pub struct BackoffPolicy {
pub max_attempts: u32,
pub initial: Duration,
pub multiplier: f64,
pub max: Duration,
}
impl Default for BackoffPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial: Duration::from_millis(100),
multiplier: 2.0,
max: Duration::from_secs(1),
}
}
}
impl BackoffPolicy {
pub fn disabled() -> Self {
Self {
max_attempts: 1,
..Self::default()
}
}
}
pub async fn with_backoff<T, E, F, Fut>(
policy: &BackoffPolicy,
is_retryable: impl Fn(&E) -> bool,
on_retry: impl Fn(u32, &E, Duration),
mut op: F,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
{
let mut attempt: u32 = 1;
let mut backoff = policy.initial;
loop {
match op().await {
Ok(v) => return Ok(v),
Err(err) if !is_retryable(&err) => return Err(err),
Err(err) if attempt >= policy.max_attempts => return Err(err),
Err(err) => {
on_retry(attempt, &err, backoff);
tokio::time::sleep(backoff).await;
let scaled = backoff.as_secs_f64() * policy.multiplier;
let scaled_dur = Duration::from_secs_f64(scaled);
backoff = if scaled_dur > policy.max {
policy.max
} else {
scaled_dur
};
attempt += 1;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
fn fast(max_attempts: u32) -> BackoffPolicy {
BackoffPolicy {
max_attempts,
initial: Duration::from_micros(10),
multiplier: 2.0,
max: Duration::from_millis(1),
}
}
#[tokio::test]
async fn ok_on_first_attempt_runs_op_exactly_once() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let result: Result<u32, &'static str> = with_backoff(
&fast(5),
|_| true,
|_, _, _| {},
|| {
let calls = Arc::clone(&calls_for_op);
async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(42)
}
},
)
.await;
assert_eq!(result, Ok(42));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn permanent_error_short_circuits_attempt_one() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let result: Result<(), &'static str> = with_backoff(
&fast(5),
|_| false,
|_, _, _| panic!("on_retry must not fire for permanent errors"),
|| {
let calls = Arc::clone(&calls_for_op);
async move {
calls.fetch_add(1, Ordering::SeqCst);
Err("fatal")
}
},
)
.await;
assert_eq!(result, Err("fatal"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn transient_then_success_retries_then_returns_ok() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let result: Result<&'static str, &'static str> = with_backoff(
&fast(3),
|_| true,
|_, _, _| {},
|| {
let calls = Arc::clone(&calls_for_op);
async move {
let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
if n < 3 { Err("blip") } else { Ok("ok") }
}
},
)
.await;
assert_eq!(result, Ok("ok"));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn budget_exhausted_returns_last_error() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let result: Result<(), String> = with_backoff(
&fast(4),
|_| true,
|_, _, _| {},
|| {
let calls = Arc::clone(&calls_for_op);
async move {
let n = calls.fetch_add(1, Ordering::SeqCst) + 1;
Err(format!("attempt #{n}"))
}
},
)
.await;
assert_eq!(result, Err("attempt #4".to_owned()));
assert_eq!(calls.load(Ordering::SeqCst), 4);
}
#[tokio::test]
async fn disabled_policy_runs_exactly_once_even_on_transient() {
let calls = Arc::new(AtomicUsize::new(0));
let calls_for_op = Arc::clone(&calls);
let _: Result<(), &'static str> = with_backoff(
&BackoffPolicy::disabled(),
|_| true,
|_, _, _| {},
|| {
let calls = Arc::clone(&calls_for_op);
async move {
calls.fetch_add(1, Ordering::SeqCst);
Err("x")
}
},
)
.await;
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn backoff_growth_is_clamped_by_max() {
let policy = BackoffPolicy {
max_attempts: 4,
initial: Duration::from_millis(1),
multiplier: 10.0,
max: Duration::from_millis(2),
};
let start = std::time::Instant::now();
let _: Result<(), &'static str> = with_backoff(
&policy,
|_| true,
|_, _, _| {},
|| async move { Err("blip") },
)
.await;
assert!(
start.elapsed() < Duration::from_millis(55),
"backoff appears unbounded: elapsed={:?}",
start.elapsed()
);
}
}