#![doc = include_str!("../README.md")]
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
#[derive(Debug)]
struct Slot {
interval_start: AtomicU64,
counter: AtomicU32,
}
impl Slot {
fn new() -> Self {
Self {
interval_start: AtomicU64::new(0),
counter: AtomicU32::new(0),
}
}
}
#[derive(Debug)]
pub struct InvocationCounter {
slots: Box<[Slot]>,
slot_count_exp: u8,
slot_size_exp: u8,
max_current_time: AtomicU64,
}
impl InvocationCounter {
pub fn new(slot_count_exp: u8, slot_size_exp: u8) -> Self {
let slots = (0..(1 << slot_count_exp))
.map(|_| Slot::new())
.collect::<Vec<_>>()
.into_boxed_slice();
Self {
slots,
slot_count_exp,
slot_size_exp,
max_current_time: AtomicU64::new(0),
}
}
pub fn register(&self, current_time: u64) {
let interval_start = current_time >> self.slot_size_exp;
let slot_index = interval_start % (1 << self.slot_count_exp);
let interval_start = interval_start << self.slot_size_exp;
let slot = &self.slots[slot_index as usize];
let time_in_slot = slot.interval_start.load(Ordering::Acquire);
if time_in_slot == interval_start {
slot.counter.fetch_add(1, Ordering::Relaxed);
} else {
slot.interval_start.store(interval_start, Ordering::Release);
slot.counter.store(1, Ordering::Release);
}
let current_max_time = self.max_current_time.load(Ordering::Acquire);
if current_max_time < current_time {
self.max_current_time
.compare_exchange_weak(
current_max_time,
current_time,
Ordering::Release,
Ordering::Relaxed,
)
.ok();
}
}
pub fn slot_count_exp(&self) -> u8 {
self.slot_count_exp
}
pub fn slot_size_exp(&self) -> u8 {
self.slot_size_exp
}
pub fn count_in(&self, start_time: u64, end_time: u64) -> u32 {
if start_time >= end_time {
return 0;
}
let current_max_time = self.max_current_time.load(Ordering::Acquire);
let ring_end = ((current_max_time >> self.slot_size_exp) + 1) << self.slot_size_exp;
let ring_start =
ring_end.saturating_sub((1 << self.slot_size_exp) * (1 << self.slot_count_exp));
let ring_buffer_range = ring_start..ring_end;
let asked_start = start_time >> self.slot_size_exp << self.slot_size_exp;
let asked_end = if end_time & ((1 << self.slot_size_exp) - 1) == 0 {
end_time
} else {
((end_time >> self.slot_size_exp) + 1) << self.slot_size_exp
};
let asked_range = asked_start..asked_end;
let valid_range = ring_buffer_range.start.max(asked_range.start)
..ring_buffer_range.end.min(asked_range.end);
let mut count = 0;
for slot in &self.slots {
let time_in_slot = slot.interval_start.load(Ordering::Acquire);
if valid_range.contains(&time_in_slot) {
count += slot.counter.load(Ordering::Acquire);
}
}
count
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_functionality() {
let counter = InvocationCounter::new(2, 3);
counter.register(0);
counter.register(1);
counter.register(8);
counter.register(16);
assert_eq!(counter.count_in(0, 16 + 1), 4);
assert_eq!(counter.count_in(100 - 32, 100 + 1), 0);
}
#[test]
fn test_count_in() {
let counter = InvocationCounter::new(1, 2);
counter.register(0);
assert_eq!(counter.count_in(0, 1), 1); assert_eq!(counter.count_in(0, 4), 1); assert_eq!(counter.count_in(0, 8), 1); assert_eq!(counter.count_in(4, 8), 0);
counter.register(1);
assert_eq!(counter.count_in(0, 4), 2); assert_eq!(counter.count_in(0, 8), 2); assert_eq!(counter.count_in(4, 8), 0);
counter.register(2);
assert_eq!(counter.count_in(0, 4), 3); assert_eq!(counter.count_in(4, 8), 0);
counter.register(3);
assert_eq!(counter.count_in(0, 4), 4); assert_eq!(counter.count_in(4, 8), 0);
counter.register(4);
assert_eq!(counter.count_in(0, 4), 4); assert_eq!(counter.count_in(4, 8), 1); assert_eq!(counter.count_in(0, 8), 5);
counter.register(5);
assert_eq!(counter.count_in(0, 4), 4); assert_eq!(counter.count_in(4, 8), 2); assert_eq!(counter.count_in(0, 8), 6);
counter.register(6);
assert_eq!(counter.count_in(0, 4), 4); assert_eq!(counter.count_in(4, 8), 3); assert_eq!(counter.count_in(0, 8), 7);
counter.register(7);
assert_eq!(counter.count_in(0, 4), 4); assert_eq!(counter.count_in(4, 8), 4); assert_eq!(counter.count_in(0, 8), 8); assert_eq!(counter.count_in(4, 12), 4); assert_eq!(counter.count_in(12, 16), 0);
counter.register(8);
assert_eq!(counter.count_in(0, 4), 0); assert_eq!(counter.count_in(4, 8), 4); assert_eq!(counter.count_in(8, 12), 1); assert_eq!(counter.count_in(4, 12), 5); assert_eq!(counter.count_in(8, 9), 1); assert_eq!(counter.count_in(0, 12), 5);
counter.register(9);
assert_eq!(counter.count_in(0, 4), 0); assert_eq!(counter.count_in(4, 8), 4); assert_eq!(counter.count_in(8, 12), 2); assert_eq!(counter.count_in(4, 12), 6); assert_eq!(counter.count_in(8, 10), 2); assert_eq!(counter.count_in(0, 16), 6);
counter.register(10);
assert_eq!(counter.count_in(4, 8), 4); assert_eq!(counter.count_in(8, 12), 3); assert_eq!(counter.count_in(4, 12), 7); assert_eq!(counter.count_in(8, 11), 3); assert_eq!(counter.count_in(0, 16), 7);
counter.register(11);
assert_eq!(counter.count_in(4, 8), 4); assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(4, 12), 8); assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(10, 12), 4); assert_eq!(counter.count_in(0, 16), 8);
counter.register(12);
assert_eq!(counter.count_in(4, 8), 0); assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(12, 16), 1); assert_eq!(counter.count_in(8, 16), 5); assert_eq!(counter.count_in(0, 16), 5); assert_eq!(counter.count_in(12, 13), 1);
counter.register(13);
assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(12, 16), 2); assert_eq!(counter.count_in(8, 16), 6); assert_eq!(counter.count_in(0, 16), 6); assert_eq!(counter.count_in(12, 14), 2);
counter.register(14);
assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(12, 16), 3); assert_eq!(counter.count_in(8, 16), 7); assert_eq!(counter.count_in(0, 16), 7); assert_eq!(counter.count_in(12, 15), 3);
counter.register(15);
assert_eq!(counter.count_in(8, 12), 4); assert_eq!(counter.count_in(12, 16), 4); assert_eq!(counter.count_in(8, 16), 8); assert_eq!(counter.count_in(0, 16), 8); assert_eq!(counter.count_in(12, 16), 4);
counter.register(16);
assert_eq!(counter.count_in(8, 12), 0); assert_eq!(counter.count_in(12, 16), 4); assert_eq!(counter.count_in(16, 20), 1); assert_eq!(counter.count_in(12, 20), 5); assert_eq!(counter.count_in(0, 20), 5); assert_eq!(counter.count_in(16, 17), 1); }
#[test]
fn test_slot_reuse() {
let counter = InvocationCounter::new(2, 2);
counter.register(0); counter.register(16);
assert_eq!(counter.count_in(0, 17), 1); }
#[test]
fn test_concurrent_access() {
let num_threads = 4;
let registrations_per_thread = 100;
let counter = Arc::new(InvocationCounter::new(3, 6));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let counter_clone = Arc::clone(&counter);
thread::spawn(move || {
for i in 0..registrations_per_thread {
counter_clone.register(thread_id as u64 * 10 + i as u64);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let count = counter.count_in(0, 399); assert!(count > 0);
assert!(count <= num_threads * registrations_per_thread);
}
#[test]
fn test_edge_cases() {
let counter = InvocationCounter::new(2, 3);
assert_eq!(counter.count_in(0, 0), 0);
counter.register(0);
assert_eq!(counter.count_in(0, 1), 1);
let large_time = 1_000_000u64;
counter.register(large_time);
assert_eq!(counter.count_in(large_time, large_time + 1), 1);
}
}