use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::time::sleep;
pub type RetryPredicate<E> = Arc<dyn Fn(&E) -> bool + Send + Sync>;
#[derive(Clone)]
pub struct RetryConfig<E = std::io::Error> {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub retry_if: Option<RetryPredicate<E>>,
}
impl<E> Default for RetryConfig<E> {
fn default() -> Self {
Self {
max_attempts: 3,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(5),
retry_if: None,
}
}
}
impl<E> std::fmt::Debug for RetryConfig<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryConfig")
.field("max_attempts", &self.max_attempts)
.field("base_delay", &self.base_delay)
.field("max_delay", &self.max_delay)
.field("retry_if", &self.retry_if.as_ref().map(|_| "<predicate>"))
.finish()
}
}
#[derive(Clone, Debug)]
pub struct Retry<E = std::io::Error> {
cfg: RetryConfig<E>,
seed: Arc<AtomicU64>,
}
impl<E> Retry<E> {
pub fn new(cfg: RetryConfig<E>) -> Self {
let seed = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0xdead_beef, |d| d.as_nanos() as u64);
Self {
cfg,
seed: Arc::new(AtomicU64::new(seed.wrapping_add(1))),
}
}
pub async fn run<F, Fut, T>(&self, mut make_fut: F) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
let mut attempt: u32 = 0;
loop {
attempt += 1;
match make_fut().await {
Ok(v) => return Ok(v),
Err(e) => {
if attempt >= self.cfg.max_attempts {
return Err(e);
}
if let Some(pred) = &self.cfg.retry_if {
if !pred(&e) {
return Err(e);
}
}
let delay = self.backoff(attempt);
sleep(delay).await;
}
}
}
}
fn backoff(&self, attempt: u32) -> Duration {
let exp = attempt.saturating_sub(1).min(30);
let raw = self.cfg.base_delay.saturating_mul(1u32 << exp);
let capped = raw.min(self.cfg.max_delay);
let max_ms = capped.as_millis().min(u128::from(u64::MAX)) as u64;
let jitter_ms = self.next_rand() % (max_ms + 1);
Duration::from_millis(jitter_ms)
}
fn next_rand(&self) -> u64 {
let mut x = self.seed.load(Ordering::Relaxed);
if x == 0 {
x = 0xdead_beef;
}
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
self.seed.store(x, Ordering::Relaxed);
x.wrapping_mul(0x2545_F491_4F6C_DD1D)
}
}