Skip to main content

bloom_lib/
count_min.rs

1//! A Count-Min Sketch for approximate frequency estimation.
2
3use core::{hash::BuildHasher, marker::PhantomData};
4
5use alloc::{vec, vec::Vec};
6
7use crate::{
8    hash::{reduce, DefaultHashBuilder, HashPair},
9    Error,
10};
11
12/// Euler's number, used to size the sketch width from a target error.
13const E: f64 = core::f64::consts::E;
14
15/// A sublinear-space frequency estimator.
16///
17/// A Count-Min Sketch counts how many times it has seen each item using a fixed
18/// grid of counters, far smaller than a real `HashMap<T, u64>` would be. Each
19/// item is hashed into one counter per row and those counters are incremented;
20/// the estimate for an item is the *minimum* across its rows. The estimate never
21/// undercounts and overcounts by a bounded amount with high probability — the
22/// width controls the error magnitude `epsilon`, the depth controls the
23/// confidence `delta`.
24///
25/// The sketch is generic over the item type `T` and a
26/// [`BuildHasher`](core::hash::BuildHasher) `S`, defaulting to the deterministic
27/// [`DefaultHashBuilder`](crate::hash::DefaultHashBuilder).
28///
29/// # Examples
30///
31/// ```
32/// use bloom_lib::CountMinSketch;
33///
34/// // ~0.1% error with 99.9% confidence.
35/// let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
36///
37/// sketch.increment("apple");
38/// sketch.add("apple", 4);
39/// sketch.increment("banana");
40///
41/// assert!(sketch.estimate("apple") >= 5); // never undercounts
42/// assert_eq!(sketch.estimate("cherry"), 0);
43/// ```
44#[derive(Debug, Clone)]
45#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
46pub struct CountMinSketch<T: ?Sized, S = DefaultHashBuilder> {
47    counters: Vec<u64>,
48    width: usize,
49    depth: usize,
50    total: u64,
51    #[cfg_attr(feature = "serde", serde(skip))]
52    hasher: S,
53    #[cfg_attr(feature = "serde", serde(skip))]
54    _marker: PhantomData<fn(&T)>,
55}
56
57impl<T: ?Sized> CountMinSketch<T, DefaultHashBuilder> {
58    /// Creates a sketch sized for an error factor `epsilon` at confidence
59    /// `1 - delta`, using the default hasher.
60    ///
61    /// The estimate for any item never undercounts and, with probability at
62    /// least `1 - delta`, overcounts by at most `epsilon * N`, where `N` is the
63    /// total of all counts added. Smaller `epsilon` widens the grid; smaller
64    /// `delta` deepens it.
65    ///
66    /// # Parameters
67    ///
68    /// - `epsilon`: the relative error factor. Must be in `(0.0, 1.0)`. The
69    ///   width is `ceil(e / epsilon)`.
70    /// - `delta`: the failure probability. Must be in `(0.0, 1.0)`. The depth is
71    ///   `ceil(ln(1 / delta))`.
72    ///
73    /// # Errors
74    ///
75    /// Returns [`Error::InvalidParameter`] if either argument is not a finite
76    /// value in the open interval `(0.0, 1.0)`.
77    ///
78    /// # Examples
79    ///
80    /// ```
81    /// use bloom_lib::CountMinSketch;
82    ///
83    /// let sketch = CountMinSketch::<&str>::new(0.01, 0.01).unwrap();
84    /// assert!(sketch.width() >= 271);
85    /// assert!(sketch.depth() >= 4);
86    /// ```
87    pub fn new(epsilon: f64, delta: f64) -> Result<Self, Error> {
88        Self::with_hasher(epsilon, delta, DefaultHashBuilder)
89    }
90
91    /// Creates a sketch with an explicit `width` and `depth`, using the default
92    /// hasher.
93    ///
94    /// # Parameters
95    ///
96    /// - `width`: counters per row. Must be non-zero.
97    /// - `depth`: number of rows (independent hashes). Must be non-zero.
98    ///
99    /// # Errors
100    ///
101    /// Returns [`Error::InvalidParameter`] if either argument is zero.
102    ///
103    /// # Examples
104    ///
105    /// ```
106    /// use bloom_lib::CountMinSketch;
107    ///
108    /// let sketch = CountMinSketch::<u32>::with_dimensions(2_048, 5).unwrap();
109    /// assert_eq!(sketch.width(), 2_048);
110    /// assert_eq!(sketch.depth(), 5);
111    /// ```
112    pub fn with_dimensions(width: usize, depth: usize) -> Result<Self, Error> {
113        Self::with_dimensions_and_hasher(width, depth, DefaultHashBuilder)
114    }
115}
116
117impl<T: ?Sized, S: BuildHasher> CountMinSketch<T, S> {
118    /// Creates a sketch from `epsilon`/`delta` with a caller-supplied hasher.
119    ///
120    /// # Errors
121    ///
122    /// Returns [`Error::InvalidParameter`] if either argument is not a finite
123    /// value in `(0.0, 1.0)`.
124    ///
125    /// # Examples
126    ///
127    /// ```
128    /// # #[cfg(feature = "std")] {
129    /// use std::collections::hash_map::RandomState;
130    /// use bloom_lib::CountMinSketch;
131    ///
132    /// let sketch: CountMinSketch<&str, RandomState> =
133    ///     CountMinSketch::with_hasher(0.01, 0.01, RandomState::new()).unwrap();
134    /// # }
135    /// ```
136    pub fn with_hasher(epsilon: f64, delta: f64, hasher: S) -> Result<Self, Error> {
137        if !(epsilon.is_finite() && epsilon > 0.0 && epsilon < 1.0) {
138            return Err(Error::InvalidParameter {
139                param: "epsilon",
140                reason: "must be a finite value in the open interval (0.0, 1.0)",
141            });
142        }
143        if !(delta.is_finite() && delta > 0.0 && delta < 1.0) {
144            return Err(Error::InvalidParameter {
145                param: "delta",
146                reason: "must be a finite value in the open interval (0.0, 1.0)",
147            });
148        }
149
150        let width = libm::ceil(E / epsilon) as usize;
151        let depth = libm::ceil(libm::log(1.0 / delta)) as usize;
152        Self::with_dimensions_and_hasher(width.max(1), depth.max(1), hasher)
153    }
154
155    /// Creates a sketch with an explicit geometry and a caller-supplied hasher.
156    ///
157    /// # Errors
158    ///
159    /// Returns [`Error::InvalidParameter`] if `width` or `depth` is zero.
160    pub fn with_dimensions_and_hasher(
161        width: usize,
162        depth: usize,
163        hasher: S,
164    ) -> Result<Self, Error> {
165        if width == 0 {
166            return Err(Error::InvalidParameter {
167                param: "width",
168                reason: "must be greater than zero",
169            });
170        }
171        if depth == 0 {
172            return Err(Error::InvalidParameter {
173                param: "depth",
174                reason: "must be greater than zero",
175            });
176        }
177
178        Ok(Self {
179            counters: vec![0u64; width * depth],
180            width,
181            depth,
182            total: 0,
183            hasher,
184            _marker: PhantomData,
185        })
186    }
187
188    /// Records `count` additional occurrences of `item`.
189    ///
190    /// Counters saturate at [`u64::MAX`] rather than overflowing, so an
191    /// adversarial or runaway stream degrades gracefully instead of panicking or
192    /// wrapping.
193    ///
194    /// # Examples
195    ///
196    /// ```
197    /// use bloom_lib::CountMinSketch;
198    ///
199    /// let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
200    /// sketch.add("page-view", 250);
201    /// assert!(sketch.estimate("page-view") >= 250);
202    /// ```
203    pub fn add(&mut self, item: &T, count: u64)
204    where
205        T: core::hash::Hash,
206    {
207        let pair = HashPair::new(item, &self.hasher);
208        let width = self.width as u64;
209        for row in 0..self.depth {
210            let column = reduce(pair.nth(row as u64), width) as usize;
211            let cell = &mut self.counters[row * self.width + column];
212            *cell = cell.saturating_add(count);
213        }
214        self.total = self.total.saturating_add(count);
215    }
216
217    /// Records a single occurrence of `item`. Shorthand for `add(item, 1)`.
218    ///
219    /// # Examples
220    ///
221    /// ```
222    /// use bloom_lib::CountMinSketch;
223    ///
224    /// let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
225    /// sketch.increment("hit");
226    /// sketch.increment("hit");
227    /// assert!(sketch.estimate("hit") >= 2);
228    /// ```
229    #[inline]
230    pub fn increment(&mut self, item: &T)
231    where
232        T: core::hash::Hash,
233    {
234        self.add(item, 1);
235    }
236
237    /// Estimates the number of times `item` has been added.
238    ///
239    /// The estimate is the minimum counter across all rows. It never undercounts
240    /// the true total and overcounts only by the sketch's error bound.
241    ///
242    /// # Examples
243    ///
244    /// ```
245    /// use bloom_lib::CountMinSketch;
246    ///
247    /// let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
248    /// for _ in 0..100 {
249    ///     sketch.increment("frequent");
250    /// }
251    /// let estimate = sketch.estimate("frequent");
252    /// assert!((100..=110).contains(&estimate));
253    /// ```
254    #[must_use]
255    pub fn estimate(&self, item: &T) -> u64
256    where
257        T: core::hash::Hash,
258    {
259        let pair = HashPair::new(item, &self.hasher);
260        let width = self.width as u64;
261        let mut min = u64::MAX;
262        for row in 0..self.depth {
263            let column = reduce(pair.nth(row as u64), width) as usize;
264            let value = self.counters[row * self.width + column];
265            if value < min {
266                min = value;
267            }
268        }
269        min
270    }
271
272    /// The sum of every count added (saturating).
273    ///
274    /// Unlike per-item estimates, this is exact up to saturation, because every
275    /// `add` contributes to it directly.
276    ///
277    /// # Examples
278    ///
279    /// ```
280    /// use bloom_lib::CountMinSketch;
281    ///
282    /// let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
283    /// sketch.add("a", 3);
284    /// sketch.add("b", 7);
285    /// assert_eq!(sketch.total_count(), 10);
286    /// ```
287    #[inline]
288    #[must_use]
289    pub fn total_count(&self) -> u64 {
290        self.total
291    }
292
293    /// The number of counters per row.
294    #[inline]
295    #[must_use]
296    pub fn width(&self) -> usize {
297        self.width
298    }
299
300    /// The number of rows (independent hash functions).
301    #[inline]
302    #[must_use]
303    pub fn depth(&self) -> usize {
304        self.depth
305    }
306
307    /// Resets every counter to zero, retaining the allocation.
308    ///
309    /// # Examples
310    ///
311    /// ```
312    /// use bloom_lib::CountMinSketch;
313    ///
314    /// let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
315    /// sketch.increment("x");
316    /// sketch.clear();
317    /// assert_eq!(sketch.estimate("x"), 0);
318    /// assert_eq!(sketch.total_count(), 0);
319    /// ```
320    pub fn clear(&mut self) {
321        self.counters.iter_mut().for_each(|cell| *cell = 0);
322        self.total = 0;
323    }
324
325    /// Merges `other` into `self` by summing counters cell by cell (saturating).
326    ///
327    /// After the merge, the sketch estimates frequencies as if every item from
328    /// both sketches had been added to one. Both sketches must share their
329    /// geometry.
330    ///
331    /// # Errors
332    ///
333    /// Returns [`Error::IncompatibleParameters`] if the two sketches differ in
334    /// width or depth.
335    ///
336    /// # Examples
337    ///
338    /// ```
339    /// use bloom_lib::CountMinSketch;
340    ///
341    /// let mut a = CountMinSketch::with_dimensions(512, 4).unwrap();
342    /// let mut b = CountMinSketch::with_dimensions(512, 4).unwrap();
343    /// a.add("shared", 2);
344    /// b.add("shared", 3);
345    ///
346    /// a.merge(&b).unwrap();
347    /// assert!(a.estimate("shared") >= 5);
348    /// ```
349    pub fn merge(&mut self, other: &Self) -> Result<(), Error> {
350        if self.width != other.width || self.depth != other.depth {
351            return Err(Error::IncompatibleParameters);
352        }
353        for (dst, src) in self.counters.iter_mut().zip(other.counters.iter()) {
354            *dst = dst.saturating_add(*src);
355        }
356        self.total = self.total.saturating_add(other.total);
357        Ok(())
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    #![allow(clippy::unwrap_used)]
364
365    use super::*;
366
367    #[test]
368    fn test_new_rejects_out_of_range() {
369        assert!(matches!(
370            CountMinSketch::<&str>::new(0.0, 0.1),
371            Err(Error::InvalidParameter { .. })
372        ));
373        assert!(matches!(
374            CountMinSketch::<&str>::new(0.1, 1.0),
375            Err(Error::InvalidParameter { .. })
376        ));
377    }
378
379    #[test]
380    fn test_with_dimensions_rejects_zero() {
381        assert!(matches!(
382            CountMinSketch::<u8>::with_dimensions(0, 4),
383            Err(Error::InvalidParameter { .. })
384        ));
385        assert!(matches!(
386            CountMinSketch::<u8>::with_dimensions(64, 0),
387            Err(Error::InvalidParameter { .. })
388        ));
389    }
390
391    #[test]
392    fn test_estimate_never_undercounts() {
393        let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
394        for i in 0..1_000u32 {
395            let count = u64::from(i % 7) + 1;
396            sketch.add(&i, count);
397        }
398        for i in 0..1_000u32 {
399            let truth = u64::from(i % 7) + 1;
400            assert!(
401                sketch.estimate(&i) >= truth,
402                "estimate undercounted item {i}"
403            );
404        }
405    }
406
407    #[test]
408    fn test_absent_item_estimates_low() {
409        let mut sketch = CountMinSketch::new(0.001, 0.001).unwrap();
410        for i in 0..100u32 {
411            sketch.increment(&i);
412        }
413        // An item never added estimates zero in a lightly-loaded sketch.
414        assert_eq!(sketch.estimate(&9_999u32), 0);
415    }
416
417    #[test]
418    fn test_total_count_is_exact() {
419        let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
420        sketch.add("a", 10);
421        sketch.add("b", 20);
422        sketch.increment("c");
423        assert_eq!(sketch.total_count(), 31);
424    }
425
426    #[test]
427    fn test_saturating_add() {
428        let mut sketch = CountMinSketch::<str>::with_dimensions(16, 2).unwrap();
429        sketch.add("x", u64::MAX);
430        sketch.add("x", 5);
431        assert_eq!(sketch.estimate("x"), u64::MAX);
432        assert_eq!(sketch.total_count(), u64::MAX);
433    }
434
435    #[test]
436    fn test_clear() {
437        let mut sketch = CountMinSketch::new(0.01, 0.01).unwrap();
438        sketch.add("x", 9);
439        sketch.clear();
440        assert_eq!(sketch.estimate("x"), 0);
441        assert_eq!(sketch.total_count(), 0);
442    }
443
444    #[test]
445    fn test_merge_sums_counts() {
446        let mut a = CountMinSketch::with_dimensions(512, 4).unwrap();
447        let mut b = CountMinSketch::with_dimensions(512, 4).unwrap();
448        a.add("shared", 2);
449        b.add("shared", 3);
450        a.merge(&b).unwrap();
451        assert!(a.estimate("shared") >= 5);
452        assert_eq!(a.total_count(), 5);
453    }
454
455    #[test]
456    fn test_merge_rejects_incompatible() {
457        let mut a = CountMinSketch::<&str>::with_dimensions(512, 4).unwrap();
458        let b = CountMinSketch::<&str>::with_dimensions(256, 4).unwrap();
459        assert_eq!(a.merge(&b), Err(Error::IncompatibleParameters));
460    }
461}