nodedb_types/approx/
count_min.rs1#[derive(Debug)]
12pub struct CountMinSketch {
13 table: Vec<Vec<u64>>,
14 width: usize,
15 depth: usize,
16 total: u64,
17 seeds: Vec<u64>,
18}
19
20impl CountMinSketch {
21 pub fn new() -> Self {
25 Self::with_params(1024, 4)
26 }
27
28 pub fn with_params(width: usize, depth: usize) -> Self {
33 let width = width.max(16);
34 let depth = depth.max(2);
35 let seeds: Vec<u64> = (0..depth as u64)
36 .map(|i| 0x517cc1b727220a95u64.wrapping_add(i.wrapping_mul(0x6c62272e07bb0142)))
37 .collect();
38 Self {
39 table: vec![vec![0u64; width]; depth],
40 width,
41 depth,
42 total: 0,
43 seeds,
44 }
45 }
46
47 pub fn add(&mut self, item: u64) {
49 self.add_count(item, 1);
50 }
51
52 pub fn add_count(&mut self, item: u64, count: u64) {
54 self.total += count;
55 for d in 0..self.depth {
56 let idx = self.hash(d, item);
57 self.table[d][idx] += count;
58 }
59 }
60
61 pub fn add_batch(&mut self, items: &[u64]) {
63 for &item in items {
64 self.add(item);
65 }
66 }
67
68 pub fn estimate(&self, item: u64) -> u64 {
73 let mut min_count = u64::MAX;
74 for d in 0..self.depth {
75 let idx = self.hash(d, item);
76 min_count = min_count.min(self.table[d][idx]);
77 }
78 min_count
79 }
80
81 pub fn total(&self) -> u64 {
83 self.total
84 }
85
86 pub fn merge(&mut self, other: &CountMinSketch) {
90 debug_assert_eq!(self.width, other.width);
91 debug_assert_eq!(self.depth, other.depth);
92 self.total += other.total;
93 for d in 0..self.depth {
94 for w in 0..self.width {
95 self.table[d][w] += other.table[d][w];
96 }
97 }
98 }
99
100 pub fn memory_bytes(&self) -> usize {
102 self.width * self.depth * std::mem::size_of::<u64>()
103 }
104
105 pub fn table_bytes(&self) -> Vec<u8> {
107 let mut bytes = Vec::with_capacity(self.width * self.depth * 8);
108 for row in &self.table {
109 for &val in row {
110 bytes.extend_from_slice(&val.to_le_bytes());
111 }
112 }
113 bytes
114 }
115
116 pub fn from_table_bytes(data: &[u8], width: usize, depth: usize) -> Self {
118 let mut sketch = Self::with_params(width, depth);
119 let mut offset = 0;
120 for d in 0..depth {
121 for w in 0..width {
122 if offset + 8 <= data.len() {
123 sketch.table[d][w] =
124 u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap_or([0; 8]));
125 sketch.total += sketch.table[d][w];
126 offset += 8;
127 }
128 }
129 }
130 sketch.total = sketch.table[0].iter().sum();
133 sketch
134 }
135
136 #[inline]
137 fn hash(&self, depth_idx: usize, item: u64) -> usize {
138 let h = splitmix64(item ^ self.seeds[depth_idx]);
139 (h as usize) % self.width
140 }
141}
142
143impl Default for CountMinSketch {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149fn splitmix64(mut x: u64) -> u64 {
151 x = x.wrapping_add(0x9e3779b97f4a7c15);
152 x = (x ^ (x >> 30)).wrapping_mul(0xbf58476d1ce4e5b9);
153 x = (x ^ (x >> 27)).wrapping_mul(0x94d049bb133111eb);
154 x ^ (x >> 31)
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn empty_sketch() {
163 let cms = CountMinSketch::new();
164 assert_eq!(cms.estimate(42), 0);
165 assert_eq!(cms.total(), 0);
166 }
167
168 #[test]
169 fn exact_for_single_item() {
170 let mut cms = CountMinSketch::new();
171 for _ in 0..1000 {
172 cms.add(42);
173 }
174 assert_eq!(cms.estimate(42), 1000);
175 }
176
177 #[test]
178 fn overestimate_bounded() {
179 let mut cms = CountMinSketch::new();
180 for i in 0..100_000u64 {
182 cms.add(i % 1000);
183 }
184 for i in 0..1000u64 {
187 let est = cms.estimate(i);
188 assert!(est >= 100, "item {i}: expected ≥100, got {est}");
189 }
190 for i in 0..1000u64 {
192 let est = cms.estimate(i);
193 assert!(est <= 400, "item {i}: expected ≤400, got {est}");
194 }
195 }
196
197 #[test]
198 fn absent_item_bounded() {
199 let mut cms = CountMinSketch::new();
200 for i in 0..10_000u64 {
201 cms.add(i);
202 }
203 let est = cms.estimate(99999);
205 assert!(est <= 50, "absent item: expected ≤50, got {est}");
207 }
208
209 #[test]
210 fn merge() {
211 let mut a = CountMinSketch::new();
212 let mut b = CountMinSketch::new();
213 for _ in 0..500 {
214 a.add(1);
215 }
216 for _ in 0..300 {
217 b.add(1);
218 }
219 for _ in 0..200 {
220 b.add(2);
221 }
222 a.merge(&b);
223 assert_eq!(a.estimate(1), 800);
224 assert_eq!(a.total(), 1000);
225 }
226
227 #[test]
228 fn batch_add() {
229 let mut cms = CountMinSketch::new();
230 cms.add_batch(&[1, 1, 2, 3, 3, 3]);
231 assert_eq!(cms.estimate(1), 2);
232 assert_eq!(cms.estimate(3), 3);
233 assert_eq!(cms.total(), 6);
234 }
235
236 #[test]
237 fn memory() {
238 let cms = CountMinSketch::new();
239 assert_eq!(cms.memory_bytes(), 1024 * 4 * 8); }
241}