use std::{future::Future, time::Duration};
use tokio::time::sleep;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RetryDecision {
Retry,
DoNotRetry,
}
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub initial_backoff: Duration,
pub backoff_multiplier: u32,
pub max_backoff: Duration,
}
impl RetryPolicy {
pub const S3_DEFAULT: Self = Self {
max_attempts: 4,
initial_backoff: Duration::from_millis(50),
backoff_multiplier: 2,
max_backoff: Duration::from_secs(5),
};
}
pub async fn retry_with<F, Fut, T, E>(
policy: RetryPolicy,
classify: impl Fn(&E) -> RetryDecision,
mut op: F,
) -> std::result::Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = std::result::Result<T, E>>,
{
let mut delay = policy.initial_backoff;
let max_attempts = policy.max_attempts.max(1);
for attempt in 0..max_attempts {
match op().await {
Ok(v) => return Ok(v),
Err(err) if attempt + 1 < max_attempts && classify(&err) == RetryDecision::Retry => {
sleep(delay).await;
delay = delay
.saturating_mul(policy.backoff_multiplier)
.min(policy.max_backoff);
}
Err(err) => return Err(err),
}
}
unreachable!("retry loop always returns success or error")
}
pub fn classify_transient_io(error: &std::io::Error) -> RetryDecision {
let message = error.to_string().to_ascii_lowercase();
let is_transient = [
"500",
"502",
"503",
"504",
"internal server error",
"service unavailable",
"slowdown",
"throttl",
"connection reset",
"connection aborted",
"broken pipe",
"timeout",
"timed out",
]
.iter()
.any(|pattern| message.contains(pattern));
if is_transient {
RetryDecision::Retry
} else {
RetryDecision::DoNotRetry
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[tokio::test]
async fn retries_until_success() {
let calls = AtomicUsize::new(0);
let result: Result<&str, &str> = retry_with(
RetryPolicy {
max_attempts: 4,
initial_backoff: Duration::from_millis(1),
backoff_multiplier: 2,
max_backoff: Duration::from_millis(10),
},
|_| RetryDecision::Retry,
|| async {
let n = calls.fetch_add(1, Ordering::SeqCst);
if n < 2 { Err("flaky") } else { Ok("ok") }
},
)
.await;
assert_eq!(result, Ok("ok"));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn gives_up_after_max_attempts() {
let calls = AtomicUsize::new(0);
let result: Result<&str, &str> = retry_with(
RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
backoff_multiplier: 2,
max_backoff: Duration::from_millis(10),
},
|_| RetryDecision::Retry,
|| async {
calls.fetch_add(1, Ordering::SeqCst);
Err::<&str, _>("always fails")
},
)
.await;
assert_eq!(result, Err("always fails"));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn non_retryable_returns_immediately() {
let calls = AtomicUsize::new(0);
let result: Result<&str, &str> = retry_with(
RetryPolicy {
max_attempts: 4,
initial_backoff: Duration::from_millis(1),
backoff_multiplier: 2,
max_backoff: Duration::from_millis(10),
},
|_| RetryDecision::DoNotRetry,
|| async {
calls.fetch_add(1, Ordering::SeqCst);
Err::<&str, _>("permanent")
},
)
.await;
assert_eq!(result, Err("permanent"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn classifies_http_5xx_as_transient() {
let err = std::io::Error::other("503 Service Unavailable");
assert_eq!(classify_transient_io(&err), RetryDecision::Retry);
}
#[test]
fn classifies_local_io_as_permanent() {
let err = std::io::Error::from(std::io::ErrorKind::PermissionDenied);
assert_eq!(classify_transient_io(&err), RetryDecision::DoNotRetry);
}
}