use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WindowMode {
TimeSliding {
window_secs: u64,
},
CountSliding {
max_count: usize,
},
}
impl Default for WindowMode {
fn default() -> Self {
Self::TimeSliding { window_secs: 60 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct FailureWindowConfig {
pub mode: WindowMode,
pub threshold: usize,
}
impl FailureWindowConfig {
pub fn time_sliding(window_secs: u64, threshold: usize) -> Self {
Self {
mode: WindowMode::TimeSliding { window_secs },
threshold,
}
}
pub fn count_sliding(max_count: usize, threshold: usize) -> Self {
Self {
mode: WindowMode::CountSliding { max_count },
threshold,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct FailureWindowState {
pub current_count: usize,
pub threshold_reached: bool,
pub oldest_timestamp: Option<Instant>,
}
#[derive(Debug, Clone)]
pub struct FailureWindow {
pub config: FailureWindowConfig,
failures: VecDeque<Instant>,
last_failure: Option<Instant>,
}
impl FailureWindow {
pub fn new(config: FailureWindowConfig) -> Self {
Self {
config,
failures: VecDeque::new(),
last_failure: None,
}
}
pub fn record_failure(&mut self, now: Instant) -> FailureWindowState {
self.prune(now);
self.failures.push_back(now);
self.last_failure = Some(now);
if let WindowMode::CountSliding { max_count } = self.config.mode {
while self.failures.len() > max_count {
self.failures.pop_front();
}
}
self.current_state()
}
pub fn clear(&mut self) {
self.failures.clear();
self.last_failure = None;
}
pub fn current_state_at(&self, now: Instant) -> FailureWindowState {
let mut temp_failures = self.failures.clone();
if let WindowMode::TimeSliding { window_secs } = self.config.mode {
let window = Duration::from_secs(window_secs);
while temp_failures
.front()
.is_some_and(|ts| now.duration_since(*ts) > window)
{
temp_failures.pop_front();
}
}
let current_count = temp_failures.len();
let threshold_reached = current_count >= self.config.threshold;
let oldest_timestamp = temp_failures.front().copied();
FailureWindowState {
current_count,
threshold_reached,
oldest_timestamp,
}
}
pub fn current_state(&self) -> FailureWindowState {
let current_count = self.failures.len();
let threshold_reached = current_count >= self.config.threshold;
let oldest_timestamp = self.failures.front().copied();
FailureWindowState {
current_count,
threshold_reached,
oldest_timestamp,
}
}
fn prune(&mut self, now: Instant) {
if let WindowMode::TimeSliding { window_secs } = self.config.mode {
let window = Duration::from_secs(window_secs);
while self
.failures
.front()
.is_some_and(|ts| now.duration_since(*ts) > window)
{
self.failures.pop_front();
}
}
}
pub fn failure_count(&self) -> usize {
self.failures.len()
}
}
#[cfg(test)]
mod tests {
use crate::policy::failure_window::{FailureWindow, FailureWindowConfig, WindowMode};
use std::time::{Duration, Instant};
#[test]
fn test_time_sliding_window_expiration() {
let config = FailureWindowConfig::time_sliding(10, 3);
let mut window = FailureWindow::new(config);
let base = Instant::now();
window.record_failure(base);
window.record_failure(base + Duration::from_secs(5));
let state = window.current_state_at(base + Duration::from_secs(8));
assert_eq!(state.current_count, 2);
assert!(!state.threshold_reached);
let state = window.current_state_at(base + Duration::from_secs(11));
assert_eq!(state.current_count, 1);
}
#[test]
fn test_count_sliding_window_limit() {
let config = FailureWindowConfig::count_sliding(3, 5);
let mut window = FailureWindow::new(config);
let base = Instant::now();
window.record_failure(base);
window.record_failure(base + Duration::from_secs(1));
window.record_failure(base + Duration::from_secs(2));
window.record_failure(base + Duration::from_secs(3));
assert_eq!(window.failure_count(), 3);
}
#[test]
fn test_threshold_detection() {
let config = FailureWindowConfig::time_sliding(60, 3);
let mut window = FailureWindow::new(config);
let base = Instant::now();
window.record_failure(base);
window.record_failure(base + Duration::from_secs(1));
let state = window.current_state();
assert!(!state.threshold_reached);
window.record_failure(base + Duration::from_secs(2));
let state = window.current_state();
assert!(state.threshold_reached);
}
#[test]
fn test_default_config() {
let config = WindowMode::default();
match config {
WindowMode::TimeSliding { window_secs } => {
assert_eq!(window_secs, 60);
}
_ => panic!("Default should be TimeSliding"),
}
}
}