use alloc::vec;
use alloc::vec::Vec;
pub struct CountMinSketch {
table: Vec<Vec<u64>>,
width: usize,
depth: usize,
}
impl CountMinSketch {
#[cfg(feature = "std")]
#[must_use]
pub fn new(epsilon: f64, delta: f64) -> Self {
assert!(epsilon > 0.0 && epsilon < 1.0, "epsilon must be in (0,1)");
assert!(delta > 0.0 && delta < 1.0, "delta must be in (0,1)");
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let width = (core::f64::consts::E / epsilon).ceil() as usize;
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let depth = (1.0_f64 / delta).ln().ceil() as usize;
Self::from_raw_params(width, depth)
}
#[must_use]
pub fn from_raw_params(width: usize, depth: usize) -> Self {
assert!(width > 0, "width must be > 0");
assert!(depth > 0, "depth must be > 0");
Self {
table: vec![vec![0u64; width]; depth],
width,
depth,
}
}
fn hash(bytes: &[u8], seed: u32, width: usize) -> usize {
let mut h: u32 = 0x811c_9dc5_u32 ^ seed;
for &b in bytes {
h ^= u32::from(b);
h = h.wrapping_mul(0x0100_0193);
}
(h as usize) % width
}
pub fn increment(&mut self, key: impl AsRef<[u8]>, count: u64) {
let bytes = key.as_ref();
for i in 0..self.depth {
#[allow(clippy::cast_possible_truncation)]
let idx = Self::hash(bytes, (i as u32).wrapping_mul(0x9e37_79b9), self.width);
self.table[i][idx] = self.table[i][idx].saturating_add(count);
}
}
#[must_use]
pub fn estimate(&self, key: impl AsRef<[u8]>) -> u64 {
let bytes = key.as_ref();
(0..self.depth)
.map(|i| {
#[allow(clippy::cast_possible_truncation)]
let idx = Self::hash(bytes, (i as u32).wrapping_mul(0x9e37_79b9), self.width);
self.table[i][idx]
})
.min()
.unwrap_or(0)
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
#[test]
fn tracks_frequency() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("a", 5);
cms.increment("a", 3);
cms.increment("b", 1);
assert!(cms.estimate("a") >= 8);
assert!(cms.estimate("b") >= 1);
assert_eq!(cms.estimate("ghost"), 0);
}
#[test]
fn default_increment_one() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("k", 1);
cms.increment("k", 1);
assert!(cms.estimate("k") >= 2);
}
#[test]
fn accepts_byte_slices() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment(b"sensor" as &[u8], 3);
assert!(cms.estimate(b"sensor" as &[u8]) >= 3);
}
#[test]
fn never_undercounts() {
let mut cms = CountMinSketch::new(0.01, 0.01);
for i in 0..200 {
let key = format!("key-{i}");
let count = u64::try_from(i).unwrap() + 1;
cms.increment(&key, count);
assert!(
cms.estimate(&key) >= count,
"Undercount detected for {key}: estimate {} < actual {count}",
cms.estimate(&key)
);
}
}
#[test]
fn zero_increment_does_not_change_estimate() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("k", 0);
assert_eq!(cms.estimate("k"), 0);
}
#[test]
fn zero_increment_after_nonzero() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("k", 5);
cms.increment("k", 0);
assert!(cms.estimate("k") >= 5);
}
#[test]
fn very_large_counts() {
let mut cms = CountMinSketch::new(0.01, 0.01);
let large = u64::MAX / 2;
cms.increment("big", large);
assert!(cms.estimate("big") >= large);
}
#[test]
fn saturating_add_on_overflow() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("max", u64::MAX);
cms.increment("max", 1);
assert_eq!(cms.estimate("max"), u64::MAX);
}
#[test]
fn unseen_key_returns_zero() {
let cms = CountMinSketch::new(0.01, 0.01);
assert_eq!(cms.estimate("never-added"), 0);
}
#[test]
fn from_raw_params_works() {
let mut cms = CountMinSketch::from_raw_params(100, 5);
cms.increment("x", 7);
assert!(cms.estimate("x") >= 7);
assert_eq!(cms.estimate("y"), 0);
}
#[test]
#[should_panic(expected = "width must be > 0")]
fn from_raw_params_zero_width_panics() {
let _ = CountMinSketch::from_raw_params(0, 5);
}
#[test]
#[should_panic(expected = "depth must be > 0")]
fn from_raw_params_zero_depth_panics() {
let _ = CountMinSketch::from_raw_params(5, 0);
}
#[test]
#[should_panic(expected = "epsilon must be in (0,1)")]
fn panics_on_zero_epsilon() {
let _ = CountMinSketch::new(0.0, 0.01);
}
#[test]
#[should_panic(expected = "delta must be in (0,1)")]
fn panics_on_one_delta() {
let _ = CountMinSketch::new(0.01, 1.0);
}
#[test]
fn multiple_keys_independent() {
let mut cms = CountMinSketch::new(0.001, 0.001);
cms.increment("alpha", 100);
cms.increment("beta", 200);
cms.increment("gamma", 300);
assert!(cms.estimate("alpha") >= 100);
assert!(cms.estimate("beta") >= 200);
assert!(cms.estimate("gamma") >= 300);
}
}