use std::hash::Hasher;
use serde::{Deserialize, Serialize};
use twox_hash::XxHash64;
use crate::sketches::Sketch;
use crate::{Error, Result};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct CountMinSketch {
depth: u32,
width: u32,
counters: Vec<u32>, total: u64,
}
impl CountMinSketch {
pub fn try_new(depth: u32, width: u32) -> Result<Self> {
if depth == 0 || width == 0 {
return Err(Error::InvalidSketch(
"CMS depth and width must be > 0".into(),
));
}
let size = (depth as usize)
.checked_mul(width as usize)
.ok_or_else(|| Error::InvalidSketch("CMS depth*width overflows usize".into()))?;
Ok(Self {
depth,
width,
counters: vec![0u32; size],
total: 0,
})
}
pub fn new(depth: u32, width: u32) -> Result<Self> {
Self::try_new(depth, width)
}
pub fn with_defaults() -> Self {
Self::try_new(5, 1024).expect("defaults are valid")
}
fn validate(&self) -> Result<()> {
if self.depth == 0 || self.width == 0 {
return Err(Error::InvalidSketch(
"CMS decoded depth/width must both be > 0".into(),
));
}
let expected = (self.depth as usize)
.checked_mul(self.width as usize)
.ok_or_else(|| Error::InvalidSketch("CMS decoded depth*width overflows".into()))?;
if self.counters.len() != expected {
return Err(Error::InvalidSketch(format!(
"CMS counter length {} != depth*width = {}",
self.counters.len(),
expected
)));
}
Ok(())
}
fn hash(item: &[u8], row: u32) -> u64 {
let mut h = XxHash64::with_seed(0x1010_d017 ^ u64::from(row));
h.write(item);
h.finish()
}
pub fn add(&mut self, item: &[u8], count: u32) {
for row in 0..self.depth {
let idx = (Self::hash(item, row) % u64::from(self.width)) as usize;
let pos = (row as usize) * (self.width as usize) + idx;
self.counters[pos] = self.counters[pos].saturating_add(count);
}
self.total = self.total.saturating_add(u64::from(count));
}
pub fn estimate(&self, item: &[u8]) -> u32 {
(0..self.depth)
.map(|row| {
let idx = (Self::hash(item, row) % u64::from(self.width)) as usize;
let pos = (row as usize) * (self.width as usize) + idx;
self.counters[pos]
})
.min()
.unwrap_or(0)
}
pub fn merge(&mut self, other: &Self) -> Result<()> {
if self.depth != other.depth || self.width != other.width {
return Err(Error::InvalidSketch(
"CMS depth/width mismatch in merge".into(),
));
}
for (a, b) in self.counters.iter_mut().zip(other.counters.iter()) {
*a = a.saturating_add(*b);
}
self.total = self.total.saturating_add(other.total);
Ok(())
}
pub fn depth(&self) -> u32 {
self.depth
}
pub fn width(&self) -> u32 {
self.width
}
pub fn total(&self) -> u64 {
self.total
}
}
impl Sketch for CountMinSketch {
const KIND: &'static str = "samkhya.cms-v1";
fn to_bytes(&self) -> Result<Vec<u8>> {
bincode::serialize(self).map_err(Into::into)
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
let s: Self = bincode::deserialize(bytes).map_err(Error::from)?;
s.validate()?;
Ok(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn never_undercounts() {
let mut cms = CountMinSketch::new(5, 1024).unwrap();
for i in 0..1000u32 {
for _ in 0..5 {
cms.add(&i.to_le_bytes(), 1);
}
}
for i in 0..1000u32 {
assert!(
cms.estimate(&i.to_le_bytes()) >= 5,
"undercount for {i}: {}",
cms.estimate(&i.to_le_bytes())
);
}
}
#[test]
fn heavy_hitter_detected() {
let mut cms = CountMinSketch::with_defaults();
for i in 0..1000u32 {
cms.add(&i.to_le_bytes(), 1);
}
cms.add(b"heavy", 10_000);
let heavy_est = cms.estimate(b"heavy");
let light_est = cms.estimate(&42u32.to_le_bytes());
assert!(
(10_000..11_000).contains(&heavy_est),
"heavy est {heavy_est} out of range"
);
assert!(light_est < 50, "light est {light_est} too high");
}
#[test]
fn merge_adds_counts() {
let mut a = CountMinSketch::new(3, 100).unwrap();
let mut b = CountMinSketch::new(3, 100).unwrap();
a.add(b"x", 5);
b.add(b"x", 3);
a.merge(&b).unwrap();
assert!(a.estimate(b"x") >= 8);
}
#[test]
fn merge_mismatched_dimensions_errors() {
let mut a = CountMinSketch::new(3, 100).unwrap();
let b = CountMinSketch::new(4, 100).unwrap();
assert!(a.merge(&b).is_err());
}
#[test]
fn round_trip() {
let mut cms = CountMinSketch::with_defaults();
for i in 0..100u32 {
cms.add(&i.to_le_bytes(), 1);
}
let bytes = cms.to_bytes().unwrap();
let cms2 = CountMinSketch::from_bytes(&bytes).unwrap();
for i in 0..100u32 {
assert_eq!(
cms.estimate(&i.to_le_bytes()),
cms2.estimate(&i.to_le_bytes())
);
}
assert_eq!(cms.total, cms2.total);
}
#[test]
fn invalid_dimensions_error() {
assert!(CountMinSketch::new(0, 100).is_err());
assert!(CountMinSketch::new(5, 0).is_err());
}
#[test]
fn try_new_rejects_each_invalid_dimension() {
assert!(CountMinSketch::try_new(0, 0).is_err());
assert!(CountMinSketch::try_new(0, 100).is_err());
assert!(CountMinSketch::try_new(5, 0).is_err());
}
#[test]
fn try_new_accepts_valid_dimensions() {
let cms = CountMinSketch::try_new(4, 256).unwrap();
assert_eq!(cms.depth(), 4);
assert_eq!(cms.width(), 256);
}
#[test]
fn from_bytes_rejects_all_zero_payload() {
for n in [4usize, 8, 16, 24, 32, 64, 128, 256] {
let zeros = vec![0u8; n];
assert!(
CountMinSketch::from_bytes(&zeros).is_err(),
"all-zero len {n} accepted by from_bytes"
);
}
}
#[test]
fn from_bytes_rejects_counter_length_mismatch() {
let depth: u32 = 2;
let width: u32 = 4;
let mut payload = Vec::new();
payload.extend_from_slice(&depth.to_le_bytes());
payload.extend_from_slice(&width.to_le_bytes());
let bad_len: u64 = 3;
payload.extend_from_slice(&bad_len.to_le_bytes());
for _ in 0..3 {
payload.extend_from_slice(&0u32.to_le_bytes());
}
payload.extend_from_slice(&0u64.to_le_bytes());
assert!(
CountMinSketch::from_bytes(&payload).is_err(),
"from_bytes accepted counter-length mismatch"
);
}
#[test]
fn from_bytes_accepts_valid_payload() {
let cms = CountMinSketch::with_defaults();
let bytes = cms.to_bytes().unwrap();
let decoded = CountMinSketch::from_bytes(&bytes).unwrap();
assert_eq!(cms.depth(), decoded.depth());
assert_eq!(cms.width(), decoded.width());
}
}