use std::hash::Hash;
use crate::{SketchError, seeded_hash64, splitmix64};
#[derive(Debug, Clone)]
pub struct CountSketch {
width: usize,
depth: usize,
counters: Vec<i64>,
index_seeds: Vec<u64>,
sign_seeds: Vec<u64>,
total_update_magnitude: u64,
}
impl CountSketch {
pub fn new(epsilon: f64, delta: f64) -> Result<Self, SketchError> {
if !epsilon.is_finite() || epsilon <= 0.0 || epsilon >= 1.0 {
return Err(SketchError::InvalidParameter(
"epsilon must be finite and strictly between 0 and 1",
));
}
if !delta.is_finite() || delta <= 0.0 || delta >= 1.0 {
return Err(SketchError::InvalidParameter(
"delta must be finite and strictly between 0 and 1",
));
}
let width = (3.0 / (epsilon * epsilon)).ceil() as usize;
let depth = (1.0 / delta).ln().ceil() as usize;
Self::with_dimensions(width.max(1), depth.max(1))
}
pub fn with_dimensions(width: usize, depth: usize) -> Result<Self, SketchError> {
if width == 0 {
return Err(SketchError::InvalidParameter(
"width must be greater than zero",
));
}
if depth == 0 {
return Err(SketchError::InvalidParameter(
"depth must be greater than zero",
));
}
let table_len = width
.checked_mul(depth)
.ok_or(SketchError::InvalidParameter(
"width * depth overflows usize",
))?;
let index_seeds = (0..depth)
.map(|idx| splitmix64((idx as u64).wrapping_add(0x0D6E_8FD9_3A5E_4C31)))
.collect();
let sign_seeds = (0..depth)
.map(|idx| splitmix64((idx as u64).wrapping_add(0xA076_1D64_78BD_642F)))
.collect();
Ok(Self {
width,
depth,
counters: vec![0; table_len],
index_seeds,
sign_seeds,
total_update_magnitude: 0,
})
}
pub fn width(&self) -> usize {
self.width
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn total_update_magnitude(&self) -> u64 {
self.total_update_magnitude
}
pub fn is_empty(&self) -> bool {
self.total_update_magnitude == 0
}
pub fn add<T: Hash>(&mut self, item: &T, delta: i64) {
if delta == 0 {
return;
}
for row in 0..self.depth {
let idx = self.counter_index(row, item);
let signed_delta = if self.sign(row, item) == 1 {
delta
} else {
delta.saturating_neg()
};
self.counters[idx] = self.counters[idx].saturating_add(signed_delta);
}
self.total_update_magnitude = self
.total_update_magnitude
.saturating_add(delta.unsigned_abs());
}
pub fn increment<T: Hash>(&mut self, item: &T) {
self.add(item, 1);
}
pub fn decrement<T: Hash>(&mut self, item: &T) {
self.add(item, -1);
}
pub fn estimate<T: Hash>(&self, item: &T) -> i64 {
let mut estimates = Vec::with_capacity(self.depth);
for row in 0..self.depth {
let idx = self.counter_index(row, item);
let signed_counter = if self.sign(row, item) == 1 {
self.counters[idx]
} else {
self.counters[idx].saturating_neg()
};
estimates.push(signed_counter);
}
estimates.sort_unstable();
let mid = estimates.len() / 2;
if estimates.len() % 2 == 1 {
estimates[mid]
} else {
let left = estimates[mid - 1] as i128;
let right = estimates[mid] as i128;
((left + right) / 2) as i64
}
}
pub fn clear(&mut self) {
self.counters.fill(0);
self.total_update_magnitude = 0;
}
pub fn merge(&mut self, other: &Self) -> Result<(), SketchError> {
if self.width != other.width || self.depth != other.depth {
return Err(SketchError::IncompatibleSketches(
"width/depth must match for merge",
));
}
if self.index_seeds != other.index_seeds || self.sign_seeds != other.sign_seeds {
return Err(SketchError::IncompatibleSketches(
"hash seeds must match for merge",
));
}
for (left, right) in self.counters.iter_mut().zip(other.counters.iter()) {
*left = left.saturating_add(*right);
}
self.total_update_magnitude = self
.total_update_magnitude
.saturating_add(other.total_update_magnitude);
Ok(())
}
fn counter_index<T: Hash>(&self, row: usize, item: &T) -> usize {
let column = (seeded_hash64(item, self.index_seeds[row]) as usize) % self.width;
row * self.width + column
}
fn sign<T: Hash>(&self, row: usize, item: &T) -> i64 {
if (seeded_hash64(item, self.sign_seeds[row]) & 1) == 0 {
1
} else {
-1
}
}
}
#[cfg(test)]
mod tests {
use super::CountSketch;
#[test]
fn constructor_from_error_bounds_creates_non_zero_dimensions() {
let sketch = CountSketch::new(0.05, 0.01).expect("valid bounds");
assert!(sketch.width() > 0);
assert!(sketch.depth() > 0);
}
#[test]
fn constructor_rejects_invalid_parameters() {
assert!(CountSketch::new(0.0, 0.1).is_err());
assert!(CountSketch::new(0.1, 0.0).is_err());
assert!(CountSketch::new(1.0, 0.1).is_err());
assert!(CountSketch::new(0.1, 1.0).is_err());
assert!(CountSketch::new(f64::NAN, 0.1).is_err());
assert!(CountSketch::with_dimensions(0, 2).is_err());
assert!(CountSketch::with_dimensions(2, 0).is_err());
}
#[test]
fn estimate_is_reasonable_with_noise() {
let mut sketch = CountSketch::with_dimensions(2_048, 7).unwrap();
for _ in 0..5_000 {
sketch.increment(&"hot-key");
}
for value in 0_u64..50_000 {
sketch.increment(&value);
}
let estimate = sketch.estimate(&"hot-key");
assert!(estimate > 3_500 && estimate < 6_500, "estimate={estimate}");
}
#[test]
fn signed_updates_are_supported() {
let mut sketch = CountSketch::with_dimensions(1_024, 7).unwrap();
sketch.add(&"x", 10);
sketch.add(&"x", -3);
let estimate = sketch.estimate(&"x");
assert!(estimate >= 5 && estimate <= 9, "estimate={estimate}");
}
#[test]
fn merge_combines_two_sketches() {
let mut left = CountSketch::with_dimensions(512, 5).unwrap();
let mut right = CountSketch::with_dimensions(512, 5).unwrap();
left.add(&"alpha", 100);
right.add(&"alpha", 50);
left.merge(&right).unwrap();
assert_eq!(left.estimate(&"alpha"), 150);
}
#[test]
fn merge_rejects_mismatched_shape() {
let mut left = CountSketch::with_dimensions(256, 4).unwrap();
let right = CountSketch::with_dimensions(128, 4).unwrap();
assert!(left.merge(&right).is_err());
}
#[test]
fn clear_resets_state() {
let mut sketch = CountSketch::with_dimensions(128, 4).unwrap();
sketch.add(&"item", 7);
assert!(!sketch.is_empty());
sketch.clear();
assert_eq!(sketch.estimate(&"item"), 0);
assert!(sketch.is_empty());
}
}