use core::{hash::BuildHasher, marker::PhantomData};
use alloc::{vec, vec::Vec};
use crate::{
hash::{reduce, DefaultHashBuilder, HashPair},
Error,
};
const E: f64 = core::f64::consts::E;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CountMinSketch<T: ?Sized, S = DefaultHashBuilder> {
counters: Vec<u64>,
width: usize,
depth: usize,
total: u64,
#[cfg_attr(feature = "serde", serde(skip))]
hasher: S,
#[cfg_attr(feature = "serde", serde(skip))]
_marker: PhantomData<fn(&T)>,
}
impl<T: ?Sized> CountMinSketch<T, DefaultHashBuilder> {
pub fn new(epsilon: f64, delta: f64) -> Result<Self, Error> {
Self::with_hasher(epsilon, delta, DefaultHashBuilder)
}
pub fn with_dimensions(width: usize, depth: usize) -> Result<Self, Error> {
Self::with_dimensions_and_hasher(width, depth, DefaultHashBuilder)
}
}
impl<T: ?Sized, S: BuildHasher> CountMinSketch<T, S> {
pub fn with_hasher(epsilon: f64, delta: f64, hasher: S) -> Result<Self, Error> {
if !(epsilon.is_finite() && epsilon > 0.0 && epsilon < 1.0) {
return Err(Error::InvalidParameter {
param: "epsilon",
reason: "must be a finite value in the open interval (0.0, 1.0)",
});
}
if !(delta.is_finite() && delta > 0.0 && delta < 1.0) {
return Err(Error::InvalidParameter {
param: "delta",
reason: "must be a finite value in the open interval (0.0, 1.0)",
});
}
let width = libm::ceil(E / epsilon) as usize;
let depth = libm::ceil(libm::log(1.0 / delta)) as usize;
Self::with_dimensions_and_hasher(width.max(1), depth.max(1), hasher)
}
pub fn with_dimensions_and_hasher(
width: usize,
depth: usize,
hasher: S,
) -> Result<Self, Error> {
if width == 0 {
return Err(Error::InvalidParameter {
param: "width",
reason: "must be greater than zero",
});
}
if depth == 0 {
return Err(Error::InvalidParameter {
param: "depth",
reason: "must be greater than zero",
});
}
Ok(Self {
counters: vec![0u64; width * depth],
width,
depth,
total: 0,
hasher,
_marker: PhantomData,
})
}
pub fn add(&mut self, item: &T, count: u64)
where
T: core::hash::Hash,
{
let pair = HashPair::new(item, &self.hasher);
let width = self.width as u64;
for row in 0..self.depth {
let column = reduce(pair.nth(row as u64), width) as usize;
let cell = &mut self.counters[row * self.width + column];
*cell = cell.saturating_add(count);
}
self.total = self.total.saturating_add(count);
}
#[inline]
pub fn increment(&mut self, item: &T)
where
T: core::hash::Hash,
{
self.add(item, 1);
}
#[must_use]
pub fn estimate(&self, item: &T) -> u64
where
T: core::hash::Hash,
{
let pair = HashPair::new(item, &self.hasher);
let width = self.width as u64;
let mut min = u64::MAX;
for row in 0..self.depth {
let column = reduce(pair.nth(row as u64), width) as usize;
let value = self.counters[row * self.width + column];
if value < min {
min = value;
}
}
min
}
#[inline]
#[must_use]
pub fn total_count(&self) -> u64 {
self.total
}
#[inline]
#[must_use]
pub fn width(&self) -> usize {
self.width
}
#[inline]
#[must_use]
pub fn depth(&self) -> usize {
self.depth
}
pub fn clear(&mut self) {
self.counters.iter_mut().for_each(|cell| *cell = 0);
self.total = 0;
}
pub fn merge(&mut self, other: &Self) -> Result<(), Error> {
if self.width != other.width || self.depth != other.depth {
return Err(Error::IncompatibleParameters);
}
for (dst, src) in self.counters.iter_mut().zip(other.counters.iter()) {
*dst = dst.saturating_add(*src);
}
self.total = self.total.saturating_add(other.total);
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_new_rejects_out_of_range() {
assert!(matches!(
CountMinSketch::<&str>::new(0.0, 0.1),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
CountMinSketch::<&str>::new(0.1, 1.0),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_with_dimensions_rejects_zero() {
assert!(matches!(
CountMinSketch::<u8>::with_dimensions(0, 4),
Err(Error::InvalidParameter { .. })
));
assert!(matches!(
CountMinSketch::<u8>::with_dimensions(64, 0),
Err(Error::InvalidParameter { .. })
));
}
#[test]
fn test_estimate_never_undercounts() {
let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
for i in 0..1_000u32 {
let count = u64::from(i % 7) + 1;
sketch.add(&i, count);
}
for i in 0..1_000u32 {
let truth = u64::from(i % 7) + 1;
assert!(
sketch.estimate(&i) >= truth,
"estimate undercounted item {i}"
);
}
}
#[test]
fn test_absent_item_estimates_low() {
let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
for i in 0..100u32 {
sketch.increment(&i);
}
assert_eq!(sketch.estimate(&9_999u32), 0);
}
#[test]
fn test_total_count_is_exact() {
let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
sketch.add("a", 10);
sketch.add("b", 20);
sketch.increment("c");
assert_eq!(sketch.total_count(), 31);
}
#[test]
fn test_saturating_add() {
let mut sketch = CountMinSketch::<str>::with_dimensions(16, 2).unwrap();
sketch.add("x", u64::MAX);
sketch.add("x", 5);
assert_eq!(sketch.estimate("x"), u64::MAX);
assert_eq!(sketch.total_count(), u64::MAX);
}
#[test]
fn test_clear() {
let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
sketch.add("x", 9);
sketch.clear();
assert_eq!(sketch.estimate("x"), 0);
assert_eq!(sketch.total_count(), 0);
}
#[test]
fn test_merge_sums_counts() {
let mut a = CountMinSketch::with_dimensions(512, 4).unwrap();
let mut b = CountMinSketch::with_dimensions(512, 4).unwrap();
a.add("shared", 2);
b.add("shared", 3);
a.merge(&b).unwrap();
assert!(a.estimate("shared") >= 5);
assert_eq!(a.total_count(), 5);
}
#[test]
fn test_merge_rejects_incompatible() {
let mut a = CountMinSketch::<&str>::with_dimensions(512, 4).unwrap();
let b = CountMinSketch::<&str>::with_dimensions(256, 4).unwrap();
assert_eq!(a.merge(&b), Err(Error::IncompatibleParameters));
}
}