use std::sync::atomic::{AtomicU64, Ordering::Relaxed};
use std::time::{Duration, Instant};
pub(crate) struct FixedWindowCounter {
epochs: Box<[AtomicU64]>,
counts: Box<[AtomicU64]>,
bucket_duration_ns: u64,
resolution: usize,
start: Instant,
}
impl FixedWindowCounter {
pub(crate) fn new(window_size: Duration, resolution: usize) -> Self {
assert!(resolution > 0, "resolution must be > 0");
let window_ns = window_size.as_nanos() as u64;
assert!(window_ns > 0, "window_size must be non-zero");
let bucket_duration_ns = window_ns / resolution as u64;
assert!(
bucket_duration_ns > 0,
"window_size ({window_ns} ns) too small for resolution {resolution}"
);
Self {
epochs: (0..resolution).map(|_| AtomicU64::new(0)).collect(),
counts: (0..resolution).map(|_| AtomicU64::new(0)).collect(),
bucket_duration_ns,
resolution,
start: Instant::now(),
}
}
pub(crate) fn record(&self, amount: u64) {
let elapsed_ns = self.start.elapsed().as_nanos() as u64;
let current_epoch = elapsed_ns / self.bucket_duration_ns;
let bucket_idx = (current_epoch as usize) % self.resolution;
loop {
let stored_epoch = self.epochs[bucket_idx].load(Relaxed);
if stored_epoch == current_epoch {
self.counts[bucket_idx].fetch_add(amount, Relaxed);
return;
}
match self.epochs[bucket_idx].compare_exchange(
stored_epoch,
current_epoch,
Relaxed,
Relaxed,
) {
Ok(_) => {
self.counts[bucket_idx].store(0, Relaxed);
self.counts[bucket_idx].fetch_add(amount, Relaxed);
return;
}
Err(_) => {
continue;
}
}
}
}
pub(crate) fn sum(&self) -> u64 {
let elapsed_ns = self.start.elapsed().as_nanos() as u64;
let current_epoch = elapsed_ns / self.bucket_duration_ns;
let mut total: u64 = 0;
for i in 0..self.resolution {
let epoch = self.epochs[i].load(Relaxed);
if epoch <= current_epoch && current_epoch - epoch < self.resolution as u64 {
total = total.saturating_add(self.counts[i].load(Relaxed));
}
}
total
}
pub(crate) fn remaining(&self, limit: u64) -> u64 {
limit.saturating_sub(self.sum())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn record_and_sum() {
let counter = FixedWindowCounter::new(Duration::from_secs(60), 6);
counter.record(50);
counter.record(30);
assert_eq!(counter.sum(), 80);
}
#[test]
fn remaining_calculation() {
let counter = FixedWindowCounter::new(Duration::from_secs(60), 6);
counter.record(200);
assert_eq!(counter.remaining(1000), 800);
assert_eq!(counter.remaining(100), 0);
}
#[test]
fn expired_buckets_not_counted() {
let counter = FixedWindowCounter::new(Duration::from_millis(50), 5);
counter.record(42);
assert_eq!(counter.sum(), 42);
thread::sleep(Duration::from_millis(80));
assert_eq!(counter.sum(), 0);
}
#[test]
fn multiple_records_accumulate() {
let counter = FixedWindowCounter::new(Duration::from_secs(10), 2);
for _ in 0..100 {
counter.record(1);
}
assert_eq!(counter.sum(), 100);
}
}