use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum JitterMode {
Disabled,
Deterministic {
seed: u64,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct BackoffPolicy {
pub initial: Duration,
pub max: Duration,
pub jitter_percent: u8,
pub reset_after: Duration,
pub jitter_mode: JitterMode,
}
impl BackoffPolicy {
pub fn new(
initial: Duration,
max: Duration,
jitter_percent: u8,
reset_after: Duration,
) -> Self {
Self {
initial,
max,
jitter_percent: jitter_percent.min(100),
reset_after,
jitter_mode: JitterMode::Disabled,
}
}
pub fn with_deterministic_jitter(mut self, seed: u64) -> Self {
self.jitter_mode = JitterMode::Deterministic { seed };
self
}
pub fn delay_for_attempt(&self, attempt: u64) -> Duration {
let exponential = self.exponential_delay(attempt.max(1));
self.apply_jitter(exponential).min(self.max)
}
pub fn should_reset(&self, stable_for: Duration) -> bool {
stable_for >= self.reset_after
}
fn exponential_delay(&self, attempt: u64) -> Duration {
let shift = attempt.saturating_sub(1).min(32);
let multiplier = 1_u128 << shift;
let millis = self.initial.as_millis().saturating_mul(multiplier);
duration_from_millis(millis).min(self.max)
}
fn apply_jitter(&self, base: Duration) -> Duration {
if self.jitter_percent == 0 {
return base;
}
match self.jitter_mode {
JitterMode::Disabled => base,
JitterMode::Deterministic { seed } => {
let jitter = deterministic_jitter(base, self.jitter_percent, seed);
base.saturating_add(jitter)
}
}
}
}
fn duration_from_millis(millis: u128) -> Duration {
Duration::from_millis(millis.min(u64::MAX as u128) as u64)
}
fn deterministic_jitter(base: Duration, percent: u8, seed: u64) -> Duration {
let max_jitter = base.as_millis().saturating_mul(percent as u128) / 100;
if max_jitter == 0 {
return Duration::ZERO;
}
let mixed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
duration_from_millis((mixed as u128) % (max_jitter + 1))
}