use std::sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Mutex,
};
use tokio::time::{Duration, Instant};
#[derive(Debug)]
pub struct WindowedCounter {
anchor: Instant,
epoch: AtomicU64,
epoch_index: Mutex<usize>,
buckets: [AtomicUsize; NUM_BUCKETS],
bucket_window_ms: u64,
current: AtomicUsize,
}
const NUM_BUCKETS: usize = 10;
impl WindowedCounter {
pub fn new(window: Duration) -> Self {
#[allow(clippy::declare_interior_mutable_const)]
const ATOMIC_USIZE_ZERO: AtomicUsize = AtomicUsize::new(0);
WindowedCounter {
anchor: Instant::now(),
epoch: AtomicU64::new(0),
epoch_index: Mutex::new(0),
buckets: [ATOMIC_USIZE_ZERO; NUM_BUCKETS],
bucket_window_ms: std::cmp::max(1, (window / (NUM_BUCKETS) as u32).as_millis() as u64),
current: AtomicUsize::new(0),
}
}
pub fn add(&self, amount: usize) {
self.expire();
self.current.fetch_add(amount, Ordering::SeqCst);
}
pub fn sum(&self) -> usize {
self.expire();
let current = self.current.load(Ordering::SeqCst);
let prev: usize = self.buckets.iter().fold(0, |acc, bucket| {
acc.saturating_add(bucket.load(Ordering::SeqCst))
});
current.saturating_add(prev)
}
fn expire(&self) {
let mut my_epoch = self.anchor.elapsed().as_millis() as u64;
let cur_epoch = self.epoch.load(Ordering::Acquire);
let mut delta = my_epoch - cur_epoch;
if delta < self.bucket_window_ms {
return;
}
let mut epoch_index = self.epoch_index.lock().unwrap();
my_epoch = self.anchor.elapsed().as_millis() as u64;
match self
.epoch
.compare_exchange(cur_epoch, my_epoch, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {}
Err(actual) => {
dbg!(actual);
return;
}
}
let to_commit = self.current.swap(0, Ordering::SeqCst);
self.buckets[*epoch_index].store(to_commit, Ordering::SeqCst);
let mut i = (*epoch_index + 1) % NUM_BUCKETS;
while delta > self.bucket_window_ms {
self.buckets[i].store(0, Ordering::SeqCst);
delta -= self.bucket_window_ms;
i = (i + 1) % NUM_BUCKETS;
}
self.buckets[i].store(0, Ordering::SeqCst);
*epoch_index = i;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time;
fn counter() -> WindowedCounter {
WindowedCounter::new(time::Duration::from_secs(3))
}
#[tokio::test]
async fn small_window_big_step() {
time::pause();
let ctr = WindowedCounter::new(time::Duration::from_millis(1));
ctr.add(1);
assert_eq!(1, ctr.sum());
time::advance(Duration::from_secs(1)).await;
ctr.add(1);
assert_eq!(1, ctr.sum());
}
#[tokio::test]
async fn sum_no_sliding() {
time::pause();
let ctr = counter();
ctr.add(1);
assert_eq!(1, ctr.sum());
ctr.add(1);
assert_eq!(2, ctr.sum());
ctr.add(3);
assert_eq!(5, ctr.sum());
}
#[tokio::test]
async fn sliding_short_window() {
time::pause();
let ctr = counter();
dbg!(ctr.add(1));
assert_eq!(1, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(1, dbg!(ctr.sum()));
dbg!(ctr.add(2));
assert_eq!(3, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(3, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(2, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(0, dbg!(ctr.sum()));
}
#[tokio::test]
async fn sliding_over_large_window() {
time::pause();
let ctr = WindowedCounter::new(Duration::from_secs(20));
for i in 0..21 {
ctr.add(dbg!(i % 3));
dbg!(time::advance(Duration::from_secs(1)).await);
}
assert_eq!(20, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(18, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(18, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(5)).await);
assert_eq!(12, dbg!(ctr.sum()));
dbg!(ctr.add(1));
dbg!(time::advance(Duration::from_secs(10)).await);
assert_eq!(3, dbg!(ctr.sum()));
}
#[tokio::test]
async fn sliding_past_empty_buckets() {
time::pause();
let ctr = counter();
dbg!(ctr.add(1));
assert_eq!(1, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
dbg!(ctr.add(2));
assert_eq!(3, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(1)).await);
dbg!(ctr.add(1));
assert_eq!(4, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(2)).await);
assert_eq!(1, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(100)).await);
assert_eq!(0, dbg!(ctr.sum()));
dbg!(ctr.add(100));
dbg!(time::advance(Duration::from_secs(1)).await);
assert_eq!(100, dbg!(ctr.sum()));
dbg!(ctr.add(100));
dbg!(time::advance(Duration::from_secs(1)).await);
dbg!(ctr.add(100));
assert_eq!(300, dbg!(ctr.sum()));
dbg!(time::advance(Duration::from_secs(100)).await);
assert_eq!(0, dbg!(ctr.sum()));
}
}