use std::ops::ControlFlow;
use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl RetryPolicy {
pub const UPLOAD: RetryPolicy = RetryPolicy {
max_attempts: 10,
base_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(30),
};
pub fn delay_for(&self, next_attempt: u32) -> Duration {
let exp = next_attempt.saturating_sub(2);
let mult = 1u64.checked_shl(exp).unwrap_or(u64::MAX);
let ms = (self.base_delay.as_millis() as u64).saturating_mul(mult);
std::cmp::min(Duration::from_millis(ms), self.max_delay)
}
}
pub fn retry_sync<T, E, F>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
where
F: FnMut(u32) -> Result<T, ControlFlow<E, E>>,
{
let max = policy.max_attempts.max(1);
let mut attempt: u32 = 1;
loop {
if attempt > 1 {
std::thread::sleep(policy.delay_for(attempt));
}
match op(attempt) {
Ok(v) => return Ok(v),
Err(ControlFlow::Break(e)) => return Err(e),
Err(ControlFlow::Continue(e)) => {
if attempt >= max {
return Err(e);
}
}
}
attempt += 1;
}
}
pub async fn retry_async<T, E, F, Fut>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Result<T, ControlFlow<E, E>>>,
{
let max = policy.max_attempts.max(1);
let mut attempt: u32 = 1;
loop {
if attempt > 1 {
tokio::time::sleep(policy.delay_for(attempt)).await;
}
match op(attempt).await {
Ok(v) => return Ok(v),
Err(ControlFlow::Break(e)) => return Err(e),
Err(ControlFlow::Continue(e)) => {
if attempt >= max {
return Err(e);
}
}
}
attempt += 1;
}
}
pub fn classify_http_sync(
result: reqwest::Result<reqwest::blocking::Response>,
) -> Result<reqwest::blocking::Response, ControlFlow<anyhow::Error, anyhow::Error>> {
use anyhow::anyhow;
match result {
Ok(resp) => {
let status = resp.status();
if status.is_success() || status.is_redirection() {
Ok(resp)
} else if status.is_server_error() {
Err(ControlFlow::Continue(anyhow!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("server error")
)))
} else {
Err(ControlFlow::Break(anyhow!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("client error")
)))
}
}
Err(e) => Err(ControlFlow::Continue(anyhow!(e))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
fn fast_policy() -> RetryPolicy {
RetryPolicy {
max_attempts: 4,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(5),
}
}
#[test]
fn delay_progression_caps_at_max() {
let p = RetryPolicy {
max_attempts: 10,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(500),
};
assert_eq!(p.delay_for(2), Duration::from_millis(100));
assert_eq!(p.delay_for(3), Duration::from_millis(200));
assert_eq!(p.delay_for(4), Duration::from_millis(400));
assert_eq!(p.delay_for(5), Duration::from_millis(500)); assert_eq!(p.delay_for(8), Duration::from_millis(500)); }
#[test]
fn sync_succeeds_on_first_attempt() {
let calls = AtomicU32::new(0);
let result: Result<&str, ()> = retry_sync(&fast_policy(), |_| {
calls.fetch_add(1, Ordering::SeqCst);
Ok("ok")
});
assert_eq!(result, Ok("ok"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn sync_retries_until_success() {
let calls = AtomicU32::new(0);
let result: Result<u32, &str> = retry_sync(&fast_policy(), |attempt| {
calls.fetch_add(1, Ordering::SeqCst);
if attempt < 3 {
Err(ControlFlow::Continue("transient"))
} else {
Ok(attempt)
}
});
assert_eq!(result, Ok(3));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[test]
fn sync_break_stops_immediately() {
let calls = AtomicU32::new(0);
let result: Result<(), &str> = retry_sync(&fast_policy(), |_| {
calls.fetch_add(1, Ordering::SeqCst);
Err(ControlFlow::Break("fatal"))
});
assert_eq!(result, Err("fatal"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn sync_returns_last_error_after_exhaustion() {
let calls = AtomicU32::new(0);
let result: Result<(), String> = retry_sync(&fast_policy(), |attempt| {
calls.fetch_add(1, Ordering::SeqCst);
Err(ControlFlow::Continue(format!("fail {attempt}")))
});
assert_eq!(result, Err("fail 4".to_string()));
assert_eq!(calls.load(Ordering::SeqCst), 4);
}
#[tokio::test]
async fn async_retries_until_success() {
let calls = std::sync::Arc::new(AtomicU32::new(0));
let calls_inner = calls.clone();
let result: Result<u32, &str> = retry_async(&fast_policy(), move |attempt| {
let c = calls_inner.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
if attempt < 2 {
Err(ControlFlow::Continue("transient"))
} else {
Ok(attempt)
}
}
})
.await;
assert_eq!(result, Ok(2));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
}