#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use snapdir_core::StoreError;
use crate::transfer::RateLimiter;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base: Duration,
pub cap: Duration,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 5,
base: Duration::from_millis(250),
cap: Duration::from_secs(30),
}
}
}
impl RetryPolicy {
#[must_use]
fn backoff(&self, n: u32, server_hint: Option<Duration>, jitter01: f64) -> Duration {
let cap = self.cap;
let base_secs = self.base.as_secs_f64();
let cap_secs = cap.as_secs_f64();
let exp_secs = if n >= 1024 {
cap_secs
} else {
let factor = 2f64.powi(n as i32);
(base_secs * factor).min(cap_secs)
};
let exp_secs = if exp_secs.is_finite() {
exp_secs.clamp(0.0, cap_secs)
} else {
cap_secs
};
let frac = if jitter01.is_finite() {
jitter01.clamp(0.0, 1.0)
} else {
0.0
};
let jittered = Duration::from_secs_f64(exp_secs * frac);
match server_hint {
Some(hint) => hint.max(jittered),
None => jittered,
}
}
}
pub trait Jitter {
fn jitter01(&self) -> f64;
}
static JITTER_SEED_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug)]
pub struct DefaultJitter {
state: AtomicU64,
}
impl DefaultJitter {
#[must_use]
pub fn new() -> Self {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.as_nanos() as u64);
let counter = JITTER_SEED_COUNTER.fetch_add(1, Ordering::Relaxed);
let seed = nanos ^ splitmix64_mix(counter.wrapping_add(0x9E37_79B9_7F4A_7C15));
Self {
state: AtomicU64::new(seed),
}
}
}
impl Default for DefaultJitter {
fn default() -> Self {
Self::new()
}
}
impl Jitter for DefaultJitter {
fn jitter01(&self) -> f64 {
let z = self
.state
.fetch_add(0x9E37_79B9_7F4A_7C15, Ordering::Relaxed)
.wrapping_add(0x9E37_79B9_7F4A_7C15);
let bits = splitmix64_mix(z);
u64_to_unit_f64(bits)
}
}
#[inline]
fn splitmix64_mix(mut z: u64) -> u64 {
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[inline]
fn u64_to_unit_f64(bits: u64) -> f64 {
((bits >> 11) as f64) / (1u64 << 53) as f64
}
#[derive(Clone, Copy, Debug)]
pub struct FixedJitter(pub f64);
impl Jitter for FixedJitter {
fn jitter01(&self) -> f64 {
self.0.clamp(0.0, 1.0 - f64::EPSILON)
}
}
pub trait AsyncSleeper {
fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send;
}
pub trait BlockingSleeper {
fn sleep(&self, dur: Duration);
}
#[derive(Clone, Copy, Debug, Default)]
pub struct TokioSleeper;
impl AsyncSleeper for TokioSleeper {
fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send {
tokio::time::sleep(dur)
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct ThreadSleeper;
impl BlockingSleeper for ThreadSleeper {
fn sleep(&self, dur: Duration) {
std::thread::sleep(dur);
}
}
#[derive(Debug)]
pub struct Attempt {
pub err: StoreError,
pub transient: bool,
pub retry_after: Option<Duration>,
}
#[inline]
fn clamped_max_attempts(policy: &RetryPolicy) -> u32 {
policy.max_attempts.max(1)
}
pub async fn retry_async<T, F, Fut>(
policy: &RetryPolicy,
sleeper: &impl AsyncSleeper,
jitter: &impl Jitter,
mut op: F,
) -> Result<T, StoreError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, Attempt>>,
{
let max = clamped_max_attempts(policy);
let mut n: u32 = 0;
loop {
match op().await {
Ok(value) => return Ok(value),
Err(attempt) => {
let attempts_used = n + 1;
if attempt.transient && attempts_used < max {
let delay = policy.backoff(n, attempt.retry_after, jitter.jitter01());
sleeper.sleep(delay).await;
n += 1;
} else {
return Err(attempt.err);
}
}
}
}
}
pub fn retry_blocking<T, F>(
policy: &RetryPolicy,
sleeper: &impl BlockingSleeper,
jitter: &impl Jitter,
mut op: F,
) -> Result<T, StoreError>
where
F: FnMut() -> Result<T, Attempt>,
{
let max = clamped_max_attempts(policy);
let mut n: u32 = 0;
loop {
match op() {
Ok(value) => return Ok(value),
Err(attempt) => {
let attempts_used = n + 1;
if attempt.transient && attempts_used < max {
let delay = policy.backoff(n, attempt.retry_after, jitter.jitter01());
sleeper.sleep(delay);
n += 1;
} else {
return Err(attempt.err);
}
}
}
}
}
pub async fn retry_network<T, F, Fut>(
policy: &RetryPolicy,
req_limiter: &RateLimiter,
sleeper: &impl AsyncSleeper,
jitter: &impl Jitter,
mut op: F,
) -> Result<T, StoreError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, Attempt>>,
{
let max = clamped_max_attempts(policy);
let mut n: u32 = 0;
loop {
req_limiter.acquire(1).await;
match op().await {
Ok(value) => return Ok(value),
Err(attempt) => {
let attempts_used = n + 1;
if attempt.transient && attempts_used < max {
let delay = policy.backoff(n, attempt.retry_after, jitter.jitter01());
sleeper.sleep(delay).await;
n += 1;
} else {
return Err(attempt.err);
}
}
}
}
}
#[must_use]
pub fn parse_retry_after(value: &str) -> Option<Duration> {
let trimmed = value.trim();
trimmed.parse::<u64>().ok().map(Duration::from_secs)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
use std::sync::Mutex;
#[derive(Default)]
struct RecordingSleeper {
delays: Mutex<Vec<Duration>>,
}
impl RecordingSleeper {
fn recorded(&self) -> Vec<Duration> {
self.delays.lock().unwrap().clone()
}
}
impl AsyncSleeper for RecordingSleeper {
fn sleep(&self, dur: Duration) -> impl Future<Output = ()> + Send {
self.delays.lock().unwrap().push(dur);
std::future::ready(())
}
}
impl BlockingSleeper for RecordingSleeper {
fn sleep(&self, dur: Duration) {
self.delays.lock().unwrap().push(dur);
}
}
struct StubJitter(f64);
impl Jitter for StubJitter {
fn jitter01(&self) -> f64 {
self.0
}
}
fn boom() -> StoreError {
StoreError::Backend {
message: "boom".into(),
source: None,
}
}
fn transient_attempt(retry_after: Option<Duration>) -> Attempt {
Attempt {
err: boom(),
transient: true,
retry_after,
}
}
fn hard_attempt() -> Attempt {
Attempt {
err: boom(),
transient: false,
retry_after: None,
}
}
fn runtime() -> tokio::runtime::Runtime {
tokio::runtime::Builder::new_current_thread()
.enable_time()
.build()
.expect("build tokio runtime")
}
fn small_policy() -> RetryPolicy {
RetryPolicy {
max_attempts: 5,
base: Duration::from_millis(100),
cap: Duration::from_secs(10),
}
}
#[test]
fn retry_async_persistent_transient_uses_full_attempt_budget() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.5);
let calls = AtomicUsize::new(0);
let result: Result<(), StoreError> = retry_async(&policy, &sleeper, &jitter, || {
calls.fetch_add(1, AtomicOrdering::SeqCst);
async { Err(transient_attempt(None)) }
})
.await;
assert!(result.is_err(), "persistent transient => surfaces the err");
assert_eq!(
calls.load(AtomicOrdering::SeqCst),
policy.max_attempts as usize,
"op invoked exactly max_attempts times"
);
assert_eq!(
sleeper.recorded().len(),
(policy.max_attempts - 1) as usize,
"max_attempts-1 sleeps recorded (no sleep after the last attempt)"
);
});
}
#[test]
fn retry_async_hard_error_returns_immediately_without_sleeping() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.5);
let calls = AtomicUsize::new(0);
let result: Result<(), StoreError> = retry_async(&policy, &sleeper, &jitter, || {
calls.fetch_add(1, AtomicOrdering::SeqCst);
async { Err(hard_attempt()) }
})
.await;
assert!(result.is_err(), "hard error surfaces");
assert_eq!(
calls.load(AtomicOrdering::SeqCst),
1,
"non-transient error => op called exactly once"
);
assert!(
sleeper.recorded().is_empty(),
"no sleeps for a non-transient error"
);
});
}
#[test]
fn retry_async_success_after_k_transient_fails() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.25);
let calls = AtomicUsize::new(0);
let k = 3usize;
let result: Result<u32, StoreError> = retry_async(&policy, &sleeper, &jitter, || {
let prev = calls.fetch_add(1, AtomicOrdering::SeqCst);
async move {
if prev < k {
Err(transient_attempt(None))
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42, "succeeds on the (k+1)-th attempt");
assert_eq!(
calls.load(AtomicOrdering::SeqCst),
k + 1,
"op invoked k+1 times"
);
assert_eq!(
sleeper.recorded().len(),
k,
"exactly k sleeps recorded (one before each retry)"
);
});
}
#[test]
fn retry_async_retry_after_is_a_floor() {
let rt = runtime();
rt.block_on(async {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.0);
let hint = Duration::from_secs(5);
let calls = AtomicUsize::new(0);
let _result: Result<(), StoreError> = retry_async(&policy, &sleeper, &jitter, || {
calls.fetch_add(1, AtomicOrdering::SeqCst);
async move { Err(transient_attempt(Some(hint))) }
})
.await;
let recorded = sleeper.recorded();
assert!(!recorded.is_empty(), "at least one retry happened");
for d in &recorded {
assert!(
*d >= hint,
"recorded delay {d:?} must be >= the server hint {hint:?}"
);
}
});
}
#[test]
fn retry_async_cap_respected_for_large_n() {
let rt = runtime();
rt.block_on(async {
let policy = RetryPolicy {
max_attempts: 12,
base: Duration::from_millis(250),
cap: Duration::from_secs(2),
};
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.999_999);
let _result: Result<(), StoreError> =
retry_async(&policy, &sleeper, &jitter, || async {
Err(transient_attempt(None))
})
.await;
for d in sleeper.recorded() {
assert!(
d <= policy.cap,
"delay {d:?} must never exceed cap {:?} (no hint)",
policy.cap
);
}
});
}
#[test]
fn backoff_full_jitter_lands_in_envelope() {
let policy = RetryPolicy {
max_attempts: 10,
base: Duration::from_millis(100),
cap: Duration::from_secs(30),
};
for n in 0..8u32 {
let frac = 0.37;
let delay = policy.backoff(n, None, frac);
let exp = policy
.base
.as_secs_f64()
.mul_add(2f64.powi(n as i32), 0.0)
.min(policy.cap.as_secs_f64());
let expected = exp * frac;
assert!(
(delay.as_secs_f64() - expected).abs() < 1e-9,
"n={n}: delay {delay:?} != jitter01*envelope {expected}"
);
assert!(
delay.as_secs_f64() >= 0.0 && delay.as_secs_f64() < exp + 1e-9,
"n={n}: delay {delay:?} outside [0, {exp})"
);
}
}
#[test]
fn backoff_saturates_at_cap_for_huge_n() {
let policy = RetryPolicy {
max_attempts: 99,
base: Duration::from_millis(250),
cap: Duration::from_secs(30),
};
let d = policy.backoff(2000, None, 0.999_999);
assert!(
d <= policy.cap,
"huge n must saturate at the cap, got {d:?}"
);
assert!(
d.as_secs_f64() > 29.0,
"with jitter ~1 the delay should be near the cap, got {d:?}"
);
}
#[test]
fn default_jitter_is_in_unit_interval() {
let j = DefaultJitter::new();
for _ in 0..10_000 {
let x = j.jitter01();
assert!((0.0..1.0).contains(&x), "jitter {x} outside [0,1)");
}
let a = DefaultJitter::new();
let b = DefaultJitter::new();
let sa: Vec<f64> = (0..4).map(|_| a.jitter01()).collect();
let sb: Vec<f64> = (0..4).map(|_| b.jitter01()).collect();
assert_ne!(sa, sb, "distinct instances should diverge");
}
#[test]
fn retry_blocking_persistent_transient_uses_full_budget() {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.5);
let calls = AtomicUsize::new(0);
let result: Result<(), StoreError> = retry_blocking(&policy, &sleeper, &jitter, || {
calls.fetch_add(1, AtomicOrdering::SeqCst);
Err(transient_attempt(None))
});
assert!(result.is_err());
assert_eq!(
calls.load(AtomicOrdering::SeqCst),
policy.max_attempts as usize,
"blocking: op invoked max_attempts times"
);
assert_eq!(
sleeper.recorded().len(),
(policy.max_attempts - 1) as usize,
"blocking: max_attempts-1 sleeps"
);
}
#[test]
fn retry_blocking_success_after_k_transient() {
let policy = small_policy();
let sleeper = RecordingSleeper::default();
let jitter = StubJitter(0.5);
let calls = AtomicUsize::new(0);
let k = 2usize;
let result: Result<&str, StoreError> = retry_blocking(&policy, &sleeper, &jitter, || {
let prev = calls.fetch_add(1, AtomicOrdering::SeqCst);
if prev < k {
Err(transient_attempt(None))
} else {
Ok("done")
}
});
assert_eq!(result.unwrap(), "done");
assert_eq!(calls.load(AtomicOrdering::SeqCst), k + 1);
assert_eq!(sleeper.recorded().len(), k, "blocking: k sleeps");
}
#[test]
fn default_policy_values() {
let p = RetryPolicy::default();
assert_eq!(p.max_attempts, 5);
assert_eq!(p.base, Duration::from_millis(250));
assert_eq!(p.cap, Duration::from_secs(30));
}
}