use rand::RngExt;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct BackoffPolicy {
pub initial: Duration,
pub multiplier: f64,
pub max_delay: Duration,
pub jitter: f64,
pub streak_limit: Option<u32>,
}
impl Default for BackoffPolicy {
fn default() -> Self {
Self {
initial: Duration::from_secs(1),
multiplier: 2.0,
max_delay: Duration::from_secs(300),
jitter: 0.1,
streak_limit: None,
}
}
}
impl BackoffPolicy {
pub fn with_trip(n: u32) -> Self {
Self {
streak_limit: Some(n),
..Self::default()
}
}
pub fn compute_delay(&self, streak: u32) -> Duration {
let streak = streak.max(1);
let exp = self.multiplier.powi((streak - 1) as i32);
let base = self.initial.as_secs_f64() * exp;
let cap = self.max_delay.as_secs_f64();
let capped = base.min(cap);
let jitter_factor = if self.jitter > 0.0 {
let j = self.jitter.clamp(0.0, 1.0);
let mut rng = rand::rng();
let r: f64 = rng.random_range(-1.0..=1.0);
1.0 + (j * r)
} else {
1.0
};
let delayed = (capped * jitter_factor).max(0.0);
Duration::from_secs_f64(delayed).min(self.max_delay)
}
pub fn is_tripped(&self, failure_streak: u32) -> bool {
matches!(self.streak_limit, Some(limit) if failure_streak >= limit)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_delay_no_jitter_is_exponential() {
let p = BackoffPolicy {
initial: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(60),
jitter: 0.0,
streak_limit: None,
};
assert_eq!(p.compute_delay(1), Duration::from_millis(100));
assert_eq!(p.compute_delay(2), Duration::from_millis(200));
assert_eq!(p.compute_delay(3), Duration::from_millis(400));
assert_eq!(p.compute_delay(4), Duration::from_millis(800));
}
#[test]
fn compute_delay_caps_at_max_delay() {
let p = BackoffPolicy {
initial: Duration::from_millis(100),
multiplier: 10.0,
max_delay: Duration::from_millis(500),
jitter: 0.0,
streak_limit: None,
};
assert_eq!(p.compute_delay(5), Duration::from_millis(500));
assert_eq!(p.compute_delay(20), Duration::from_millis(500));
}
#[test]
fn compute_delay_with_jitter_stays_within_bounds() {
let p = BackoffPolicy {
initial: Duration::from_millis(1000),
multiplier: 1.0,
max_delay: Duration::from_secs(60),
jitter: 0.5,
streak_limit: None,
};
for _ in 0..50 {
let d = p.compute_delay(1);
assert!(d >= Duration::from_millis(500), "too small: {d:?}");
assert!(d <= Duration::from_millis(1500), "too big: {d:?}");
}
}
#[test]
fn compute_delay_streak_zero_treated_as_one() {
let p = BackoffPolicy {
initial: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(10),
jitter: 0.0,
streak_limit: None,
};
assert_eq!(p.compute_delay(0), Duration::from_millis(100));
}
#[test]
fn is_tripped_respects_limit() {
let p = BackoffPolicy {
initial: Duration::from_millis(100),
multiplier: 2.0,
max_delay: Duration::from_secs(60),
jitter: 0.0,
streak_limit: Some(3),
};
assert!(!p.is_tripped(0));
assert!(!p.is_tripped(2));
assert!(p.is_tripped(3));
assert!(p.is_tripped(99));
}
#[test]
fn is_tripped_none_never_trips() {
let p = BackoffPolicy {
streak_limit: None,
..Default::default()
};
assert!(!p.is_tripped(0));
assert!(!p.is_tripped(u32::MAX));
}
#[test]
fn default_policy_values() {
let p = BackoffPolicy::default();
assert_eq!(p.initial, Duration::from_secs(1));
assert!((p.multiplier - 2.0).abs() < f64::EPSILON);
assert_eq!(p.max_delay, Duration::from_secs(300));
assert!((p.jitter - 0.1).abs() < f64::EPSILON);
assert_eq!(p.streak_limit, None);
}
#[test]
fn with_trip_inherits_default_and_sets_limit() {
let p = BackoffPolicy::with_trip(7);
let d = BackoffPolicy::default();
assert_eq!(p.initial, d.initial);
assert_eq!(p.multiplier, d.multiplier);
assert_eq!(p.max_delay, d.max_delay);
assert_eq!(p.jitter, d.jitter);
assert_eq!(p.streak_limit, Some(7));
}
}