1use std::sync::atomic::{AtomicU16, AtomicU32, AtomicU64, AtomicU8, AtomicUsize, Ordering};
16
17#[inline(always)]
21fn combine_hashes(upper: u64, lower: u64) -> u64 {
22 const MUL: u64 = 0x9ddfea08eb382d69;
23
24 let mut a = (lower ^ upper).wrapping_mul(MUL);
25 a ^= a >> 47;
26 let mut b = (upper ^ a).wrapping_mul(MUL);
27 b ^= b >> 47;
28 b = b.wrapping_mul(MUL);
29 b
30}
31
32#[inline(always)]
33fn twang_mix64(val: u64) -> u64 {
34 let mut val = (!val).wrapping_add(val << 21); val = val ^ (val >> 24);
36 val = val.wrapping_add(val << 3).wrapping_add(val << 8); val = val ^ (val >> 14);
38 val = val.wrapping_add(val << 2).wrapping_add(val << 4); val = val ^ (val >> 28);
40 val = val.wrapping_add(val << 31); val
42}
43
44macro_rules! cmsketch {
45 ($( {$type:ty, $atomic:ty, $sketch:ident}, )*) => {
46 $(
47 #[doc = concat!(
48 "Count-Min Sketch that stores `",
49 stringify!($type),
50 "` counters using atomics for concurrent updates.\n\n",
51 "Each bucket is backed by [`",
52 stringify!($atomic),
53 "`], allowing lock-free increments and decrements."
54 )]
55 #[derive(Debug)]
56 pub struct $sketch {
57 width: usize,
58 depth: usize,
59
60 table: Box<[$atomic]>,
61 }
62
63 impl $sketch {
64 pub fn new(eps: f64, confidence: f64) ->Self {
73
74 let width = (2.0 / eps).ceil() as usize;
75 let depth = (- (1.0 - confidence).log2()).ceil() as usize;
76 debug_assert!(width > 0, "width: {width}");
77 debug_assert!(depth > 0, "depth: {depth}");
78
79 let table = std::iter::repeat_with(|| <$atomic>::new(0)).take(width * depth).collect();
80
81 Self {
82 width,
83 depth,
84 table,
85 }
86 }
87
88 pub fn inc(&self, hash: u64) {
90 self.inc_by(hash, 1);
91 }
92
93 pub fn inc_by(&self, hash: u64, count: $type) {
98 for depth in 0..self.depth {
99 let index = self.index(depth, hash);
100 let _ = self.table[index].fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
101 if x <= <$type>::MAX - count { Some(x + count) } else { None }
102 });
103 }
104 }
105
106 pub fn dec(&self, hash: u64) {
108 self.dec_by(hash, 1);
109 }
110
111 pub fn dec_by(&self, hash: u64, count: $type) {
115 for depth in 0..self.depth {
116 let index = self.index(depth, hash);
117 let _ = self.table[index].fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
118 if x >= count { Some(x - count) } else { None }
119 });
120 }
121 }
122
123 pub fn estimate(&self, hash: u64) -> $type {
125 unsafe {
126 (0..self.depth).map(|depth| self.table[self.index(depth, hash)].load(Ordering::Relaxed)).min().unwrap_unchecked()
127 }
128 }
129
130 pub fn clear(&self) {
132 self.table.iter().for_each(|v| v.store(0, Ordering::Relaxed));
133 }
134
135 pub fn halve(&self) {
137 self.table.iter().for_each(|v| {
138 let _ = v.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x >> 1));
139 });
140 }
141
142 pub fn decay(&self, decay: f64) {
144 self.table.iter().for_each(|v| {
145 let _ = v.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some((x as f64 * decay) as $type));
146 });
147 }
148
149 pub fn width(&self) -> usize {
151 self.width
152 }
153
154 pub fn depth(&self) -> usize {
156 self.depth
157 }
158
159 pub fn capacity(&self) -> $type {
161 <$type>::MAX
162 }
163
164 #[inline(always)]
165 fn index(&self, depth: usize, hash: u64) -> usize {
166 depth * self.width
167 + (combine_hashes(twang_mix64(depth as u64), hash) as usize % self.width)
168 }
169
170 pub fn memory(&self) -> usize {
172 (<$type>::BITS as usize * self.depth * self.width + usize::BITS as usize * 3) / 8
173 }
174 }
175 )*
176 };
177}
178
179cmsketch! {
180 {u8, AtomicU8, CMSketchAtomicU8},
181 {u16, AtomicU16, CMSketchAtomicU16},
182 {u32, AtomicU32, CMSketchAtomicU32},
183 {u64, AtomicU64, CMSketchAtomicU64},
184 {usize, AtomicUsize, CMSketchAtomicUsize},
185}
186
187macro_rules! test_cmsketch {
188 ($( {$module:ident, $type:ty, $atomic:ty, $sketch:ident}, )*) => {
189 $(
190 #[cfg(test)]
191 mod $module {
192 use itertools::Itertools;
193 use rand_mt::Mt64;
194
195 use super::*;
196
197 #[test]
198 fn test_new() {
199 let cms = $sketch::new(0.01, 0.5);
200 assert_eq!(cms.width(), 200);
201 assert_eq!(cms.depth(), 1);
202
203 let cms = $sketch::new(0.01, 0.6);
204 assert_eq!(cms.width(), 200);
205 assert_eq!(cms.depth(), 2);
206
207 let cms = $sketch::new(0.01, 0.7);
208 assert_eq!(cms.width(), 200);
209 assert_eq!(cms.depth(), 2);
210
211 let cms = $sketch::new(0.01, 0.8);
212 assert_eq!(cms.width(), 200);
213 assert_eq!(cms.depth(), 3);
214
215 let cms = $sketch::new(0.01, 0.9);
216 assert_eq!(cms.width(), 200);
217 assert_eq!(cms.depth(), 4);
218
219 let cms = $sketch::new(0.01, 0.95);
220 assert_eq!(cms.width(), 200);
221 assert_eq!(cms.depth(), 5);
222
223 let cms = $sketch::new(0.01, 0.995);
224 assert_eq!(cms.width(), 200);
225 assert_eq!(cms.depth(), 8);
226 }
227
228 #[test]
229 #[should_panic]
230 fn test_new_with_invalid_args() {
231 $sketch::new(0.0, 0.0);
232 }
233
234 #[test]
235 fn test_inc() {
236 let cms = $sketch::new(0.01, 0.9);
237
238 let mut rng = Mt64::new_unseeded();
239 let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
240
241 for i in 0..100 {
242 for _ in 0..i {
243 cms.inc(keys[i]);
244 }
245 }
246
247 for i in 0..100 {
248 assert!(
249 cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
250 "assert {} >= {} failed",
251 cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
252 );
253 }
254 }
255
256 #[test]
257 fn test_dec() {
258 let cms = $sketch::new(0.01, 0.9);
259
260 let mut rng = Mt64::new_unseeded();
261 let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
262
263
264 for i in 0..100 {
265 for _ in 0..i {
266 cms.inc(keys[i]);
267 }
268 }
269
270 for i in 0..100 {
271 for _ in 0..i {
272 cms.dec(keys[i]);
273 }
274 }
275
276 for i in 0..100 {
277 assert_eq!(cms.estimate(keys[i]), 0);
278 }
279 }
280
281 #[test]
282 fn test_clear() {
283 let cms = $sketch::new(0.01, 0.9);
284
285 let mut rng = Mt64::new_unseeded();
286 let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
287
288 for i in 0..100 {
289 for _ in 0..i {
290 cms.inc(keys[i]);
291 }
292 }
293
294 cms.clear();
295
296 for i in 0..100 {
297 assert_eq!(cms.estimate(keys[i]), 0);
298 }
299 }
300
301 #[test]
302 fn test_halve() {
303 let cms = $sketch::new(0.01, 0.9);
304
305 let mut rng = Mt64::new_unseeded();
306 let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
307
308 for i in 0..1000 {
309 for _ in 0..i {
310 cms.inc(keys[i]);
311 }
312 }
313
314
315 for i in 0..1000 {
316 assert!(
317 cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
318 "assert {} >= {} failed",
319 cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
320 );
321 }
322
323 cms.halve();
324
325 for i in 0..1000 {
326 assert!(
327 cms.estimate(keys[i]) >= std::cmp::min(i as $type / 2, cms.capacity()),
328 "assert {} >= {} failed",
329 cms.estimate(keys[i]), std::cmp::min(i as $type / 2, cms.capacity())
330 );
331 }
332 }
333
334 #[test]
335 fn test_decay() {
336 let cms = $sketch::new(0.01, 0.9);
337 let mut rng = Mt64::new_unseeded();
338 let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
339
340 for i in 0..1000 {
341 for _ in 0..i {
342 cms.inc(keys[i]);
343 }
344 }
345
346 for i in 0..1000 {
347 assert!(
348 cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
349 "assert {} >= {} failed",
350 cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
351 );
352 }
353
354 const FACTOR: f64 = 0.5;
355 cms.decay(FACTOR);
356
357 for i in 0..1000 {
358 assert!(cms.estimate(keys[i]) >= (std::cmp::min(i as $type, cms.capacity()) as f64 * FACTOR).floor() as $type);
359 }
360 }
361
362 #[test]
363 fn test_collisions() {
364 let cms = $sketch::new(0.01, 0.9);
365 let mut rng = Mt64::new_unseeded();
366 let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
367 let mut sum = 0;
368
369 for i in 0..1000 {
371 for _ in 0..i {
372 cms.inc(keys[i]);
373 }
374 sum += i;
375 }
376
377 let error = sum as f64 * 0.01;
378 for i in 0..10 {
379 assert!(cms.estimate(keys[i]) >= i as $type);
380 assert!(i as f64 + error >= cms.estimate(keys[i]) as f64);
381 }
382 }
383 }
384 )*
385 }
386}
387
388test_cmsketch! {
389 {tests_cmsketch_atomic_u8, u8, AtomicU8, CMSketchAtomicU8},
390 {tests_cmsketch_atomic_u16, u16, AtomicU16, CMSketchAtomicU16},
391 {tests_cmsketch_atomic_u32, u32, AtomicU32, CMSketchAtomicU32},
392 {tests_cmsketch_atomic_u64, u64, AtomicU64, CMSketchAtomicU64},
393 {tests_cmsketch_atomic_usize, usize, AtomicUsize, CMSketchAtomicUsize},
394}