use std::hash::Hash;
#[derive(Debug, Clone)]
pub struct CountMinSketch {
width: usize,
depth: usize,
counters: Vec<Vec<u64>>,
seeds: Vec<u64>,
total: u64,
}
impl CountMinSketch {
pub fn new(epsilon: f64, delta: f64) -> Self {
let width = (std::f64::consts::E / epsilon).ceil() as usize;
let depth = (1.0 / delta).ln().ceil() as usize;
let mut seeds = Vec::with_capacity(depth);
let mut seed = 0x517cc1b727220a95u64; for _ in 0..depth {
seed = seed.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
seeds.push(seed);
}
Self {
width,
depth,
counters: vec![vec![0; width]; depth],
seeds,
total: 0,
}
}
pub fn default_params() -> Self {
Self::new(0.001, 0.01)
}
pub fn with_dimensions(width: usize, depth: usize) -> Self {
let mut seeds = Vec::with_capacity(depth);
let mut seed = 0x517cc1b727220a95u64;
for _ in 0..depth {
seed = seed.wrapping_mul(0x5851f42d4c957f2d).wrapping_add(1);
seeds.push(seed);
}
Self {
width,
depth,
counters: vec![vec![0; width]; depth],
seeds,
total: 0,
}
}
#[inline]
fn hash_with_seed<T: Hash>(&self, item: &T, seed: u64) -> usize {
let state = ahash::RandomState::with_seeds(seed, seed, seed, seed);
let hash = state.hash_one(item);
(hash as usize) % self.width
}
#[inline]
pub fn add<T: Hash>(&mut self, item: &T) {
self.add_count(item, 1);
}
pub fn add_count<T: Hash>(&mut self, item: &T, count: u64) {
for (i, &seed) in self.seeds.iter().enumerate() {
let j = self.hash_with_seed(item, seed);
self.counters[i][j] = self.counters[i][j].saturating_add(count);
}
self.total = self.total.saturating_add(count);
}
pub fn estimate<T: Hash>(&self, item: &T) -> u64 {
self.seeds
.iter()
.enumerate()
.map(|(i, &seed)| {
let j = self.hash_with_seed(item, seed);
self.counters[i][j]
})
.min()
.unwrap_or(0)
}
pub fn merge(&mut self, other: &CountMinSketch) {
assert_eq!(self.width, other.width, "Width mismatch");
assert_eq!(self.depth, other.depth, "Depth mismatch");
for i in 0..self.depth {
for j in 0..self.width {
self.counters[i][j] = self.counters[i][j].saturating_add(other.counters[i][j]);
}
}
self.total = self.total.saturating_add(other.total);
}
pub fn total(&self) -> u64 {
self.total
}
pub fn width(&self) -> usize {
self.width
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn is_empty(&self) -> bool {
self.total == 0
}
pub fn clear(&mut self) {
for row in &mut self.counters {
row.fill(0);
}
self.total = 0;
}
pub fn memory_usage(&self) -> usize {
std::mem::size_of::<Self>()
+ self.depth * self.width * std::mem::size_of::<u64>()
+ self.seeds.len() * std::mem::size_of::<u64>()
}
}
impl Default for CountMinSketch {
fn default() -> Self {
Self::default_params()
}
}
#[derive(Debug, Clone)]
pub struct TopKTracker<K: Hash + Eq + Clone> {
capacity: usize,
items: std::collections::HashMap<K, (u64, u64)>,
}
impl<K: Hash + Eq + Clone> TopKTracker<K> {
pub fn new(k: usize) -> Self {
Self {
capacity: k,
items: std::collections::HashMap::with_capacity(k),
}
}
pub fn add(&mut self, item: K) {
self.add_count(item, 1);
}
pub fn add_count(&mut self, item: K, count: u64) {
if let Some((c, _)) = self.items.get_mut(&item) {
*c += count;
} else if self.items.len() < self.capacity {
self.items.insert(item, (count, 0));
} else {
let min_entry = self
.items
.iter()
.min_by_key(|(_, (c, _))| *c)
.map(|(k, v)| (k.clone(), *v));
if let Some((min_key, (min_count, _))) = min_entry {
self.items.remove(&min_key);
self.items.insert(item, (min_count + count, min_count));
}
}
}
pub fn top_k(&self) -> Vec<(K, u64)> {
let mut items: Vec<_> = self
.items
.iter()
.map(|(k, (c, _))| (k.clone(), *c))
.collect();
items.sort_by(|a, b| b.1.cmp(&a.1));
items
}
pub fn get(&self, item: &K) -> Option<u64> {
self.items.get(item).map(|(c, _)| *c)
}
pub fn get_with_error(&self, item: &K) -> Option<(u64, u64)> {
self.items.get(item).copied()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn clear(&mut self) {
self.items.clear();
}
pub fn merge(&mut self, other: &TopKTracker<K>) {
for (item, (count, _)) in &other.items {
self.add_count(item.clone(), *count);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_counting() {
let mut cms = CountMinSketch::new(0.01, 0.01);
for _ in 0..100 {
cms.add(&"hello");
}
for _ in 0..50 {
cms.add(&"world");
}
assert!(cms.estimate(&"hello") >= 100);
assert!(cms.estimate(&"world") >= 50);
assert_eq!(cms.estimate(&"unknown"), 0);
}
#[test]
fn test_merge() {
let mut cms1 = CountMinSketch::with_dimensions(1000, 5);
let mut cms2 = CountMinSketch::with_dimensions(1000, 5);
for _ in 0..50 {
cms1.add(&"hello");
}
for _ in 0..50 {
cms2.add(&"hello");
}
cms1.merge(&cms2);
assert!(cms1.estimate(&"hello") >= 100);
}
#[test]
fn test_top_k() {
let mut tracker = TopKTracker::new(3);
for _ in 0..100 {
tracker.add("a");
}
for _ in 0..50 {
tracker.add("b");
}
for _ in 0..25 {
tracker.add("c");
}
for _ in 0..10 {
tracker.add("d");
}
let top = tracker.top_k();
assert_eq!(top.len(), 3);
assert_eq!(top[0].0, "a");
assert_eq!(top[1].0, "b");
assert_eq!(top[2].0, "d");
}
#[test]
fn test_top_k_merge() {
let mut tracker1 = TopKTracker::new(5);
let mut tracker2 = TopKTracker::new(5);
for _ in 0..50 {
tracker1.add("a");
}
for _ in 0..50 {
tracker2.add("a");
}
tracker1.merge(&tracker2);
assert!(tracker1.get(&"a").unwrap() >= 100);
}
}