use crate::error::{Error, Result};
use std::future::Future;
use std::time::Duration;
use tokio::time::{Instant, timeout};
pub const OPERATION_TIMEOUT: Duration = Duration::from_secs(60);
const INITIAL_BACKOFF: Duration = Duration::from_millis(100);
const MAX_BACKOFF: Duration = Duration::from_secs(5);
const BACKOFF_MULTIPLIER: u32 = 2;
pub async fn with_retry_and_timeout<F, Fut, T>(op: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T>>,
{
let deadline = Instant::now() + OPERATION_TIMEOUT;
let mut backoff = INITIAL_BACKOFF;
let mut attempt: u32 = 0;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(Error::timeout("operation exceeded 60s"));
}
let result = timeout(remaining, op()).await;
match result {
Ok(Ok(t)) => return Ok(t),
Ok(Err(e)) => {
if e.is_retryable() && Instant::now() < deadline {
let sleep_duration = sleep_duration_with_jitter(backoff, attempt);
tokio::time::sleep(sleep_duration).await;
backoff = (backoff * BACKOFF_MULTIPLIER).min(MAX_BACKOFF);
attempt += 1;
continue;
}
return Err(e);
}
Err(_) => {
return Err(Error::timeout("operation exceeded 60s"));
}
}
}
}
const fn sleep_duration_with_jitter(base: Duration, attempt: u32) -> Duration {
let ms = base.as_millis() as u64;
let jitter_pct = (attempt % 25) as u64;
let jitter_ms = (ms * jitter_pct) / 100;
Duration::from_millis(ms.saturating_add(jitter_ms))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn succeeds_immediately() {
let result = with_retry_and_timeout(|| async { Ok::<_, Error>(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn returns_permanent_error_without_retry() {
let result =
with_retry_and_timeout(|| async { Err::<i32, _>(Error::invalid_input("bad input")) })
.await;
assert!(result.is_err());
assert!(!result.unwrap_err().is_retryable());
}
#[tokio::test]
async fn retries_then_succeeds() {
let attempts = Arc::new(AtomicU32::new(0));
let result = with_retry_and_timeout(|| {
let attempts = Arc::clone(&attempts);
async move {
let n = attempts.fetch_add(1, Ordering::SeqCst) + 1;
if n < 3 {
Err(Error::api("transient"))
} else {
Ok(7)
}
}
})
.await;
assert_eq!(result.unwrap(), 7);
assert!(attempts.load(Ordering::SeqCst) >= 3);
}
}