pub struct CountMinSketch {
table: Vec<Vec<u64>>,
width: usize,
depth: usize,
}
impl CountMinSketch {
#[must_use]
pub fn new(epsilon: f64, delta: f64) -> Self {
assert!(epsilon > 0.0 && epsilon < 1.0, "epsilon must be in (0,1)");
assert!(delta > 0.0 && delta < 1.0, "delta must be in (0,1)");
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let width = (std::f64::consts::E / epsilon).ceil() as usize;
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let depth = (1.0_f64 / delta).ln().ceil() as usize;
Self {
table: vec![vec![0u64; width]; depth],
width,
depth,
}
}
fn hash(bytes: &[u8], seed: u32, width: usize) -> usize {
let mut h: u32 = 2_166_136_261u32 ^ seed;
for &b in bytes {
h ^= u32::from(b);
h = h.wrapping_mul(16_777_619);
}
(h as usize) % width
}
pub fn increment(&mut self, key: impl AsRef<[u8]>, count: u64) {
let bytes = key.as_ref();
for i in 0..self.depth {
#[allow(clippy::cast_possible_truncation)]
let idx = Self::hash(bytes, (i as u32).wrapping_mul(0x9e37_79b9), self.width);
self.table[i][idx] = self.table[i][idx].saturating_add(count);
}
}
#[must_use]
pub fn estimate(&self, key: impl AsRef<[u8]>) -> u64 {
let bytes = key.as_ref();
(0..self.depth)
.map(|i| {
#[allow(clippy::cast_possible_truncation)]
let idx = Self::hash(bytes, (i as u32).wrapping_mul(0x9e37_79b9), self.width);
self.table[i][idx]
})
.min()
.unwrap_or(0)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracks_frequency() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("a", 5);
cms.increment("a", 3);
cms.increment("b", 1);
assert!(cms.estimate("a") >= 8);
assert!(cms.estimate("b") >= 1);
assert_eq!(cms.estimate("ghost"), 0);
}
#[test]
fn default_increment_one() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment("k", 1);
cms.increment("k", 1);
assert!(cms.estimate("k") >= 2);
}
#[test]
fn accepts_byte_slices() {
let mut cms = CountMinSketch::new(0.01, 0.01);
cms.increment(b"sensor" as &[u8], 3);
assert!(cms.estimate(b"sensor" as &[u8]) >= 3);
}
}