use std::{future::Future, time::Duration};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: u32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(5),
backoff_multiplier: 2,
}
}
}
pub async fn retry_with_backoff<T, E, F, Fut, C>(
config: &RetryConfig,
operation_name: &str,
mut operation: F,
is_retryable: C,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
C: Fn(&E) -> bool,
E: std::fmt::Display,
{
let mut backoff = config.initial_backoff;
for attempt in 0..=config.max_retries {
match operation().await {
Ok(val) => {
if attempt > 0 {
tracing::debug!(operation = operation_name, attempt, "succeeded after retry");
}
return Ok(val);
}
Err(e) if attempt < config.max_retries && is_retryable(&e) => {
tracing::warn!(
operation = operation_name,
error = %e,
attempt = attempt + 1,
max_retries = config.max_retries,
backoff_ms = backoff.as_millis() as u64,
"retryable error, backing off"
);
tokio::time::sleep(backoff).await;
backoff = std::cmp::min(backoff * config.backoff_multiplier, config.max_backoff);
}
Err(e) => return Err(e),
}
}
unreachable!("loop runs max_retries + 1 times and always returns")
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
#[tokio::test]
async fn succeeds_on_first_attempt() {
let config = RetryConfig {
max_retries: 3,
..Default::default()
};
let result: Result<u32, String> = retry_with_backoff(&config, "test", || async { Ok(42) }, |_| true).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn retries_on_transient_error() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig {
max_retries: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
backoff_multiplier: 2,
};
let result: Result<u32, String> = retry_with_backoff(
&config,
"test",
|| {
let a = attempts_clone.clone();
async move {
let n = a.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Err("transient".to_string())
} else {
Ok(99)
}
}
},
|_| true,
)
.await;
assert_eq!(result.unwrap(), 99);
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn fails_immediately_on_non_retryable() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig {
max_retries: 5,
initial_backoff: Duration::from_millis(1),
..Default::default()
};
let result: Result<u32, String> = retry_with_backoff(
&config,
"test",
|| {
let a = attempts_clone.clone();
async move {
a.fetch_add(1, Ordering::SeqCst);
Err("permanent".to_string())
}
},
|_| false, )
.await;
assert_eq!(result.unwrap_err(), "permanent");
assert_eq!(attempts.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn exhausts_retries_and_returns_last_error() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig {
max_retries: 2,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
backoff_multiplier: 2,
};
let result: Result<u32, String> = retry_with_backoff(
&config,
"test",
|| {
let a = attempts_clone.clone();
async move {
let n = a.fetch_add(1, Ordering::SeqCst);
Err(format!("fail-{}", n))
}
},
|_| true,
)
.await;
assert_eq!(result.unwrap_err(), "fail-2");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
}