use alloc::vec;
use alloc::vec::Vec;
use core::hash::{Hash, Hasher};
use std::hash::DefaultHasher;
use crate::error::{RcfError, RcfResult};
pub const MAX_WIDTH: usize = 1 << 18;
pub const MAX_DEPTH: usize = 16;
pub struct CountMinSketch {
table: Vec<Vec<u64>>,
seeds: Vec<(u64, u64)>,
width: usize,
depth: usize,
total: u64,
}
#[allow(clippy::missing_fields_in_debug)] impl core::fmt::Debug for CountMinSketch {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("CountMinSketch")
.field("width", &self.width)
.field("depth", &self.depth)
.field("total", &self.total)
.field("memory_bytes", &self.memory_bytes())
.finish()
}
}
impl CountMinSketch {
#[allow(clippy::cast_possible_truncation)]
pub fn new(width: usize, depth: usize) -> RcfResult<Self> {
if width == 0 || width > MAX_WIDTH {
return Err(RcfError::InvalidConfig(
alloc::format!("CountMinSketch: width {width} out of (0, {MAX_WIDTH}]").into(),
));
}
if depth == 0 || depth > MAX_DEPTH {
return Err(RcfError::InvalidConfig(
alloc::format!("CountMinSketch: depth {depth} out of (0, {MAX_DEPTH}]").into(),
));
}
let seeds: Vec<(u64, u64)> = (0..depth)
.map(|i| {
let idx = i as u64 + 1;
let a = 0x517c_c1b7_2722_0a95_u64.wrapping_mul(idx);
let b = 0x6c62_272e_07bb_0142_u64.wrapping_mul(idx);
(a, b)
})
.collect();
Ok(Self {
table: vec![vec![0_u64; width]; depth],
seeds,
width,
depth,
total: 0,
})
}
#[must_use]
pub fn depth(&self) -> usize {
self.depth
}
#[must_use]
pub fn width(&self) -> usize {
self.width
}
#[inline]
pub fn increment(&mut self, key: &[u8], count: u64) {
self.total = self.total.saturating_add(count);
for row in 0..self.depth {
let col = self.hash_to_col(key, row);
self.table[row][col] = self.table[row][col].saturating_add(count);
}
}
#[must_use]
#[inline]
pub fn estimate(&self, key: &[u8]) -> u64 {
(0..self.depth)
.map(|row| {
let col = self.hash_to_col(key, row);
self.table[row][col]
})
.min()
.unwrap_or(0)
}
#[must_use]
pub fn total(&self) -> u64 {
self.total
}
pub fn reset(&mut self) {
for row in &mut self.table {
for cell in row.iter_mut() {
*cell = 0;
}
}
self.total = 0;
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
self.width * self.depth * core::mem::size_of::<u64>()
}
#[allow(clippy::cast_possible_truncation)]
#[inline]
fn hash_to_col(&self, key: &[u8], row: usize) -> usize {
let (a, b) = self.seeds[row];
let mut hasher = DefaultHasher::new();
a.hash(&mut hasher);
key.hash(&mut hasher);
b.hash(&mut hasher);
(hasher.finish() as usize) % self.width
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_increment_and_estimate() {
let mut cms = CountMinSketch::new(2048, 4).unwrap();
cms.increment(b"192.168.1.1", 100);
cms.increment(b"192.168.1.1", 50);
cms.increment(b"10.0.0.1", 30);
assert!(cms.estimate(b"192.168.1.1") >= 150);
assert!(cms.estimate(b"10.0.0.1") >= 30);
assert_eq!(cms.total(), 180);
}
#[test]
fn reset_clears_all() {
let mut cms = CountMinSketch::new(256, 3).unwrap();
cms.increment(b"key", 1000);
assert!(cms.estimate(b"key") >= 1000);
cms.reset();
assert_eq!(cms.estimate(b"key"), 0);
assert_eq!(cms.total(), 0);
}
#[test]
fn accuracy_bounds_with_many_keys() {
let mut cms = CountMinSketch::new(2048, 4).unwrap();
let n = 100_000_u64;
for i in 0..n {
let key = i.to_le_bytes();
cms.increment(&key, 1);
}
let heavy = b"heavy_hitter";
cms.increment(heavy, 1000);
let estimate = cms.estimate(heavy);
assert!(estimate >= 1000, "estimate {estimate} < true count 1000");
assert!(
estimate <= 1000 + 200,
"estimate {estimate} too far from true count 1000 (> 200 error)"
);
}
#[test]
fn memory_footprint() {
let cms = CountMinSketch::new(2048, 4).unwrap();
assert_eq!(cms.memory_bytes(), 2048 * 4 * 8); }
#[test]
fn different_keys_different_estimates() {
let mut cms = CountMinSketch::new(1024, 4).unwrap();
cms.increment(b"alpha", 500);
cms.increment(b"beta", 100);
assert!(cms.estimate(b"alpha") >= 500);
assert!(cms.estimate(b"beta") >= 100);
assert!(cms.estimate(b"gamma") < 50);
}
#[test]
fn saturates_at_u64_max() {
let mut cms = CountMinSketch::new(64, 2).unwrap();
cms.increment(b"k", u64::MAX - 10);
cms.increment(b"k", 100);
assert_eq!(cms.estimate(b"k"), u64::MAX);
assert_eq!(cms.total(), u64::MAX);
}
#[test]
fn dim_accessors_match_constructor() {
let cms = CountMinSketch::new(512, 5).unwrap();
assert_eq!(cms.width(), 512);
assert_eq!(cms.depth(), 5);
}
#[test]
fn rejects_zero_dims() {
assert!(CountMinSketch::new(0, 4).is_err());
assert!(CountMinSketch::new(2048, 0).is_err());
}
#[test]
fn rejects_oversized_dims() {
assert!(CountMinSketch::new(MAX_WIDTH + 1, 4).is_err());
assert!(CountMinSketch::new(2048, MAX_DEPTH + 1).is_err());
assert!(CountMinSketch::new(usize::MAX, usize::MAX).is_err());
}
}