use std::time::Duration;
struct Xorshift64(u64);
impl Xorshift64 {
fn seeded() -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0xDEAD_BEEF_CAFE_BABEu64);
Self(if seed == 0 {
0xDEAD_BEEF_CAFE_BABEu64
} else {
seed
})
}
fn next(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next() >> 11) as f64 / (1u64 << 53) as f64
}
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub jitter_factor: f64,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay_ms: 100,
max_delay_ms: 5_000,
jitter_factor: 0.1,
}
}
}
impl RetryPolicy {
fn delay_for_attempt(&self, n: u32, rng: &mut Xorshift64) -> Duration {
let multiplier: u64 = 1u64.checked_shl(n).unwrap_or(u64::MAX);
let base: u64 = self.base_delay_ms.saturating_mul(multiplier);
let capped = base.min(self.max_delay_ms);
let factor = if self.jitter_factor <= 0.0 {
1.0_f64
} else {
let j = self.jitter_factor.min(1.0);
let r = rng.next_f64(); 1.0 - j + 2.0 * j * r
};
let ms = (capped as f64 * factor).max(0.0) as u64;
Duration::from_millis(ms)
}
}
pub trait ErrorClassification {
fn is_transient(&self) -> bool;
}
pub async fn retry_with_backoff<F, T, E, Fut>(mut op: F, policy: &RetryPolicy) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: ErrorClassification + std::fmt::Debug,
{
let mut rng = Xorshift64::seeded();
let max = policy.max_attempts.max(1);
for attempt in 0..max {
match op().await {
Ok(val) => return Ok(val),
Err(err) => {
let is_last = attempt + 1 >= max;
if is_last || !err.is_transient() {
return Err(err);
}
let delay = policy.delay_for_attempt(attempt, &mut rng);
tokio::time::sleep(delay).await;
}
}
}
op().await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq)]
enum TestError {
Transient,
Permanent,
}
impl std::fmt::Display for TestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestError::Transient => write!(f, "transient error"),
TestError::Permanent => write!(f, "permanent error"),
}
}
}
impl ErrorClassification for TestError {
fn is_transient(&self) -> bool {
matches!(self, TestError::Transient)
}
}
#[tokio::test]
async fn test_retry_succeeds_on_third_attempt() {
let call_count = Arc::new(Mutex::new(0u32));
let counter = Arc::clone(&call_count);
let policy = RetryPolicy {
max_attempts: 3,
base_delay_ms: 1,
max_delay_ms: 5,
jitter_factor: 0.0,
};
let result = retry_with_backoff(
|| {
let counter = Arc::clone(&counter);
async move {
let mut guard = counter.lock().expect("lock poisoned");
*guard += 1;
let n = *guard;
drop(guard);
if n < 3 {
Err(TestError::Transient)
} else {
Ok(n)
}
}
},
&policy,
)
.await;
assert!(result.is_ok(), "expected success on third attempt");
assert_eq!(result.expect("ok"), 3);
assert_eq!(*call_count.lock().expect("lock"), 3);
}
#[tokio::test]
async fn test_retry_permanent_error_not_retried() {
let call_count = Arc::new(Mutex::new(0u32));
let counter = Arc::clone(&call_count);
let policy = RetryPolicy {
max_attempts: 5,
base_delay_ms: 1,
max_delay_ms: 10,
jitter_factor: 0.0,
};
let result: Result<u32, TestError> = retry_with_backoff(
|| {
let counter = Arc::clone(&counter);
async move {
let mut guard = counter.lock().expect("lock poisoned");
*guard += 1;
Err(TestError::Permanent)
}
},
&policy,
)
.await;
assert_eq!(result, Err(TestError::Permanent));
assert_eq!(
*call_count.lock().expect("lock"),
1,
"permanent error must not be retried"
);
}
#[tokio::test]
async fn test_retry_respects_max_attempts() {
let call_count = Arc::new(Mutex::new(0u32));
let counter = Arc::clone(&call_count);
let policy = RetryPolicy {
max_attempts: 4,
base_delay_ms: 1,
max_delay_ms: 5,
jitter_factor: 0.0,
};
let result: Result<u32, TestError> = retry_with_backoff(
|| {
let counter = Arc::clone(&counter);
async move {
let mut guard = counter.lock().expect("lock poisoned");
*guard += 1;
Err(TestError::Transient)
}
},
&policy,
)
.await;
assert_eq!(result, Err(TestError::Transient));
assert_eq!(
*call_count.lock().expect("lock"),
policy.max_attempts,
"total calls must equal max_attempts"
);
}
#[tokio::test]
async fn test_retry_backoff_increases_exponentially() {
let call_count = Arc::new(Mutex::new(0u32));
let counter = Arc::clone(&call_count);
let policy = RetryPolicy {
max_attempts: 3,
base_delay_ms: 50,
max_delay_ms: 5_000,
jitter_factor: 0.0, };
let start = std::time::Instant::now();
let result: Result<u32, TestError> = retry_with_backoff(
|| {
let counter = Arc::clone(&counter);
async move {
let mut guard = counter.lock().expect("lock poisoned");
*guard += 1;
Err(TestError::Transient)
}
},
&policy,
)
.await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(
elapsed >= Duration::from_millis(148), "expected elapsed >= 150 ms, got {:?}",
elapsed
);
assert_eq!(*call_count.lock().expect("lock"), 3);
}
#[test]
fn test_xorshift64_non_zero() {
let mut rng = Xorshift64::seeded();
for _ in 0..10 {
assert_ne!(rng.next(), 0);
}
}
#[test]
fn test_xorshift64_f64_in_range() {
let mut rng = Xorshift64::seeded();
for _ in 0..1000 {
let v = rng.next_f64();
assert!((0.0..1.0).contains(&v), "out of range: {v}");
}
}
#[test]
fn test_delay_for_attempt_no_jitter() {
let policy = RetryPolicy {
max_attempts: 5,
base_delay_ms: 100,
max_delay_ms: 1_000,
jitter_factor: 0.0,
};
let mut rng = Xorshift64::seeded();
assert_eq!(
policy.delay_for_attempt(0, &mut rng),
Duration::from_millis(100)
);
assert_eq!(
policy.delay_for_attempt(1, &mut rng),
Duration::from_millis(200)
);
assert_eq!(
policy.delay_for_attempt(2, &mut rng),
Duration::from_millis(400)
);
assert_eq!(
policy.delay_for_attempt(3, &mut rng),
Duration::from_millis(800)
);
assert_eq!(
policy.delay_for_attempt(4, &mut rng),
Duration::from_millis(1_000)
);
}
}