use core::time::Duration;
const DEFAULT_MAX: Duration = Duration::from_secs(30);
const DEFAULT_BASE: Duration = Duration::from_millis(100);
const DEFAULT_FACTOR: f64 = 2.0;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Jitter {
None,
Full,
Equal,
#[default]
Decorrelated,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Kind {
Constant,
Linear,
Exponential,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Backoff {
kind: Kind,
base: Duration,
factor: f64,
increment: Duration,
max: Duration,
jitter: Jitter,
}
impl Backoff {
#[must_use]
pub fn constant(delay: Duration) -> Self {
Self {
kind: Kind::Constant,
base: delay,
factor: DEFAULT_FACTOR,
increment: Duration::ZERO,
max: delay,
jitter: Jitter::None,
}
}
#[must_use]
pub fn linear(initial: Duration, increment: Duration) -> Self {
Self {
kind: Kind::Linear,
base: initial,
factor: DEFAULT_FACTOR,
increment,
max: DEFAULT_MAX,
jitter: Jitter::None,
}
}
#[must_use]
pub fn exponential(initial: Duration, factor: f64) -> Self {
Self {
kind: Kind::Exponential,
base: initial,
factor,
increment: Duration::ZERO,
max: DEFAULT_MAX,
jitter: Jitter::None,
}
}
#[must_use]
pub fn with_max(mut self, max: Duration) -> Self {
self.max = max;
self
}
#[must_use]
pub fn with_jitter(mut self, jitter: Jitter) -> Self {
self.jitter = jitter;
self
}
#[must_use]
pub const fn max(&self) -> Duration {
self.max
}
#[must_use]
pub const fn jitter(&self) -> Jitter {
self.jitter
}
#[must_use]
pub fn iter(&self) -> BackoffIter {
self.iter_seeded(entropy_seed())
}
#[must_use]
pub fn iter_seeded(&self, seed: u64) -> BackoffIter {
BackoffIter {
policy: *self,
attempt: 0,
previous: self.base,
rng: Rng::new(seed),
}
}
}
impl Default for Backoff {
fn default() -> Self {
Self::exponential(DEFAULT_BASE, DEFAULT_FACTOR)
.with_max(DEFAULT_MAX)
.with_jitter(Jitter::Decorrelated)
}
}
#[derive(Debug, Clone)]
pub struct BackoffIter {
policy: Backoff,
attempt: u32,
previous: Duration,
rng: Rng,
}
impl BackoffIter {
pub fn next_delay(&mut self) -> Duration {
let max = dur_to_nanos(self.policy.max);
let nanos = if let Jitter::Decorrelated = self.policy.jitter {
self.decorrelated(max)
} else {
let capped = self.raw_nanos(self.attempt).min(max);
match self.policy.jitter {
Jitter::Full => self.rng.range(0, capped),
Jitter::Equal => {
let half = capped / 2;
half + self.rng.range(0, capped - half)
}
_ => capped,
}
};
self.attempt = self.attempt.saturating_add(1);
Duration::from_nanos(nanos)
}
fn raw_nanos(&self, attempt: u32) -> u64 {
let base = dur_to_nanos(self.policy.base);
match self.policy.kind {
Kind::Constant => base,
Kind::Linear => {
let inc = dur_to_nanos(self.policy.increment);
base.saturating_add(inc.saturating_mul(u64::from(attempt)))
}
Kind::Exponential => {
let exp = i32::try_from(attempt.min(64)).unwrap_or(64);
let scaled = (base as f64) * self.policy.factor.powi(exp);
if scaled.is_finite() && scaled >= 0.0 && scaled < (u64::MAX as f64) {
scaled as u64
} else {
u64::MAX
}
}
}
}
fn decorrelated(&mut self, max: u64) -> u64 {
let base = dur_to_nanos(self.policy.base);
let prev = dur_to_nanos(self.previous);
let hi = prev.saturating_mul(3).max(base).min(max);
let lo = base.min(hi);
let chosen = self.rng.range(lo, hi);
self.previous = Duration::from_nanos(chosen.max(base));
chosen
}
}
impl Iterator for BackoffIter {
type Item = Duration;
fn next(&mut self) -> Option<Duration> {
Some(self.next_delay())
}
}
#[inline]
fn dur_to_nanos(d: Duration) -> u64 {
u64::try_from(d.as_nanos()).unwrap_or(u64::MAX)
}
#[derive(Debug, Clone)]
struct Rng(u64);
impl Rng {
#[inline]
const fn new(seed: u64) -> Self {
Self(seed)
}
#[inline]
fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[inline]
fn range(&mut self, lo: u64, hi: u64) -> u64 {
if hi <= lo {
return lo;
}
let span = hi - lo + 1;
lo + (self.next_u64() % span)
}
}
fn entropy_seed() -> u64 {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_nanos() as u64);
let mut z = nanos ^ counter.wrapping_mul(0x9E37_79B9_7F4A_7C15);
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
#[cfg(test)]
mod tests {
use super::{Backoff, Jitter, Rng};
use core::time::Duration;
#[test]
fn test_constant_is_flat() {
let mut it = Backoff::constant(Duration::from_millis(50)).iter();
for _ in 0..5 {
assert_eq!(it.next_delay(), Duration::from_millis(50));
}
}
#[test]
fn test_linear_grows_by_increment_and_caps() {
let backoff = Backoff::linear(Duration::from_millis(100), Duration::from_millis(100))
.with_max(Duration::from_millis(250));
let mut it = backoff.iter();
assert_eq!(it.next_delay(), Duration::from_millis(100));
assert_eq!(it.next_delay(), Duration::from_millis(200));
assert_eq!(it.next_delay(), Duration::from_millis(250)); assert_eq!(it.next_delay(), Duration::from_millis(250));
}
#[test]
fn test_exponential_doubles_and_caps() {
let backoff = Backoff::exponential(Duration::from_millis(100), 2.0)
.with_max(Duration::from_millis(500));
let mut it = backoff.iter();
assert_eq!(it.next_delay(), Duration::from_millis(100));
assert_eq!(it.next_delay(), Duration::from_millis(200));
assert_eq!(it.next_delay(), Duration::from_millis(400));
assert_eq!(it.next_delay(), Duration::from_millis(500)); }
#[test]
fn test_full_jitter_stays_within_zero_and_cap() {
let backoff = Backoff::exponential(Duration::from_millis(100), 2.0)
.with_max(Duration::from_secs(10))
.with_jitter(Jitter::Full);
let mut it = backoff.iter_seeded(1);
for attempt in 0..6u32 {
let ceiling =
Duration::from_millis(100 * 2u64.pow(attempt)).min(Duration::from_secs(10));
let d = it.next_delay();
assert!(d <= ceiling, "{d:?} exceeded {ceiling:?}");
}
}
#[test]
fn test_equal_jitter_keeps_a_floor() {
let backoff = Backoff::constant(Duration::from_millis(1000)).with_jitter(Jitter::Equal);
let mut it = backoff.iter_seeded(42);
for _ in 0..20 {
let d = it.next_delay();
assert!(d >= Duration::from_millis(500), "{d:?} below the floor");
assert!(d <= Duration::from_millis(1000), "{d:?} above the cap");
}
}
#[test]
fn test_decorrelated_respects_base_and_cap() {
let backoff = Backoff::exponential(Duration::from_millis(100), 2.0)
.with_max(Duration::from_secs(2))
.with_jitter(Jitter::Decorrelated);
let mut it = backoff.iter_seeded(99);
for _ in 0..50 {
let d = it.next_delay();
assert!(d >= Duration::from_millis(100), "{d:?} below base");
assert!(d <= Duration::from_secs(2), "{d:?} above cap");
}
}
#[test]
fn test_seeded_sequences_are_reproducible() {
let backoff = Backoff::default();
let a: Vec<_> = {
let mut it = backoff.iter_seeded(123);
(0..8).map(|_| it.next_delay()).collect()
};
let b: Vec<_> = {
let mut it = backoff.iter_seeded(123);
(0..8).map(|_| it.next_delay()).collect()
};
assert_eq!(a, b);
}
#[test]
fn test_default_is_decorrelated_exponential() {
let backoff = Backoff::default();
assert_eq!(backoff.jitter(), Jitter::Decorrelated);
assert_eq!(backoff.max(), Duration::from_secs(30));
}
#[test]
fn test_rng_range_is_within_bounds_and_handles_degenerate() {
let mut rng = Rng::new(7);
for _ in 0..1000 {
let v = rng.range(10, 20);
assert!((10..=20).contains(&v));
}
assert_eq!(rng.range(5, 5), 5);
assert_eq!(rng.range(9, 4), 9); }
}