use crate::error::{CoreError, CoreResult, ErrorContext};
use crate::error_context;
use super::DoubleHasher;
#[derive(Clone)]
pub struct CountMinSketch {
counters: Vec<u64>,
width: usize,
depth: usize,
total_count: u64,
hashers: Vec<DoubleHasher>,
}
impl CountMinSketch {
pub fn new(epsilon: f64, delta: f64) -> CoreResult<Self> {
if epsilon <= 0.0 || epsilon >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("epsilon must be in (0, 1)"),
));
}
if delta <= 0.0 || delta >= 1.0 {
return Err(CoreError::InvalidArgument(
error_context!("delta must be in (0, 1)"),
));
}
let width = (std::f64::consts::E / epsilon).ceil() as usize;
let depth = (1.0_f64 / delta).ln().ceil() as usize;
let width = width.max(1);
let depth = depth.max(1);
Self::with_dimensions(width, depth)
}
pub fn with_dimensions(width: usize, depth: usize) -> CoreResult<Self> {
if width == 0 {
return Err(CoreError::InvalidArgument(
error_context!("width must be > 0"),
));
}
if depth == 0 {
return Err(CoreError::InvalidArgument(
error_context!("depth must be > 0"),
));
}
let hashers: Vec<DoubleHasher> = (0..depth).map(|_| DoubleHasher::new()).collect();
Ok(Self {
counters: vec![0u64; width * depth],
width,
depth,
total_count: 0,
hashers,
})
}
pub fn increment(&mut self, item: &[u8]) {
self.increment_by(item, 1);
}
pub fn increment_by(&mut self, item: &[u8], count: u64) {
for row in 0..self.depth {
let col = self.hash_to_col(row, item);
self.counters[row * self.width + col] =
self.counters[row * self.width + col].saturating_add(count);
}
self.total_count = self.total_count.saturating_add(count);
}
pub fn increment_conservative(&mut self, item: &[u8]) {
self.increment_conservative_by(item, 1);
}
pub fn increment_conservative_by(&mut self, item: &[u8], count: u64) {
let current_min = self.estimate(item);
let new_val = current_min.saturating_add(count);
for row in 0..self.depth {
let col = self.hash_to_col(row, item);
let idx = row * self.width + col;
if self.counters[idx] < new_val {
self.counters[idx] = new_val;
}
}
self.total_count = self.total_count.saturating_add(count);
}
pub fn estimate(&self, item: &[u8]) -> u64 {
let mut min_val = u64::MAX;
for row in 0..self.depth {
let col = self.hash_to_col(row, item);
let val = self.counters[row * self.width + col];
if val < min_val {
min_val = val;
}
}
if min_val == u64::MAX {
0
} else {
min_val
}
}
pub fn inner_product(&self, other: &CountMinSketch) -> CoreResult<u64> {
if self.width != other.width || self.depth != other.depth {
return Err(CoreError::DimensionError(
error_context!("Sketches must have the same dimensions for inner product"),
));
}
let mut min_ip = u64::MAX;
for row in 0..self.depth {
let mut row_ip: u64 = 0;
for col in 0..self.width {
let idx = row * self.width + col;
row_ip = row_ip.saturating_add(
self.counters[idx].saturating_mul(other.counters[idx]),
);
}
if row_ip < min_ip {
min_ip = row_ip;
}
}
Ok(if min_ip == u64::MAX { 0 } else { min_ip })
}
pub fn merge(&mut self, other: &CountMinSketch) -> CoreResult<()> {
if self.width != other.width || self.depth != other.depth {
return Err(CoreError::DimensionError(
error_context!("Sketches must have the same dimensions for merge"),
));
}
for i in 0..self.counters.len() {
self.counters[i] = self.counters[i].saturating_add(other.counters[i]);
}
self.total_count = self.total_count.saturating_add(other.total_count);
Ok(())
}
pub fn heavy_hitters<'a>(
&self,
candidates: &'a [&[u8]],
threshold: u64,
) -> Vec<(&'a [u8], u64)> {
candidates
.iter()
.filter_map(|&item| {
let est = self.estimate(item);
if est >= threshold {
Some((item, est))
} else {
None
}
})
.collect()
}
pub fn total_count(&self) -> u64 {
self.total_count
}
pub fn width(&self) -> usize {
self.width
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn empty_clone(&self) -> Self {
Self {
counters: vec![0u64; self.width * self.depth],
width: self.width,
depth: self.depth,
total_count: 0,
hashers: self.hashers.clone(),
}
}
pub fn clear(&mut self) {
for c in &mut self.counters {
*c = 0;
}
self.total_count = 0;
}
#[inline]
fn hash_to_col(&self, row: usize, item: &[u8]) -> usize {
let (h1, h2) = self.hashers[row].hash_pair(item);
DoubleHasher::position(h1, h2, 0, self.width)
}
}
impl std::fmt::Debug for CountMinSketch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CountMinSketch")
.field("width", &self.width)
.field("depth", &self.depth)
.field("total_count", &self.total_count)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cms_basic_frequency() {
let mut cms = CountMinSketch::new(0.001, 0.01).expect("valid");
for _ in 0..100 {
cms.increment(b"apple");
}
for _ in 0..50 {
cms.increment(b"banana");
}
let est_apple = cms.estimate(b"apple");
let est_banana = cms.estimate(b"banana");
assert!(est_apple >= 100, "apple estimate too low: {est_apple}");
assert!(est_banana >= 50, "banana estimate too low: {est_banana}");
}
#[test]
fn test_cms_estimates_within_error_bounds() {
let epsilon = 0.01;
let delta = 0.01;
let mut cms = CountMinSketch::new(epsilon, delta).expect("valid");
let n = 10_000u64;
for i in 0..n {
cms.increment(&i.to_le_bytes());
}
let max_error = (epsilon * n as f64).ceil() as u64;
let mut within_bounds = 0usize;
let test_count = 1000usize;
for i in 0..test_count as u64 {
let est = cms.estimate(&i.to_le_bytes());
if est <= 1 + max_error {
within_bounds += 1;
}
}
let expected_min = ((1.0 - delta) * test_count as f64) as usize;
assert!(
within_bounds >= expected_min.saturating_sub(10),
"Only {within_bounds}/{test_count} estimates within bounds (expected >= {expected_min})"
);
}
#[test]
fn test_cms_merge() {
let mut cms1 = CountMinSketch::with_dimensions(100, 5).expect("valid");
let mut cms2 = CountMinSketch::with_dimensions(100, 5).expect("valid");
for _ in 0..30 {
cms1.increment(b"event");
}
for _ in 0..20 {
cms2.increment(b"event");
}
cms1.merge(&cms2).expect("same dimensions");
assert!(cms1.total_count() == 50);
}
#[test]
fn test_cms_conservative_update() {
let mut cms = CountMinSketch::with_dimensions(200, 5).expect("valid");
for _ in 0..100 {
cms.increment_conservative(b"item");
}
let est = cms.estimate(b"item");
assert!(est >= 100, "Conservative estimate too low: {est}");
}
#[test]
fn test_cms_heavy_hitters() {
let mut cms = CountMinSketch::new(0.001, 0.01).expect("valid");
for _ in 0..1000 {
cms.increment(b"hot");
}
for _ in 0..10 {
cms.increment(b"cold");
}
let candidates: Vec<&[u8]> = vec![b"hot", b"cold", b"missing"];
let hh = cms.heavy_hitters(&candidates, 500);
assert!(!hh.is_empty());
assert!(hh.iter().any(|(item, _)| *item == b"hot"));
assert!(!hh.iter().any(|(item, _)| *item == b"cold"));
}
#[test]
fn test_cms_empty() {
let cms = CountMinSketch::with_dimensions(50, 3).expect("valid");
assert_eq!(cms.total_count(), 0);
assert_eq!(cms.estimate(b"nope"), 0);
}
#[test]
fn test_cms_invalid_params() {
assert!(CountMinSketch::new(0.0, 0.01).is_err());
assert!(CountMinSketch::new(0.01, 0.0).is_err());
assert!(CountMinSketch::new(1.0, 0.01).is_err());
assert!(CountMinSketch::with_dimensions(0, 5).is_err());
assert!(CountMinSketch::with_dimensions(5, 0).is_err());
}
#[test]
fn test_cms_increment_by() {
let mut cms = CountMinSketch::with_dimensions(200, 5).expect("valid");
cms.increment_by(b"bulk", 42);
assert!(cms.estimate(b"bulk") >= 42);
assert_eq!(cms.total_count(), 42);
}
#[test]
fn test_cms_inner_product() {
let mut cms1 = CountMinSketch::with_dimensions(100, 5).expect("valid");
let mut cms2 = cms1.empty_clone();
cms1.increment_by(b"a", 10);
cms2.increment_by(b"a", 5);
let ip = cms1.inner_product(&cms2).expect("same dims");
assert!(ip >= 50, "Inner product too low: {ip}");
}
#[test]
fn test_cms_clear() {
let mut cms = CountMinSketch::with_dimensions(50, 3).expect("valid");
cms.increment(b"data");
cms.clear();
assert_eq!(cms.total_count(), 0);
assert_eq!(cms.estimate(b"data"), 0);
}
#[test]
fn test_cms_merge_incompatible() {
let mut cms1 = CountMinSketch::with_dimensions(100, 5).expect("valid");
let cms2 = CountMinSketch::with_dimensions(200, 5).expect("valid");
assert!(cms1.merge(&cms2).is_err());
}
}