use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JitterStrategy {
None,
Full,
Equal,
}
#[derive(Debug, Clone)]
pub struct BackoffPolicy {
pub base_delay: Duration,
pub multiplier: f64,
pub max_delay: Duration,
pub max_attempts: Option<u32>,
pub jitter: JitterStrategy,
}
impl Default for BackoffPolicy {
fn default() -> Self {
Self {
base_delay: Duration::from_millis(500),
multiplier: 2.0,
max_delay: Duration::from_secs(30),
max_attempts: Some(5),
jitter: JitterStrategy::Full,
}
}
}
impl BackoffPolicy {
#[must_use]
pub fn new(base_delay: Duration, multiplier: f64) -> Self {
Self {
base_delay,
multiplier,
..Self::default()
}
}
#[must_use]
pub fn with_max_delay(mut self, max: Duration) -> Self {
self.max_delay = max;
self
}
#[must_use]
pub fn with_max_attempts(mut self, n: u32) -> Self {
self.max_attempts = Some(n);
self
}
#[must_use]
pub fn unlimited(mut self) -> Self {
self.max_attempts = None;
self
}
#[must_use]
pub fn with_jitter(mut self, jitter: JitterStrategy) -> Self {
self.jitter = jitter;
self
}
#[must_use]
pub fn deterministic_delay(&self, attempt: u32) -> Option<Duration> {
if let Some(max) = self.max_attempts {
if attempt > max {
return None;
}
}
if attempt == 0 {
return Some(Duration::ZERO);
}
let exp = u32::from(attempt.saturating_sub(1));
let scale = self.multiplier.powi(exp as i32);
let base_us = self.base_delay.as_micros() as f64;
let delay_us = (base_us * scale).min(self.max_delay.as_micros() as f64);
Some(Duration::from_micros(delay_us.round() as u64))
}
#[must_use]
pub fn should_retry(&self, attempt: u32) -> bool {
self.max_attempts.map_or(true, |max| attempt <= max)
}
}
#[derive(Debug, Clone)]
pub struct RetryState {
pub attempt: u32,
pub total_delay: Duration,
pub policy: BackoffPolicy,
}
impl RetryState {
#[must_use]
pub fn new(policy: BackoffPolicy) -> Self {
Self {
attempt: 0,
total_delay: Duration::ZERO,
policy,
}
}
#[must_use]
pub fn with_default_policy() -> Self {
Self::new(BackoffPolicy::default())
}
#[must_use]
pub fn can_retry(&self) -> bool {
self.policy.should_retry(self.attempt + 1)
}
#[must_use]
pub fn next_delay_deterministic(&self) -> Option<Duration> {
self.policy.deterministic_delay(self.attempt + 1)
}
pub fn record_failure(&mut self) -> Option<Duration> {
self.attempt += 1;
let delay = self.policy.deterministic_delay(self.attempt)?;
self.total_delay += delay;
Some(delay)
}
pub fn reset(&mut self) {
self.attempt = 0;
self.total_delay = Duration::ZERO;
}
#[must_use]
pub fn remaining_attempts(&self) -> Option<u32> {
self.policy
.max_attempts
.map(|max| max.saturating_sub(self.attempt))
}
}
pub struct RetryScheduler {
policy: BackoffPolicy,
rng_state: u64,
}
impl RetryScheduler {
#[must_use]
pub fn new(policy: BackoffPolicy, seed: u64) -> Self {
Self {
policy,
rng_state: seed | 1, }
}
#[must_use]
pub fn with_default_policy() -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0xdead_beef_cafe_babe);
Self::new(BackoffPolicy::default(), seed)
}
fn next_rand(&mut self) -> f64 {
self.rng_state = self
.rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
let hi = self.rng_state >> 11;
hi as f64 / (1u64 << 53) as f64
}
pub fn delay_for_attempt(&mut self, attempt: u32) -> Option<Duration> {
let det = self.policy.deterministic_delay(attempt)?;
let delay = match self.policy.jitter {
JitterStrategy::None => det,
JitterStrategy::Full => {
let r = self.next_rand();
Duration::from_micros((det.as_micros() as f64 * r).round() as u64)
}
JitterStrategy::Equal => {
let half = det.as_micros() as f64 / 2.0;
let r = self.next_rand();
Duration::from_micros((half + r * half).round() as u64)
}
};
Some(delay)
}
#[must_use]
pub fn should_retry(&self, attempt: u32) -> bool {
self.policy.should_retry(attempt)
}
#[must_use]
pub fn policy(&self) -> &BackoffPolicy {
&self.policy
}
}
#[must_use]
pub fn total_delay_estimate(policy: &BackoffPolicy, n: u32) -> Duration {
let limit = policy
.max_attempts
.map_or(n, |max| n.min(max));
(1..=limit)
.filter_map(|a| policy.deterministic_delay(a))
.fold(Duration::ZERO, |acc, d| acc + d)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy_parameters() {
let p = BackoffPolicy::default();
assert_eq!(p.base_delay, Duration::from_millis(500));
assert!((p.multiplier - 2.0).abs() < f64::EPSILON);
assert_eq!(p.max_delay, Duration::from_secs(30));
assert_eq!(p.max_attempts, Some(5));
assert_eq!(p.jitter, JitterStrategy::Full);
}
#[test]
fn test_deterministic_delay_first_attempt() {
let p = BackoffPolicy::new(Duration::from_secs(1), 2.0);
let d = p.deterministic_delay(1).expect("attempt 1 is always valid");
assert_eq!(d, Duration::from_secs(1));
}
#[test]
fn test_deterministic_delay_doubles_each_attempt() {
let p = BackoffPolicy::new(Duration::from_secs(1), 2.0)
.with_max_delay(Duration::from_secs(1000))
.with_max_attempts(10);
let d1 = p.deterministic_delay(1).expect("attempt 1 valid");
let d2 = p.deterministic_delay(2).expect("attempt 2 valid");
let d3 = p.deterministic_delay(3).expect("attempt 3 valid");
assert_eq!(d2, d1 * 2);
assert_eq!(d3, d1 * 4);
}
#[test]
fn test_deterministic_delay_capped_at_max() {
let p = BackoffPolicy::new(Duration::from_secs(1), 4.0)
.with_max_delay(Duration::from_secs(5))
.with_max_attempts(20);
let d = p.deterministic_delay(5).expect("attempt 5 valid");
assert_eq!(d, Duration::from_secs(5));
}
#[test]
fn test_deterministic_delay_none_beyond_max_attempts() {
let p = BackoffPolicy::default(); assert!(p.deterministic_delay(6).is_none());
}
#[test]
fn test_should_retry_within_limit() {
let p = BackoffPolicy::default(); assert!(p.should_retry(1));
assert!(p.should_retry(5));
assert!(!p.should_retry(6));
}
#[test]
fn test_unlimited_policy_always_retries() {
let p = BackoffPolicy::default().unlimited();
assert!(p.should_retry(1000));
assert!(p.deterministic_delay(1000).is_some());
}
#[test]
fn test_retry_state_initial() {
let s = RetryState::with_default_policy();
assert_eq!(s.attempt, 0);
assert_eq!(s.total_delay, Duration::ZERO);
assert!(s.can_retry());
}
#[test]
fn test_retry_state_record_failure_increments_attempt() {
let mut s = RetryState::with_default_policy();
s.record_failure().expect("first failure allowed");
assert_eq!(s.attempt, 1);
}
#[test]
fn test_retry_state_record_failure_accumulates_delay() {
let policy = BackoffPolicy::new(Duration::from_millis(100), 2.0)
.with_max_delay(Duration::from_secs(60))
.with_max_attempts(5)
.with_jitter(JitterStrategy::None);
let mut s = RetryState::new(policy);
let d1 = s.record_failure().expect("attempt 1");
let d2 = s.record_failure().expect("attempt 2");
assert_eq!(d1, Duration::from_millis(100));
assert_eq!(d2, Duration::from_millis(200));
assert_eq!(s.total_delay, Duration::from_millis(300));
}
#[test]
fn test_retry_state_none_after_exhaustion() {
let policy = BackoffPolicy::default().with_max_attempts(2);
let mut s = RetryState::new(policy);
s.record_failure().expect("attempt 1");
s.record_failure().expect("attempt 2");
assert!(s.record_failure().is_none());
}
#[test]
fn test_retry_state_reset() {
let mut s = RetryState::with_default_policy();
s.record_failure().expect("attempt 1");
s.reset();
assert_eq!(s.attempt, 0);
assert_eq!(s.total_delay, Duration::ZERO);
}
#[test]
fn test_retry_state_remaining_attempts() {
let policy = BackoffPolicy::default().with_max_attempts(3);
let mut s = RetryState::new(policy);
assert_eq!(s.remaining_attempts(), Some(3));
s.record_failure().expect("attempt 1");
assert_eq!(s.remaining_attempts(), Some(2));
}
#[test]
fn test_scheduler_no_jitter_deterministic() {
let policy = BackoffPolicy::new(Duration::from_millis(200), 2.0)
.with_jitter(JitterStrategy::None)
.with_max_attempts(5)
.with_max_delay(Duration::from_secs(60));
let mut sched = RetryScheduler::new(policy, 42);
let d1 = sched.delay_for_attempt(1).expect("attempt 1");
let d2 = sched.delay_for_attempt(2).expect("attempt 2");
assert_eq!(d1, Duration::from_millis(200));
assert_eq!(d2, Duration::from_millis(400));
}
#[test]
fn test_scheduler_full_jitter_within_bounds() {
let base = Duration::from_secs(1);
let policy = BackoffPolicy::new(base, 2.0)
.with_jitter(JitterStrategy::Full)
.with_max_delay(Duration::from_secs(60))
.with_max_attempts(10);
let mut sched = RetryScheduler::new(policy, 0xfeed_beef);
for attempt in 1..=5 {
let d = sched
.delay_for_attempt(attempt)
.expect("should succeed in test");
let det = BackoffPolicy::new(base, 2.0)
.with_max_delay(Duration::from_secs(60))
.with_max_attempts(10)
.deterministic_delay(attempt)
.expect("should succeed in test");
assert!(d <= det, "jittered delay {d:?} must be ≤ det {det:?}");
}
}
#[test]
fn test_scheduler_returns_none_beyond_limit() {
let policy = BackoffPolicy::default().with_max_attempts(3);
let mut sched = RetryScheduler::new(policy, 1);
assert!(sched.delay_for_attempt(4).is_none());
}
#[test]
fn test_total_delay_estimate_no_jitter() {
let policy = BackoffPolicy::new(Duration::from_secs(1), 2.0)
.with_jitter(JitterStrategy::None)
.with_max_delay(Duration::from_secs(1000))
.with_max_attempts(4);
let total = total_delay_estimate(&policy, 4);
assert_eq!(total, Duration::from_secs(15));
}
#[test]
fn test_total_delay_capped_by_max_attempts() {
let policy = BackoffPolicy::new(Duration::from_secs(1), 2.0)
.with_jitter(JitterStrategy::None)
.with_max_delay(Duration::from_secs(1000))
.with_max_attempts(2);
let total = total_delay_estimate(&policy, 10);
assert_eq!(total, Duration::from_secs(3)); }
#[test]
fn test_equal_jitter_midpoint_range() {
let base = Duration::from_secs(2);
let policy = BackoffPolicy::new(base, 1.0)
.with_jitter(JitterStrategy::Equal)
.with_max_delay(Duration::from_secs(60))
.with_max_attempts(10);
let mut sched = RetryScheduler::new(policy, 12345);
for _ in 0..20 {
let d = sched
.delay_for_attempt(1)
.expect("should succeed in test");
assert!(d >= Duration::from_secs(1), "equal-jitter low bound");
assert!(d <= Duration::from_secs(2), "equal-jitter high bound");
}
}
}