cmsketch/
atomic.rs

1//  Copyright 2023 MrCroxx
2//
3//  Licensed under the Apache License, Version 2.0 (the "License");
4//  you may not use this file except in compliance with the License.
5//  You may obtain a copy of the License at
6//
7//  http://www.apache.org/licenses/LICENSE-2.0
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14
15use std::sync::atomic::{AtomicU16, AtomicU32, AtomicU64, AtomicU8, AtomicUsize, Ordering};
16
17/// Reduce two 64-bit hashes into one.
18///
19/// Ported from CacheLib, which uses the `Hash128to64` function from Google's city hash.
20#[inline(always)]
21fn combine_hashes(upper: u64, lower: u64) -> u64 {
22    const MUL: u64 = 0x9ddfea08eb382d69;
23
24    let mut a = (lower ^ upper).wrapping_mul(MUL);
25    a ^= a >> 47;
26    let mut b = (upper ^ a).wrapping_mul(MUL);
27    b ^= b >> 47;
28    b = b.wrapping_mul(MUL);
29    b
30}
31
32#[inline(always)]
33fn twang_mix64(val: u64) -> u64 {
34    let mut val = (!val).wrapping_add(val << 21); // val *= (1 << 21); val -= 1
35    val = val ^ (val >> 24);
36    val = val.wrapping_add(val << 3).wrapping_add(val << 8); // val *= 1 + (1 << 3) + (1 << 8)
37    val = val ^ (val >> 14);
38    val = val.wrapping_add(val << 2).wrapping_add(val << 4); // va; *= 1 + (1 << 2) + (1 << 4)
39    val = val ^ (val >> 28);
40    val = val.wrapping_add(val << 31); // val *= 1 + (1 << 31)
41    val
42}
43
44macro_rules! cmsketch {
45    ($( {$type:ty, $atomic:ty, $sketch:ident}, )*) => {
46        $(
47            #[doc = concat!(
48                "Count-Min Sketch that stores `",
49                stringify!($type),
50                "` counters using atomics for concurrent updates.\n\n",
51                "Each bucket is backed by [`",
52                stringify!($atomic),
53                "`], allowing lock-free increments and decrements."
54            )]
55            #[derive(Debug)]
56            pub struct $sketch {
57                width: usize,
58                depth: usize,
59
60                table: Box<[$atomic]>,
61            }
62
63            impl $sketch {
64                /// Creates a new atomic sketch sized by error `eps` and `confidence`.
65                ///
66                /// See [`CMSketchU32::new`](crate::CMSketchU32::new) for the mapping between
67                /// confidence and depth.
68                ///
69                /// # Panics
70                ///
71                /// Panics if `eps <= 0.0` or `confidence <= 0.0`.
72                pub fn new(eps: f64, confidence: f64) ->Self {
73
74                    let width = (2.0 / eps).ceil() as usize;
75                    let depth = (- (1.0 - confidence).log2()).ceil() as usize;
76                    debug_assert!(width > 0, "width: {width}");
77                    debug_assert!(depth > 0, "depth: {depth}");
78
79                    let table = std::iter::repeat_with(|| <$atomic>::new(0)).take(width * depth).collect();
80
81                    Self {
82                        width,
83                        depth,
84                        table,
85                    }
86                }
87
88                /// Atomically increments the count associated with `hash` by 1.
89                pub fn inc(&self, hash: u64) {
90                    self.inc_by(hash, 1);
91                }
92
93                /// Atomically increments the count associated with `hash` by `count`.
94                ///
95                /// Saturates at the maximum value representable by the counter type and leaves the bucket unchanged
96                /// if the update would overflow.
97                pub fn inc_by(&self, hash: u64, count: $type) {
98                    for depth in 0..self.depth {
99                        let index = self.index(depth, hash);
100                        let _ = self.table[index].fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
101                            if x <= <$type>::MAX - count { Some(x + count) } else { None }
102                        });
103                    }
104                }
105
106                /// Atomically decrements the count associated with `hash` by 1.
107                pub fn dec(&self, hash: u64) {
108                    self.dec_by(hash, 1);
109                }
110
111                /// Atomically decrements the count associated with `hash` by `count`.
112                ///
113                /// Leaves the counter unchanged if it would underflow.
114                pub fn dec_by(&self, hash: u64, count: $type) {
115                    for depth in 0..self.depth {
116                        let index = self.index(depth, hash);
117                        let _ = self.table[index].fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| {
118                            if x >= count { Some(x - count) } else { None }
119                        });
120                    }
121                }
122
123                /// Returns the minimum counter across all rows for `hash`.
124                pub fn estimate(&self, hash: u64) -> $type {
125                    unsafe {
126                        (0..self.depth).map(|depth| self.table[self.index(depth, hash)].load(Ordering::Relaxed)).min().unwrap_unchecked()
127                    }
128                }
129
130                /// Resets all counters to zero.
131                pub fn clear(&self) {
132                    self.table.iter().for_each(|v| v.store(0, Ordering::Relaxed));
133                }
134
135                /// Divides every counter by two using an atomic fetch-update.
136                pub fn halve(&self) {
137                    self.table.iter().for_each(|v| {
138                        let _ = v.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some(x >> 1));
139                    });
140                }
141
142                /// Applies a floating-point decay factor to every counter.
143                pub fn decay(&self, decay: f64) {
144                    self.table.iter().for_each(|v| {
145                        let _ = v.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |x| Some((x as f64 * decay) as $type));
146                    });
147                }
148
149                /// Returns the configured table width (number of columns).
150                pub fn width(&self) -> usize {
151                    self.width
152                }
153
154                /// Returns the number of hash rows.
155                pub fn depth(&self) -> usize {
156                    self.depth
157                }
158
159                /// Returns the maximum representable counter for this sketch.
160                pub fn capacity(&self) -> $type {
161                    <$type>::MAX
162                }
163
164                #[inline(always)]
165                fn index(&self, depth: usize, hash: u64) -> usize {
166                    depth * self.width
167                        + (combine_hashes(twang_mix64(depth as u64), hash) as usize % self.width)
168                }
169
170                /// Returns the amount of memory used by the sketch in bytes.
171                pub fn memory(&self) -> usize {
172                    (<$type>::BITS as usize * self.depth * self.width + usize::BITS as usize * 3) / 8
173                }
174            }
175        )*
176    };
177}
178
179cmsketch! {
180    {u8, AtomicU8, CMSketchAtomicU8},
181    {u16, AtomicU16, CMSketchAtomicU16},
182    {u32, AtomicU32, CMSketchAtomicU32},
183    {u64, AtomicU64, CMSketchAtomicU64},
184    {usize, AtomicUsize, CMSketchAtomicUsize},
185}
186
187macro_rules! test_cmsketch {
188    ($( {$module:ident, $type:ty, $atomic:ty, $sketch:ident}, )*) => {
189        $(
190            #[cfg(test)]
191            mod $module {
192                use itertools::Itertools;
193                use rand_mt::Mt64;
194
195                use super::*;
196
197                #[test]
198                fn test_new() {
199                    let cms = $sketch::new(0.01, 0.5);
200                    assert_eq!(cms.width(), 200);
201                    assert_eq!(cms.depth(), 1);
202
203                    let cms = $sketch::new(0.01, 0.6);
204                    assert_eq!(cms.width(), 200);
205                    assert_eq!(cms.depth(), 2);
206
207                    let cms = $sketch::new(0.01, 0.7);
208                    assert_eq!(cms.width(), 200);
209                    assert_eq!(cms.depth(), 2);
210
211                    let cms = $sketch::new(0.01, 0.8);
212                    assert_eq!(cms.width(), 200);
213                    assert_eq!(cms.depth(), 3);
214
215                    let cms = $sketch::new(0.01, 0.9);
216                    assert_eq!(cms.width(), 200);
217                    assert_eq!(cms.depth(), 4);
218
219                    let cms = $sketch::new(0.01, 0.95);
220                    assert_eq!(cms.width(), 200);
221                    assert_eq!(cms.depth(), 5);
222
223                    let cms = $sketch::new(0.01, 0.995);
224                    assert_eq!(cms.width(), 200);
225                    assert_eq!(cms.depth(), 8);
226                }
227
228                #[test]
229                #[should_panic]
230                fn test_new_with_invalid_args() {
231                    $sketch::new(0.0, 0.0);
232                }
233
234                #[test]
235                fn test_inc() {
236                    let cms = $sketch::new(0.01, 0.9);
237
238                    let mut rng = Mt64::new_unseeded();
239                    let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
240
241                    for i in 0..100 {
242                        for _ in 0..i {
243                            cms.inc(keys[i]);
244                        }
245                    }
246
247                    for i in 0..100 {
248                        assert!(
249                            cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
250                            "assert {} >= {} failed",
251                            cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
252                        );
253                    }
254                }
255
256                #[test]
257                fn test_dec() {
258                    let cms = $sketch::new(0.01, 0.9);
259
260                    let mut rng = Mt64::new_unseeded();
261                    let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
262
263
264                    for i in 0..100 {
265                        for _ in 0..i {
266                            cms.inc(keys[i]);
267                        }
268                    }
269
270                    for i in 0..100 {
271                        for _ in 0..i {
272                            cms.dec(keys[i]);
273                        }
274                    }
275
276                    for i in 0..100 {
277                        assert_eq!(cms.estimate(keys[i]), 0);
278                    }
279                }
280
281                #[test]
282                fn test_clear() {
283                    let cms = $sketch::new(0.01, 0.9);
284
285                    let mut rng = Mt64::new_unseeded();
286                    let keys = (0..100).map(|_| rng.next_u64()).collect_vec();
287
288                    for i in 0..100 {
289                        for _ in 0..i {
290                            cms.inc(keys[i]);
291                        }
292                    }
293
294                    cms.clear();
295
296                    for i in 0..100 {
297                        assert_eq!(cms.estimate(keys[i]), 0);
298                    }
299                }
300
301                #[test]
302                fn test_halve() {
303                    let cms = $sketch::new(0.01, 0.9);
304
305                    let mut rng = Mt64::new_unseeded();
306                    let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
307
308                    for i in 0..1000 {
309                        for _ in 0..i {
310                            cms.inc(keys[i]);
311                        }
312                    }
313
314
315                    for i in 0..1000 {
316                        assert!(
317                            cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
318                            "assert {} >= {} failed",
319                            cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
320                        );
321                    }
322
323                    cms.halve();
324
325                    for i in 0..1000 {
326                        assert!(
327                            cms.estimate(keys[i]) >= std::cmp::min(i as $type / 2, cms.capacity()),
328                            "assert {} >= {} failed",
329                            cms.estimate(keys[i]), std::cmp::min(i as $type / 2, cms.capacity())
330                        );
331                    }
332                }
333
334                #[test]
335                fn test_decay() {
336                    let cms = $sketch::new(0.01, 0.9);
337                    let mut rng = Mt64::new_unseeded();
338                    let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
339
340                    for i in 0..1000 {
341                        for _ in 0..i {
342                            cms.inc(keys[i]);
343                        }
344                    }
345
346                    for i in 0..1000 {
347                        assert!(
348                            cms.estimate(keys[i]) >= std::cmp::min(i as $type, cms.capacity()),
349                            "assert {} >= {} failed",
350                            cms.estimate(keys[i]), std::cmp::min(i as $type, cms.capacity())
351                        );
352                    }
353
354                    const FACTOR: f64 = 0.5;
355                    cms.decay(FACTOR);
356
357                    for i in 0..1000 {
358                        assert!(cms.estimate(keys[i]) >= (std::cmp::min(i as $type, cms.capacity()) as f64 * FACTOR).floor() as $type);
359                    }
360                }
361
362                #[test]
363                fn test_collisions() {
364                    let cms = $sketch::new(0.01, 0.9);
365                    let mut rng = Mt64::new_unseeded();
366                    let keys = (0..1000).map(|_| rng.next_u64()).collect_vec();
367                    let mut sum = 0;
368
369                    // Try inserting more keys than cms table width
370                    for i in 0..1000 {
371                        for _ in 0..i {
372                            cms.inc(keys[i]);
373                        }
374                        sum += i;
375                    }
376
377                    let error = sum as f64 * 0.01;
378                    for i in 0..10 {
379                        assert!(cms.estimate(keys[i]) >= i as $type);
380                        assert!(i as f64 + error >= cms.estimate(keys[i]) as f64);
381                    }
382                }
383            }
384        )*
385    }
386}
387
388test_cmsketch! {
389    {tests_cmsketch_atomic_u8, u8, AtomicU8, CMSketchAtomicU8},
390    {tests_cmsketch_atomic_u16, u16, AtomicU16, CMSketchAtomicU16},
391    {tests_cmsketch_atomic_u32, u32, AtomicU32, CMSketchAtomicU32},
392    {tests_cmsketch_atomic_u64, u64, AtomicU64, CMSketchAtomicU64},
393    {tests_cmsketch_atomic_usize, usize, AtomicUsize, CMSketchAtomicUsize},
394}