use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Instant;
pub const WIDTH: usize = 1 << 16;
pub const DEPTH: usize = 4;
const TS_BITS: u32 = 8;
const COUNT_BITS: u32 = 24;
const COUNT_MASK: u32 = (1 << COUNT_BITS) - 1;
pub const MAX_COUNT: u32 = COUNT_MASK;
const MAX_DECAY: u32 = COUNT_BITS;
const TS_MASK: u32 = (1 << TS_BITS) - 1;
#[inline(always)]
pub fn pack(timestamp: u32, count: u32) -> u32 {
((timestamp & TS_MASK) << COUNT_BITS) | (count & COUNT_MASK)
}
#[inline(always)]
pub fn unpack(cell: u32) -> (u32, u32) {
let timestamp = cell >> COUNT_BITS;
let count = cell & COUNT_MASK;
(timestamp, count)
}
#[inline(always)]
pub fn decay_steps(cell_ts: u32, current_epoch: u64) -> u32 {
let epoch_low = (current_epoch & TS_MASK as u64) as u32;
let diff = epoch_low.wrapping_sub(cell_ts) & TS_MASK;
diff.min(MAX_DECAY)
}
#[inline(always)]
fn apply_decay(count: u32, steps: u32) -> u32 {
if steps >= MAX_DECAY {
0
} else {
count >> steps
}
}
const DEFAULT_HASH_A: [u64; DEPTH] = [
0x9e3779b97f4a7c15, 0x517cc1b727220a95, 0x6c62272e07bb0142, 0xbf58476d1ce4e5b9, ];
const DEFAULT_HASH_B: [u64; DEPTH] = [
0xd2a98b26625eee7b,
0x94d049bb133111eb,
0xc4ceb9fe1a85ec53,
0xe7037ed1a0b428db,
];
fn random_hash_params() -> ([u64; DEPTH], [u64; DEPTH]) {
use std::collections::hash_map::RandomState;
use std::hash::{BuildHasher, Hasher};
let mut a = [0u64; DEPTH];
let mut b = [0u64; DEPTH];
for i in 0..DEPTH {
let state = RandomState::new();
let mut hasher = state.build_hasher();
hasher.write_usize(i);
a[i] = hasher.finish() | 1;
let state2 = RandomState::new();
let mut hasher2 = state2.build_hasher();
hasher2.write_usize(i + DEPTH);
b[i] = hasher2.finish();
}
(a, b)
}
#[inline(always)]
fn compute_hash(id: u64, row: usize, hash_a: &[u64; DEPTH], hash_b: &[u64; DEPTH]) -> usize {
let h = hash_a[row].wrapping_mul(id).wrapping_add(hash_b[row]);
(h >> 48) as usize }
#[repr(align(64))]
pub struct Row([AtomicU32; WIDTH]);
pub struct ClTds {
rows: Box<[Row; DEPTH]>,
epoch: AtomicU64,
created_at: Option<Instant>,
epoch_interval_ms: u64,
hash_a: [u64; DEPTH],
hash_b: [u64; DEPTH],
}
impl Default for ClTds {
fn default() -> Self {
Self::new()
}
}
impl ClTds {
pub fn new() -> Self {
let (a, b) = random_hash_params();
Self::alloc(None, 0, a, b)
}
pub fn new_deterministic() -> Self {
Self::alloc(None, 0, DEFAULT_HASH_A, DEFAULT_HASH_B)
}
pub fn with_epoch_interval(interval_ms: u64) -> Self {
assert!(interval_ms > 0, "epoch interval must be > 0");
let (a, b) = random_hash_params();
Self::alloc(Some(Instant::now()), interval_ms, a, b)
}
fn alloc(
created_at: Option<Instant>,
epoch_interval_ms: u64,
hash_a: [u64; DEPTH],
hash_b: [u64; DEPTH],
) -> Self {
let rows = unsafe {
let layout = std::alloc::Layout::new::<[Row; DEPTH]>();
let ptr = std::alloc::alloc_zeroed(layout) as *mut [Row; DEPTH];
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
Box::from_raw(ptr)
};
ClTds {
rows,
epoch: AtomicU64::new(0),
created_at,
epoch_interval_ms,
hash_a,
hash_b,
}
}
#[inline(always)]
fn current_epoch(&self) -> u64 {
match self.created_at {
Some(t) => t.elapsed().as_millis() as u64 / self.epoch_interval_ms,
None => self.epoch.load(Ordering::Relaxed),
}
}
#[inline(always)]
fn hash(&self, id: u64, row: usize) -> usize {
compute_hash(id, row, &self.hash_a, &self.hash_b)
}
pub fn increment(&self, id: u64) {
let epoch = self.current_epoch();
let epoch_low = (epoch & TS_MASK as u64) as u32;
for row_idx in 0..DEPTH {
let col = self.hash(id, row_idx);
let cell = &self.rows[row_idx].0[col];
loop {
let old = cell.load(Ordering::Relaxed);
let (old_ts, old_count) = unpack(old);
let steps = decay_steps(old_ts, epoch);
let decayed = apply_decay(old_count, steps);
let new_count = (decayed + 1).min(MAX_COUNT);
let new_val = pack(epoch_low, new_count);
match cell.compare_exchange_weak(old, new_val, Ordering::Relaxed, Ordering::Relaxed)
{
Ok(_) => break,
Err(_) => continue,
}
}
}
}
pub fn query(&self, id: u64) -> u32 {
let epoch = self.current_epoch();
let mut min_count = u32::MAX;
for row_idx in 0..DEPTH {
let col = self.hash(id, row_idx);
let cell = &self.rows[row_idx].0[col];
let val = cell.load(Ordering::Relaxed);
let (ts, count) = unpack(val);
let steps = decay_steps(ts, epoch);
let decayed = apply_decay(count, steps);
min_count = min_count.min(decayed);
}
min_count
}
pub fn tick_epoch(&self) {
self.epoch.fetch_add(1, Ordering::Relaxed);
}
pub fn epoch(&self) -> u64 {
self.current_epoch()
}
pub fn is_auto_epoch(&self) -> bool {
self.created_at.is_some()
}
pub fn memory_bytes(&self) -> usize {
DEPTH * WIDTH * std::mem::size_of::<AtomicU32>()
}
pub fn algorithm_parameters() -> (f64, f64, usize, usize) {
let epsilon = std::f64::consts::E / WIDTH as f64;
let delta = (-(DEPTH as f64)).exp();
(epsilon, delta, WIDTH, DEPTH)
}
pub fn error_bound(n_effective: u64) -> f64 {
let epsilon = std::f64::consts::E / WIDTH as f64;
epsilon * n_effective as f64
}
pub fn to_bytes(&self) -> Vec<u8> {
let epoch = self.current_epoch();
let matrix_size = DEPTH * WIDTH * 4;
let header_size = 8 + (DEPTH * 8 * 2); let mut buf = Vec::with_capacity(header_size + matrix_size);
buf.extend_from_slice(&epoch.to_le_bytes());
for i in 0..DEPTH {
buf.extend_from_slice(&self.hash_a[i].to_le_bytes());
}
for i in 0..DEPTH {
buf.extend_from_slice(&self.hash_b[i].to_le_bytes());
}
for row in 0..DEPTH {
for col in 0..WIDTH {
let val = self.rows[row].0[col].load(Ordering::Relaxed);
buf.extend_from_slice(&val.to_le_bytes());
}
}
buf
}
pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
let matrix_size = DEPTH * WIDTH * 4;
let header_size = 8 + (DEPTH * 8 * 2);
if bytes.len() != header_size + matrix_size {
return None;
}
let mut pos = 0;
let epoch = u64::from_le_bytes(bytes[pos..pos + 8].try_into().ok()?);
pos += 8;
let mut hash_a = [0u64; DEPTH];
for item in hash_a.iter_mut() {
*item = u64::from_le_bytes(bytes[pos..pos + 8].try_into().ok()?);
pos += 8;
}
let mut hash_b = [0u64; DEPTH];
for item in hash_b.iter_mut() {
*item = u64::from_le_bytes(bytes[pos..pos + 8].try_into().ok()?);
pos += 8;
}
let sketch = Self::alloc(None, 0, hash_a, hash_b);
sketch.epoch.store(epoch, Ordering::Relaxed);
for row in 0..DEPTH {
for col in 0..WIDTH {
let val = u32::from_le_bytes(bytes[pos..pos + 4].try_into().ok()?);
sketch.rows[row].0[col].store(val, Ordering::Relaxed);
pos += 4;
}
}
Some(sketch)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_unpack_roundtrip() {
for ts in [0, 1, 127, 255] {
for count in [0, 1, 1000, MAX_COUNT] {
let packed = pack(ts, count);
let (got_ts, got_count) = unpack(packed);
assert_eq!(got_ts, ts, "ts mismatch");
assert_eq!(got_count, count, "count mismatch");
}
}
}
#[test]
fn pack_truncates_overflow() {
let (ts, _) = unpack(pack(256, 0));
assert_eq!(ts, 0);
let (_, count) = unpack(pack(0, MAX_COUNT + 1));
assert_eq!(count, 0); }
#[test]
fn decay_steps_same_epoch() {
assert_eq!(decay_steps(5, 5), 0);
assert_eq!(decay_steps(0, 0), 0);
assert_eq!(decay_steps(255, 255), 0);
}
#[test]
fn decay_steps_simple_gap() {
assert_eq!(decay_steps(0, 1), 1);
assert_eq!(decay_steps(0, 5), 5);
assert_eq!(decay_steps(10, 15), 5);
}
#[test]
fn decay_steps_wraparound() {
assert_eq!(decay_steps(250, 2), 8);
}
#[test]
fn decay_steps_capped_at_count_bits() {
assert_eq!(decay_steps(0, 100), 24);
}
#[test]
fn hash_produces_valid_indices() {
for row in 0..DEPTH {
for id in [0, 1, u64::MAX, 0xDEADBEEF, 42] {
let idx = compute_hash(id, row, &DEFAULT_HASH_A, &DEFAULT_HASH_B);
assert!(idx < WIDTH, "hash out of range: {}", idx);
}
}
}
#[test]
fn hash_different_rows_differ() {
let id = 0xDEADBEEF_u64;
let indices: Vec<usize> = (0..DEPTH)
.map(|r| compute_hash(id, r, &DEFAULT_HASH_A, &DEFAULT_HASH_B))
.collect();
let unique: std::collections::HashSet<_> = indices.iter().collect();
assert!(unique.len() >= 2, "hash functions not independent enough");
}
#[test]
fn hash_distribution_uniformity() {
let mut buckets = vec![0u32; WIDTH];
for id in 0..100_000u64 {
let idx = compute_hash(
id.wrapping_mul(0x12345),
0,
&DEFAULT_HASH_A,
&DEFAULT_HASH_B,
);
buckets[idx] += 1;
}
let max = *buckets.iter().max().unwrap();
let expected = 100_000.0 / WIDTH as f64;
assert!(
max as f64 <= expected * 10.0,
"distribution too skewed: max={}, expected={:.1}",
max,
expected
);
}
#[test]
fn basic_increment_and_query() {
let sketch = ClTds::new_deterministic();
let id = 42u64;
for _ in 0..100 {
sketch.increment(id);
}
let count = sketch.query(id);
assert_eq!(count, 100, "expected 100, got {}", count);
}
#[test]
fn query_unseen_item_returns_zero() {
let sketch = ClTds::new_deterministic();
sketch.increment(1);
sketch.increment(2);
let count = sketch.query(999);
assert_eq!(count, 0, "unseen item should have count 0, got {}", count);
}
#[test]
fn multiple_items_independent() {
let sketch = ClTds::new_deterministic();
sketch.increment(100);
sketch.increment(100);
sketch.increment(100);
sketch.increment(200);
assert_eq!(sketch.query(100), 3);
assert_eq!(sketch.query(200), 1);
}
#[test]
fn counter_saturation() {
let sketch = ClTds::new_deterministic();
let id = 1u64;
let col = sketch.hash(id, 0);
let near_max = pack(0, MAX_COUNT - 1);
sketch.rows[0].0[col].store(near_max, Ordering::Relaxed);
sketch.increment(id);
let val = sketch.rows[0].0[col].load(Ordering::Relaxed);
let (_, count) = unpack(val);
assert_eq!(count, MAX_COUNT);
}
#[test]
fn decay_halves_count() {
let sketch = ClTds::new_deterministic();
let id = 1u64;
for _ in 0..1000 {
sketch.increment(id);
}
assert_eq!(sketch.query(id), 1000);
sketch.tick_epoch();
assert_eq!(sketch.query(id), 500);
sketch.tick_epoch();
assert_eq!(sketch.query(id), 250); }
#[test]
fn decay_multiple_epochs() {
let sketch = ClTds::new_deterministic();
let id = 7u64;
for _ in 0..1024 {
sketch.increment(id);
}
for _ in 0..5 {
sketch.tick_epoch();
}
assert_eq!(sketch.query(id), 32);
}
#[test]
fn decay_to_zero() {
let sketch = ClTds::new_deterministic();
let id = 1u64;
sketch.increment(id);
assert_eq!(sketch.query(id), 1);
sketch.tick_epoch();
assert_eq!(sketch.query(id), 0);
}
#[test]
fn decay_with_interleaved_inserts() {
let sketch = ClTds::new_deterministic();
let id = 1u64;
for _ in 0..100 {
sketch.increment(id);
}
sketch.tick_epoch();
for _ in 0..50 {
sketch.increment(id);
}
assert_eq!(sketch.query(id), 100);
}
#[test]
fn claim1_no_undercounting() {
let sketch = ClTds::new_deterministic();
let target = 42u64;
for _ in 0..500 {
sketch.increment(target);
}
for id in 1000..2000u64 {
sketch.increment(id);
}
let est = sketch.query(target);
assert!(
est >= 500,
"Claim 1 violation: undercounting! got {} < 500",
est
);
}
#[test]
fn claim1_overcount_bounded() {
let sketch = ClTds::new_deterministic();
let heavy_hitter = 42u64;
let true_count = 1000u32;
for _ in 0..true_count {
sketch.increment(heavy_hitter);
}
let n_noise = 100_000u64;
for id in 0..n_noise {
sketch.increment(id + 10_000);
}
let n_total = true_count as f64 + n_noise as f64;
let epsilon = std::f64::consts::E / WIDTH as f64;
let max_overcount = epsilon * n_total;
let est = sketch.query(heavy_hitter);
let overcount = est as f64 - true_count as f64;
assert!(
overcount <= max_overcount * 10.0,
"Claim 1 error too large: overcount={:.0}, bound={:.0}",
overcount,
max_overcount * 10.0
);
}
#[test]
fn claim2_false_positive_rate() {
let sketch = ClTds::new_deterministic();
for id in 0..100u64 {
for _ in 0..1000 {
sketch.increment(id);
}
}
let _n_total = 100 * 1000;
let threshold = 50u32;
let n_queries = 10_000u64;
let mut false_positives = 0u64;
for id in 1_000_000..1_000_000 + n_queries {
if sketch.query(id) > threshold {
false_positives += 1;
}
}
let fp_rate = false_positives as f64 / n_queries as f64;
assert!(
fp_rate <= 0.05,
"Claim 2 violation: false positive rate {:.2}% > 5%",
fp_rate * 100.0
);
}
#[test]
fn claim3_lazy_equals_full_decay() {
for initial in [1, 7, 100, 1000, 65535, MAX_COUNT] {
for total_steps in 0..=24u32 {
let mut full = initial;
for _ in 0..total_steps {
full >>= 1;
}
let lazy = apply_decay(initial, total_steps);
assert_eq!(
full, lazy,
"Claim 3 violation! initial={}, steps={}: full={}, lazy={}",
initial, total_steps, full, lazy
);
}
}
}
#[test]
fn claim3_lazy_decay_in_sketch() {
let sketch_lazy = ClTds::new_deterministic();
let id = 99u64;
for _ in 0..1024 {
sketch_lazy.increment(id);
}
for _ in 0..5 {
sketch_lazy.tick_epoch();
}
let lazy_result = sketch_lazy.query(id);
let full_result = 1024u32 >> 5;
assert_eq!(
lazy_result, full_result,
"Claim 3 in-sketch: lazy={}, full={}",
lazy_result, full_result
);
}
#[test]
fn concurrent_increments() {
use std::sync::Arc;
let sketch = Arc::new(ClTds::new_deterministic());
let id = 42u64;
let threads = 4;
let per_thread = 10_000;
std::thread::scope(|s| {
for _ in 0..threads {
let sk = Arc::clone(&sketch);
s.spawn(move || {
for _ in 0..per_thread {
sk.increment(id);
}
});
}
});
let total = sketch.query(id);
let expected = (threads * per_thread) as u32;
assert_eq!(
total, expected,
"Thread safety: got {}, expected {}",
total, expected
);
}
#[test]
fn matrix_size_is_1mb() {
let sketch = ClTds::new_deterministic();
let size = sketch.memory_bytes();
assert_eq!(
size, 1_048_576,
"Matrix should be exactly 1 MB, got {} bytes",
size
);
}
#[test]
fn epoch_advances() {
let sketch = ClTds::new_deterministic();
assert_eq!(sketch.epoch(), 0);
sketch.tick_epoch();
assert_eq!(sketch.epoch(), 1);
sketch.tick_epoch();
sketch.tick_epoch();
assert_eq!(sketch.epoch(), 3);
}
#[test]
fn claim2_zipf_stress_test() {
let sketch = ClTds::new_deterministic();
let n_heavy = 10u64; let n_mice = 990u64; let per_heavy = 8_000u64; let per_mouse = 20u64;
for id in 1..=n_heavy {
for _ in 0..per_heavy {
sketch.increment(id);
}
}
for id in 0..n_mice {
for _ in 0..per_mouse {
sketch.increment(id + 100_000);
}
}
let threshold = 1000u32;
for id in 1..=n_heavy {
let count = sketch.query(id);
assert!(
count >= per_heavy as u32,
"Zipf: heavy hitter {} not detected: count={} < {}",
id,
count,
per_heavy
);
}
let n_test = 50_000u64;
let mut fp = 0u64;
for id in 1_000_000..1_000_000 + n_test {
if sketch.query(id) > threshold {
fp += 1;
}
}
let fp_rate = fp as f64 / n_test as f64;
assert!(
fp_rate <= 0.02,
"Zipf Claim 2: false positive rate {:.3}% exceeds 2%",
fp_rate * 100.0
);
}
#[test]
fn manual_mode_by_default() {
let sketch = ClTds::new_deterministic();
assert!(!sketch.is_auto_epoch());
assert_eq!(sketch.epoch(), 0);
}
#[test]
fn auto_mode_via_constructor() {
let sketch = ClTds::with_epoch_interval(1000);
assert!(sketch.is_auto_epoch());
assert_eq!(sketch.epoch(), 0);
}
#[test]
fn auto_epoch_advances_with_time() {
let sketch = ClTds::with_epoch_interval(1);
for _ in 0..100 {
sketch.increment(42);
}
std::thread::sleep(std::time::Duration::from_millis(15));
let epoch = sketch.epoch();
assert!(
epoch >= 10,
"auto epoch should advance with time, got {}",
epoch
);
}
#[test]
fn auto_epoch_decay_works() {
let sketch = ClTds::with_epoch_interval(10);
for _ in 0..1024 {
sketch.increment(7);
}
assert_eq!(sketch.query(7), 1024);
std::thread::sleep(std::time::Duration::from_millis(55));
let count = sketch.query(7);
assert!(
count <= 64 && count >= 16,
"auto decay after ~5 epochs: expected ~32, got {}",
count
);
}
#[test]
#[should_panic(expected = "epoch interval must be > 0")]
fn auto_epoch_zero_interval_panics() {
let _sketch = ClTds::with_epoch_interval(0);
}
#[test]
fn two_sketches_different_hashes() {
let s1 = ClTds::new();
let s2 = ClTds::new();
let id = 42u64;
let idx1: Vec<usize> = (0..DEPTH).map(|r| s1.hash(id, r)).collect();
let idx2: Vec<usize> = (0..DEPTH).map(|r| s2.hash(id, r)).collect();
assert_ne!(
idx1, idx2,
"Two sketches should have different hash mappings"
);
}
#[test]
fn random_hashes_produce_valid_results() {
let sketch = ClTds::new();
for _ in 0..100 {
sketch.increment(42);
}
let count = sketch.query(42);
assert_eq!(count, 100, "random hashes: expected 100, got {}", count);
assert_eq!(sketch.query(999999), 0, "unseen item should be 0");
}
#[test]
fn deterministic_hashes_reproducible() {
let s1 = ClTds::new_deterministic();
let s2 = ClTds::new_deterministic();
for _ in 0..500 {
s1.increment(42);
s2.increment(42);
}
assert_eq!(s1.query(42), s2.query(42));
}
#[test]
fn algorithm_parameters_correct() {
let (epsilon, delta, w, d) = ClTds::algorithm_parameters();
assert_eq!(w, 65536);
assert_eq!(d, 4);
assert!((epsilon - std::f64::consts::E / 65536.0).abs() < 1e-10);
assert!((delta - (-4.0_f64).exp()).abs() < 1e-10);
assert!(epsilon < 0.0001, "ε should be tiny");
assert!(delta < 0.02, "δ should be < 2%");
}
#[test]
fn error_bound_scales_with_stream() {
let err_1m = ClTds::error_bound(1_000_000);
let err_10m = ClTds::error_bound(10_000_000);
assert!((err_10m / err_1m - 10.0).abs() < 0.001);
assert!(err_1m < 50.0, "error bound for 1M stream should be < 50");
}
#[test]
fn adversarial_collision_min_filter_holds() {
let sketch = ClTds::new_deterministic();
for id in 0..1000u64 {
for _ in 0..1000 {
sketch.increment(id);
}
}
let mut high_false = 0u64;
for id in 500_000..510_000u64 {
let count = sketch.query(id);
if count > 100 {
high_false += 1;
}
}
let fp_rate = high_false as f64 / 10_000.0;
assert!(
fp_rate < 0.02,
"Adversarial: false positive rate {:.2}% exceeds 2%",
fp_rate * 100.0
);
}
#[test]
fn decay_accuracy_over_many_epochs() {
let sketch = ClTds::new_deterministic();
let id = 1u64;
let initial = 1_000_000u32;
for _ in 0..initial {
sketch.increment(id);
}
for epoch in 1..=20u32 {
sketch.tick_epoch();
let expected = initial >> epoch;
let actual = sketch.query(id);
assert_eq!(
actual, expected,
"Epoch {}: expected {}, got {}",
epoch, expected, actual
);
}
}
#[test]
fn concurrent_increment_with_decay() {
use std::sync::Arc;
let sketch = Arc::new(ClTds::new_deterministic());
let id = 99u64;
let threads = 4;
let per_thread = 5_000;
std::thread::scope(|s| {
for _ in 0..threads {
let sk = Arc::clone(&sketch);
s.spawn(move || {
for _ in 0..per_thread {
sk.increment(id);
}
});
}
let sk = Arc::clone(&sketch);
s.spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(1));
sk.tick_epoch();
std::thread::sleep(std::time::Duration::from_millis(1));
sk.tick_epoch();
std::thread::sleep(std::time::Duration::from_millis(1));
sk.tick_epoch();
});
});
let count = sketch.query(id);
assert!(
count > 0 && count <= (threads * per_thread) as u32,
"Concurrent+decay: count {} out of reasonable range",
count
);
}
#[test]
fn full_integration_test() {
let sketch = ClTds::new();
for id in 1..=5u64 {
for _ in 0..10_000 {
sketch.increment(id);
}
}
for id in 100..1100u64 {
for _ in 0..10 {
sketch.increment(id);
}
}
for id in 1..=5u64 {
assert!(
sketch.query(id) >= 10_000,
"Integration: heavy hitter {} not detected",
id
);
}
for _ in 0..10 {
sketch.tick_epoch();
}
for id in 1..=5u64 {
let decayed = sketch.query(id);
assert!(
decayed <= 15,
"Integration: heavy hitter {} not forgotten after decay: {}",
id,
decayed
);
}
for _ in 0..500 {
sketch.increment(42);
}
assert_eq!(
sketch.query(42),
500,
"New traffic after decay should be exact"
);
let (eps, delta, w, d) = ClTds::algorithm_parameters();
assert!(eps > 0.0 && delta > 0.0 && w == WIDTH && d == DEPTH);
}
}