use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum JitterMode {
Disabled,
Deterministic {
seed: u64,
},
FullJitter {
seed: u64,
},
DecorrelatedJitter {
seed: u64,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct BackoffPolicy {
pub initial: Duration,
pub max: Duration,
pub jitter_percent: u8,
pub reset_after: Duration,
pub jitter_mode: JitterMode,
}
impl BackoffPolicy {
pub fn new(
initial: Duration,
max: Duration,
jitter_percent: u8,
reset_after: Duration,
) -> Self {
Self {
initial,
max,
jitter_percent: jitter_percent.min(100),
reset_after,
jitter_mode: JitterMode::Disabled,
}
}
pub fn with_deterministic_jitter(mut self, seed: u64) -> Self {
self.jitter_mode = JitterMode::Deterministic { seed };
self
}
pub fn with_full_jitter(mut self, seed: u64) -> Self {
self.jitter_mode = JitterMode::FullJitter { seed };
self
}
pub fn with_decorrelated_jitter(mut self, seed: u64) -> Self {
self.jitter_mode = JitterMode::DecorrelatedJitter { seed };
self
}
pub fn delay_for_child_start_count(&self, child_start_count: u64) -> Duration {
let exponential = self.exponential_delay(child_start_count.max(1));
self.apply_jitter(exponential).min(self.max)
}
pub fn should_reset(&self, stable_for: Duration) -> bool {
stable_for >= self.reset_after
}
fn exponential_delay(&self, child_start_count: u64) -> Duration {
let shift = child_start_count.saturating_sub(1).min(32);
let multiplier = 1_u128 << shift;
let millis = self.initial.as_millis().saturating_mul(multiplier);
duration_from_millis(millis).min(self.max)
}
fn apply_jitter(&self, base: Duration) -> Duration {
if self.jitter_percent == 0 {
return base;
}
match self.jitter_mode {
JitterMode::Disabled => base,
JitterMode::Deterministic { seed } => {
let jitter = deterministic_jitter(base, self.jitter_percent, seed);
base.saturating_add(jitter)
}
JitterMode::FullJitter { seed } => calculate_full_jitter(base, self.max, seed),
JitterMode::DecorrelatedJitter { seed } => {
calculate_decorrelated_jitter(base, self.initial, self.max, seed)
}
}
}
}
fn duration_from_millis(millis: u128) -> Duration {
Duration::from_millis(millis.min(u64::MAX as u128) as u64)
}
fn deterministic_jitter(base: Duration, percent: u8, seed: u64) -> Duration {
let max_jitter = base.as_millis().saturating_mul(percent as u128) / 100;
if max_jitter == 0 {
return Duration::ZERO;
}
let mixed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
duration_from_millis((mixed as u128) % (max_jitter + 1))
}
pub fn calculate_full_jitter(base: Duration, max: Duration, seed: u64) -> Duration {
let upper_bound = std::cmp::min(base, max);
let upper_millis = upper_bound.as_millis();
if upper_millis == 0 {
return Duration::ZERO;
}
let lcg_next = |state: &mut u64| -> u64 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
*state
};
let mut rng_state = seed;
let random_value = lcg_next(&mut rng_state);
let jitter_millis = (random_value as u128) % (upper_millis + 1);
duration_from_millis(jitter_millis)
}
pub fn calculate_decorrelated_jitter(
base: Duration,
initial: Duration,
max: Duration,
seed: u64,
) -> Duration {
let lower = initial.as_millis();
let upper_candidate = base.as_millis().saturating_mul(3);
let upper = std::cmp::min(upper_candidate, max.as_millis());
if upper <= lower {
return duration_from_millis(lower);
}
let lcg_next = |state: &mut u64| -> u64 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
*state
};
let mut rng_state = seed;
let random_value = lcg_next(&mut rng_state);
let range = upper - lower;
let jitter_millis = lower + ((random_value as u128) % (range + 1));
duration_from_millis(jitter_millis)
}
#[derive(Debug, Clone)]
pub struct ColdStartBudget {
pub window_secs: u64,
pub max_restarts: u32,
pub restart_count: u32,
pub start_time_secs: u64,
}
impl ColdStartBudget {
pub fn new(window_secs: u64, max_restarts: u32, start_time_secs: u64) -> Self {
Self {
window_secs,
max_restarts,
restart_count: 0,
start_time_secs,
}
}
pub fn record_restart(&mut self, current_time_secs: u64) -> bool {
let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
if elapsed > self.window_secs {
self.restart_count = 1;
return false;
}
self.restart_count += 1;
self.restart_count > self.max_restarts
}
pub fn is_exhausted(&self, current_time_secs: u64) -> bool {
let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
if elapsed > self.window_secs {
return false; }
self.restart_count >= self.max_restarts
}
pub fn get_restart_count(&self) -> u32 {
self.restart_count
}
pub fn is_window_active(&self, current_time_secs: u64) -> bool {
let elapsed = current_time_secs.saturating_sub(self.start_time_secs);
elapsed <= self.window_secs
}
}
#[derive(Debug, Clone)]
pub struct HotLoopDetector {
pub window_secs: u64,
pub min_restarts: u32,
pub crash_times: Vec<u64>,
}
impl HotLoopDetector {
pub fn new(window_secs: u64, min_restarts: u32) -> Self {
Self {
window_secs,
min_restarts,
crash_times: Vec::new(),
}
}
pub fn record_crash(&mut self, crash_time_secs: u64) -> bool {
self.crash_times.push(crash_time_secs);
let cutoff = crash_time_secs.saturating_sub(self.window_secs);
self.crash_times.retain(|&t| t > cutoff);
self.is_hot_loop_detected(crash_time_secs)
}
pub fn is_hot_loop_detected(&self, current_time_secs: u64) -> bool {
let cutoff = current_time_secs.saturating_sub(self.window_secs);
let crashes_in_window = self.crash_times.iter().filter(|&&t| t > cutoff).count();
crashes_in_window >= self.min_restarts as usize
}
pub fn get_crash_count_in_window(&self, current_time_secs: u64) -> usize {
let cutoff = current_time_secs.saturating_sub(self.window_secs);
self.crash_times.iter().filter(|&&t| t > cutoff).count()
}
pub fn clear_history(&mut self) {
self.crash_times.clear();
}
}
#[cfg(test)]
mod backoff_extended_tests {
use crate::policy::backoff::{
ColdStartBudget, HotLoopDetector, calculate_decorrelated_jitter, calculate_full_jitter,
};
use std::time::Duration;
#[test]
fn test_cold_start_budget_basic_tracking() {
let mut budget = ColdStartBudget::new(300, 3, 1000);
assert!(!budget.record_restart(1010));
assert!(!budget.record_restart(1020));
assert!(!budget.record_restart(1030));
assert!(budget.record_restart(1040));
}
#[test]
fn test_cold_start_window_expiry() {
let mut budget = ColdStartBudget::new(300, 2, 1000);
budget.record_restart(1010);
budget.record_restart(1020);
assert!(!budget.record_restart(1400)); assert_eq!(budget.get_restart_count(), 1);
}
#[test]
fn test_hot_loop_detection_basic() {
let mut detector = HotLoopDetector::new(60, 3);
detector.record_crash(1000);
detector.record_crash(1010);
assert!(!detector.is_hot_loop_detected(1010));
detector.record_crash(1020);
assert!(detector.is_hot_loop_detected(1020)); }
#[test]
fn test_hot_loop_sliding_window() {
let mut detector = HotLoopDetector::new(60, 3);
detector.record_crash(1000);
detector.record_crash(1010);
detector.record_crash(1020);
assert!(detector.is_hot_loop_detected(1020));
assert!(!detector.is_hot_loop_detected(1070)); }
#[test]
fn test_full_jitter_bounds() {
let delay =
calculate_full_jitter(Duration::from_millis(100), Duration::from_millis(1000), 42);
assert!(delay <= Duration::from_millis(100)); }
#[test]
fn test_decorrelated_jitter_bounds() {
let delay = calculate_decorrelated_jitter(
Duration::from_millis(100),
Duration::from_millis(10),
Duration::from_millis(1000),
42,
);
assert!(delay >= Duration::from_millis(10)); assert!(delay <= Duration::from_millis(1000)); }
}