Skip to main content

nodedb_types/approx/
count_min.rs

1//! Count-Min Sketch — approximate frequency estimation for high-cardinality streams.
2//!
3//! Fixed memory: `width × depth × 8` bytes. Default (1024 × 4) = 32 KB.
4//! Error guarantee: over-estimates by at most `ε·N` with probability `1 − δ`
5//! where `ε = e/width`, `δ = e^(−depth)`.
6
7/// Count-Min Sketch for approximate frequency queries.
8///
9/// Answers "how many times did item X appear?" with bounded over-estimation.
10/// Mergeable across shards: element-wise addition of tables.
11#[derive(Debug)]
12pub struct CountMinSketch {
13    table: Vec<Vec<u64>>,
14    width: usize,
15    depth: usize,
16    total: u64,
17    seeds: Vec<u64>,
18}
19
20impl CountMinSketch {
21    /// Create a sketch with default parameters (width=1024, depth=4).
22    ///
23    /// Error ≤ `e/1024 ≈ 0.27%` of total count, confidence ≥ `1 − e^(−4) ≈ 98.2%`.
24    pub fn new() -> Self {
25        Self::with_params(1024, 4)
26    }
27
28    /// Create with custom width and depth.
29    ///
30    /// * `width` — number of counters per row (controls accuracy; larger = less error)
31    /// * `depth` — number of hash functions/rows (controls confidence; larger = higher)
32    pub fn with_params(width: usize, depth: usize) -> Self {
33        let width = width.max(16);
34        let depth = depth.max(2);
35        let seeds: Vec<u64> = (0..depth as u64)
36            .map(|i| 0x517cc1b727220a95u64.wrapping_add(i.wrapping_mul(0x6c62272e07bb0142)))
37            .collect();
38        Self {
39            table: vec![vec![0u64; width]; depth],
40            width,
41            depth,
42            total: 0,
43            seeds,
44        }
45    }
46
47    /// Add an item occurrence.
48    pub fn add(&mut self, item: u64) {
49        self.add_count(item, 1);
50    }
51
52    /// Add an item with a specified count.
53    pub fn add_count(&mut self, item: u64, count: u64) {
54        self.total += count;
55        for d in 0..self.depth {
56            let idx = self.hash(d, item);
57            self.table[d][idx] += count;
58        }
59    }
60
61    /// Add a batch of items (each with count 1).
62    pub fn add_batch(&mut self, items: &[u64]) {
63        for &item in items {
64            self.add(item);
65        }
66    }
67
68    /// Estimate the frequency of an item.
69    ///
70    /// Returns the minimum count across all hash rows (point query).
71    /// This is always ≥ the true count and ≤ `true_count + ε·N`.
72    pub fn estimate(&self, item: u64) -> u64 {
73        let mut min_count = u64::MAX;
74        for d in 0..self.depth {
75            let idx = self.hash(d, item);
76            min_count = min_count.min(self.table[d][idx]);
77        }
78        min_count
79    }
80
81    /// Total number of items added.
82    pub fn total(&self) -> u64 {
83        self.total
84    }
85
86    /// Merge another sketch (element-wise addition).
87    ///
88    /// Both sketches must have the same width and depth.
89    pub fn merge(&mut self, other: &CountMinSketch) {
90        debug_assert_eq!(self.width, other.width);
91        debug_assert_eq!(self.depth, other.depth);
92        self.total += other.total;
93        for d in 0..self.depth {
94            for w in 0..self.width {
95                self.table[d][w] += other.table[d][w];
96            }
97        }
98    }
99
100    /// Memory usage in bytes.
101    pub fn memory_bytes(&self) -> usize {
102        self.width * self.depth * std::mem::size_of::<u64>()
103    }
104
105    /// Serialize the table as a flat byte array (row-major, little-endian u64).
106    pub fn table_bytes(&self) -> Vec<u8> {
107        let mut bytes = Vec::with_capacity(self.width * self.depth * 8);
108        for row in &self.table {
109            for &val in row {
110                bytes.extend_from_slice(&val.to_le_bytes());
111            }
112        }
113        bytes
114    }
115
116    /// Reconstruct from serialized table bytes.
117    pub fn from_table_bytes(data: &[u8], width: usize, depth: usize) -> Self {
118        let mut sketch = Self::with_params(width, depth);
119        let mut offset = 0;
120        for d in 0..depth {
121            for w in 0..width {
122                if offset + 8 <= data.len() {
123                    sketch.table[d][w] =
124                        u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap_or([0; 8]));
125                    sketch.total += sketch.table[d][w];
126                    offset += 8;
127                }
128            }
129        }
130        // Total is over-counted (each item is in `depth` rows).
131        // Approximate: use row 0's sum.
132        sketch.total = sketch.table[0].iter().sum();
133        sketch
134    }
135
136    #[inline]
137    fn hash(&self, depth_idx: usize, item: u64) -> usize {
138        let h = splitmix64(item ^ self.seeds[depth_idx]);
139        (h as usize) % self.width
140    }
141}
142
143impl Default for CountMinSketch {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149/// Splitmix64 hash — same as HyperLogLog for consistency.
150fn splitmix64(mut x: u64) -> u64 {
151    x = x.wrapping_add(0x9e3779b97f4a7c15);
152    x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
153    x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
154    x ^ (x >> 31)
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn empty_sketch() {
163        let cms = CountMinSketch::new();
164        assert_eq!(cms.estimate(42), 0);
165        assert_eq!(cms.total(), 0);
166    }
167
168    #[test]
169    fn exact_for_single_item() {
170        let mut cms = CountMinSketch::new();
171        for _ in 0..1000 {
172            cms.add(42);
173        }
174        assert_eq!(cms.estimate(42), 1000);
175    }
176
177    #[test]
178    fn overestimate_bounded() {
179        let mut cms = CountMinSketch::new();
180        // Add 100K items with zipf-like distribution.
181        for i in 0..100_000u64 {
182            cms.add(i % 1000);
183        }
184        // Each of the 1000 items appears exactly 100 times.
185        // CMS should return ≥ 100 for any item.
186        for i in 0..1000u64 {
187            let est = cms.estimate(i);
188            assert!(est >= 100, "item {i}: expected ≥100, got {est}");
189        }
190        // Over-estimation should be bounded by ~ε·N = (e/1024)*100000 ≈ 265.
191        for i in 0..1000u64 {
192            let est = cms.estimate(i);
193            assert!(est <= 400, "item {i}: expected ≤400, got {est}");
194        }
195    }
196
197    #[test]
198    fn absent_item_bounded() {
199        let mut cms = CountMinSketch::new();
200        for i in 0..10_000u64 {
201            cms.add(i);
202        }
203        // Item 99999 was never added. Estimate should be low.
204        let est = cms.estimate(99999);
205        // Bounded by ε·N ≈ (e/1024)*10000 ≈ 26.5
206        assert!(est <= 50, "absent item: expected ≤50, got {est}");
207    }
208
209    #[test]
210    fn merge() {
211        let mut a = CountMinSketch::new();
212        let mut b = CountMinSketch::new();
213        for _ in 0..500 {
214            a.add(1);
215        }
216        for _ in 0..300 {
217            b.add(1);
218        }
219        for _ in 0..200 {
220            b.add(2);
221        }
222        a.merge(&b);
223        assert_eq!(a.estimate(1), 800);
224        assert_eq!(a.total(), 1000);
225    }
226
227    #[test]
228    fn batch_add() {
229        let mut cms = CountMinSketch::new();
230        cms.add_batch(&[1, 1, 2, 3, 3, 3]);
231        assert_eq!(cms.estimate(1), 2);
232        assert_eq!(cms.estimate(3), 3);
233        assert_eq!(cms.total(), 6);
234    }
235
236    #[test]
237    fn memory() {
238        let cms = CountMinSketch::new();
239        assert_eq!(cms.memory_bytes(), 1024 * 4 * 8); // 32 KB
240    }
241}