toolbox-rs 0.6.0

A toolbox of basic data structures and algorithms
Documentation
use xxhash_rust::xxh3::xxh3_128_with_seed;

use crate::as_bytes::AsBytes;

fn optimal_k(delta: f64) -> usize {
    (1. / delta).ln().ceil() as usize
}

fn optimal_m(epsilon: f64) -> usize {
    (std::f64::consts::E / epsilon).ceil() as usize
}

fn get_seed() -> u64 {
    rand::Rng::random::<u64>(&mut rand::rng())
}

pub struct CountMinSketch {
    seed: u64,
    counter: Vec<Vec<u32>>,
    m: usize,
    k: usize,
    len: usize,
}

impl CountMinSketch {
    pub fn new(delta: f64, epsilon: f64) -> Self {
        let seed = get_seed();

        let m = optimal_m(delta);
        let k = optimal_k(epsilon);

        let counter = vec![vec![0u32; m]; k];

        Self {
            seed,
            counter,
            m,
            k,
            len: 0,
        }
    }
}

impl CountMinSketch {
    fn hash_pair<K: Eq + AsBytes>(&self, key: &K) -> (u64, u64) {
        let hash = xxh3_128_with_seed(key.as_bytes(), self.seed);
        (hash as u64, (hash >> 64) as u64)
    }

    fn get_buckets<K: Eq + AsBytes>(&self, key: &K) -> Vec<usize> {
        let (hash1, hash2) = self.hash_pair(key);
        let mut bucket_indices = Vec::with_capacity(self.k);
        if self.k == 1 {
            let index = hash1 % self.m as u64;
            bucket_indices.push(index as usize);
        } else {
            (0..self.k as u64).for_each(|i| {
                let hash = hash1.wrapping_add(i.wrapping_mul(hash2));
                let index = hash % self.m as u64;
                bucket_indices.push(index as usize);
            });
        }
        bucket_indices
    }

    pub fn insert<K: Eq + AsBytes>(&mut self, key: &K) {
        let indices = self.get_buckets(key);
        indices.iter().enumerate().for_each(|(k, &b)| {
            self.counter[k][b] = self.counter[k][b].saturating_add(1);
        });
        self.len += 1;
    }

    pub fn estimate<K: Eq + AsBytes>(&self, key: &K) -> u32 {
        let indices = self.get_buckets(key);
        indices
            .iter()
            .enumerate()
            .map(|(k, b)| self.counter[k][*b])
            .fold(u32::MAX, |a, b| a.min(b))
    }
}

#[cfg(test)]
mod tests {
    use super::CountMinSketch;

    #[test]
    fn insert_check_1m() {
        let mut sketch = CountMinSketch::new(0.01, 0.2);

        for _ in 0..1_000_000 {
            sketch.insert(&"key");
        }

        assert_eq!(sketch.estimate(&"key"), 1_000_000);
        assert_eq!(sketch.estimate(&"blah"), 0);
    }
}