use crate::error::FaucetError;
use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
const MAX_BACKOFF: Duration = Duration::from_secs(60);
pub async fn execute_with_retry<F, Fut, T>(
max_retries: u32,
base_backoff: Duration,
mut operation: F,
) -> Result<T, FaucetError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, FaucetError>>,
{
let mut attempt = 0u32;
loop {
match operation().await {
Ok(val) => return Ok(val),
Err(e) if e.is_retriable() && attempt < max_retries => {
let wait = backoff_with_jitter(base_backoff, attempt);
tracing::warn!(
"request failed (attempt {}/{}), retrying in {wait:?}: {e}",
attempt + 1,
max_retries + 1
);
tokio::time::sleep(wait).await;
attempt += 1;
}
Err(e) => return Err(e),
}
}
}
pub fn backoff_with_jitter(base: Duration, attempt: u32) -> Duration {
let exp = base
.saturating_mul(2u32.saturating_pow(attempt))
.min(MAX_BACKOFF);
let nanos = exp.as_nanos() as u64;
Duration::from_nanos((nanos as f64 * pseudo_random_factor()) as u64)
}
fn pseudo_random_factor() -> f64 {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
jitter_factor(decorrelate(nanos, counter))
}
fn decorrelate(nanos: u32, counter: u64) -> u32 {
let mut x = (nanos as u64) ^ counter.wrapping_mul(0x9E37_79B9_7F4A_7C15);
x ^= x >> 30;
x = x.wrapping_mul(0xBF58_476D_1CE4_E5B9);
x ^= x >> 27;
x = x.wrapping_mul(0x94D0_49BB_1331_11EB);
x ^= x >> 31;
(x % 1_000_000_000) as u32
}
fn jitter_factor(nanos: u32) -> f64 {
0.5 + (nanos as f64 / 1_000_000_000.0)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn returns_immediately_on_success() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let r = execute_with_retry(3, Duration::from_millis(1), move || {
c.fetch_add(1, Ordering::SeqCst);
async { Ok::<_, FaucetError>(7) }
})
.await;
assert_eq!(r.unwrap(), 7);
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retries_then_succeeds_on_transient_5xx() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let r = execute_with_retry(3, Duration::from_millis(1), move || {
let n = c.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err::<i32, _>(FaucetError::HttpStatus {
status: 503,
url: "http://t".into(),
body: "x".into(),
})
} else {
Ok(42)
}
}
})
.await;
assert_eq!(r.unwrap(), 42);
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[test]
fn jitter_factor_spans_documented_half_to_one_and_a_half_range() {
assert_eq!(jitter_factor(0), 0.5);
let mid = jitter_factor(500_000_000);
assert!((mid - 1.0).abs() < 1e-6, "midpoint factor was {mid}");
let hi = jitter_factor(999_999_999);
assert!(
(1.4..1.5).contains(&hi),
"factor at max sub-second nanos was {hi}, expected ~1.5"
);
}
#[test]
fn backoff_is_capped_for_large_attempt() {
let d = backoff_with_jitter(Duration::from_secs(1), 60);
assert!(d < Duration::from_secs(90), "backoff not capped: {d:?}");
assert!(
d >= Duration::from_secs(30),
"backoff unexpectedly tiny: {d:?}"
);
}
#[test]
fn decorrelate_diverges_for_same_nanos_concurrent_calls() {
let a = decorrelate(123_456_789, 0);
let b = decorrelate(123_456_789, 1);
let c = decorrelate(123_456_789, 2);
assert_ne!(a, b);
assert_ne!(b, c);
assert_ne!(a, c);
for v in [a, b, c] {
assert!(
v < 1_000_000_000,
"decorrelate out of jitter_factor range: {v}"
);
}
}
#[tokio::test]
async fn non_retriable_fails_immediately() {
let calls = Arc::new(AtomicU32::new(0));
let c = calls.clone();
let r = execute_with_retry(3, Duration::from_millis(1), move || {
c.fetch_add(1, Ordering::SeqCst);
async { Err::<i32, _>(FaucetError::Auth("nope".into())) }
})
.await;
assert!(r.is_err());
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
}