invocation_counter/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::{fmt::Debug, sync::Mutex};
4
5#[derive(Default)]
6struct Pair {
7    key: u64,
8    value: usize,
9}
10
11/// A counter useful for counting how many times a function invocation is called in last X seconds/minutes/hours.
12/// `N` is the number of buckets and `M` is the number of sub-buckets.
13///
14/// In the documentation and in the code, we use `key` to refer the temporal unit (e.g. seconds, minutes, hours) of the invocation.
15/// Because this library don't want force you to use seconds, you can use any unit you want.
16/// You can consider to use `std::time::Instant::elapsed().as_secs()` as the key for instance.
17///
18/// `Counter` groups keys into buckets based on the `group_shift_factor`: `key >> group_shift_factor % N` will be the bucket index.
19/// The index for the sub-bucket is `key % M`.
20///
21/// ## Internal structure
22///
23/// Internally, the `Counter` uses a ring buffer of `N` buckets. Each bucket has `M` sub-buckets.
24/// This allows the `Counter` to distribute the load across multiple sub-buckets when the keys have the same index.
25///
26/// For instance, `Counter<3, 2>::new(4)` will be like:
27/// ```text
28///                ----------   ----------   ----------
29///                | (0, 0) |   | (0, 0) |   | (0, 0) |
30///                | (0, 0) |   | (0, 0) |   | (0, 0) |
31///                ----------   ----------   ----------
32/// index              0            1            2
33/// key range 1       0-16        17-31        32-47
34/// key range 2      48-63        64-80         ...
35/// ```
36///
37pub struct Counter<const N: usize, const M: usize> {
38    buckets: [[Mutex<Pair>; M]; N],
39    group_shift_factor: u32,
40}
41
42impl<const N: usize, const M: usize> Debug for Counter<N, M> {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        let data = self
45            .buckets
46            .iter()
47            .map(|b| {
48                b.iter()
49                    .map(|m| {
50                        let pair = m.lock().unwrap();
51                        (pair.key, pair.value)
52                    })
53                    .collect::<Vec<_>>()
54            })
55            .collect::<Vec<_>>();
56
57        f.debug_struct("Counter")
58            .field("buckets", &data)
59            .field("group_shift_factor", &self.group_shift_factor)
60            .finish()
61    }
62}
63
64impl<const N: usize, const M: usize> Counter<N, M> {
65    /// Create a new counter with the given group shift factor.
66    ///
67    /// `group_shift_factor` is the number of bits to shift the key to get the group index.
68    pub fn new(group_shift_factor: u32) -> Self {
69        let locks = core::array::from_fn(|_| core::array::from_fn(|_| Mutex::new(Pair::default())));
70
71        Self {
72            buckets: locks,
73            group_shift_factor,
74        }
75    }
76
77    /// Register an invocation.
78    ///
79    /// This will increment the value of the key by one.
80    /// You can use `std::time::Instant::elapsed().as_secs()` as the key for instance.
81    pub fn increment_by_one(&self, key: u64) {
82        let index = (key >> self.group_shift_factor) as usize % N;
83        let sub_index = (key as usize) % M;
84
85        let mut pair = self.buckets[index][sub_index].lock().unwrap();
86
87        let lower_bound = self.get_lower_bound(key, N);
88        if (lower_bound..=key).contains(&pair.key) {
89            pair.key = key;
90            pair.value += 1;
91        } else {
92            if pair.key > key {
93                return;
94            }
95            pair.key = key;
96            pair.value = 1;
97        }
98    }
99
100    /// Get the count of invocations till the given key.
101    ///
102    /// This will return the total count of invocations till the given key.
103    /// You can use `std::time::Instant::elapsed().as_secs()` as the key for instance.
104    ///
105    pub fn get_count_till(&self, key: u64) -> usize {
106        let d = 2_u64.pow(self.group_shift_factor) * N as u64;
107
108        let allowed_range = self.get_lower_bound(key, d as usize)..=key;
109
110        let mut tot = 0;
111        for b in &self.buckets {
112            let mut s = 0;
113            for sub in b {
114                let pair = sub.lock().unwrap();
115
116                if allowed_range.contains(&pair.key) {
117                    s += pair.value;
118                }
119            }
120
121            tot += s;
122        }
123
124        tot
125    }
126
127    #[inline]
128    fn get_lower_bound(&self, key: u64, n: usize) -> u64 {
129        key.saturating_sub(n as u64 - 1)
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use std::{sync::Arc, thread};
136
137    use crate::Counter;
138
139    #[test]
140    fn test_initially_empty() {
141        let counter = Counter::<2, 5>::new(4);
142        assert_eq!(counter.get_count_till(0), 0);
143        assert_eq!(counter.get_count_till(1), 0);
144        assert_eq!(counter.get_count_till(2), 0);
145
146        assert_eq!(counter.get_count_till(16), 0);
147        assert_eq!(counter.get_count_till(17), 0);
148
149        assert_eq!(counter.get_count_till(16 * 2), 0);
150        assert_eq!(counter.get_count_till(16 * 2 + 1), 0);
151
152        assert_eq!(counter.get_count_till(16 * 3), 0);
153        assert_eq!(counter.get_count_till(16 * 3 + 1), 0);
154    }
155
156    #[test]
157    fn test_increment_check_bucket() {
158        const BUCKET_COUNT: usize = 3;
159        const SHIFT_FACTOR: u32 = 2;
160
161        let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
162        // ----------   ----------   ----------
163        // | (0, 0) |   | (0, 0) |   | (0, 0) |
164        // | (0, 0) |   | (0, 0) |   | (0, 0) |
165        // ----------   ----------   ----------
166        //     0            1            2
167        //    0-3          4-7          8-11
168
169        counter.increment_by_one(0);
170        // ----------   ----------   ----------
171        // | (0, 0) |   | (0, 0) |   | (0, 0) |
172        // | (0, 1) |   | (0, 0) |   | (0, 0) |
173        // ----------   ----------   ----------
174        //     0            1            2
175        //    0-3          4-7          8-11
176        let a = counter.buckets[0][0].lock().unwrap();
177        assert_eq!(a.key, 0);
178        assert_eq!(a.value, 1);
179        drop(a);
180        let a = counter.buckets[0][1].lock().unwrap();
181        assert_eq!(a.key, 0);
182        assert_eq!(a.value, 0);
183        drop(a);
184
185        counter.increment_by_one(1);
186        // ----------   ----------   ----------
187        // | (1, 1) |   | (0, 0) |   | (0, 0) |
188        // | (0, 1) |   | (0, 0) |   | (0, 0) |
189        // ----------   ----------   ----------
190        //     0            1            2
191        //    0-3          4-7          8-11
192        let a = counter.buckets[0][0].lock().unwrap();
193        assert_eq!(a.key, 0);
194        assert_eq!(a.value, 1);
195        drop(a);
196        let a = counter.buckets[0][1].lock().unwrap();
197        assert_eq!(a.key, 1);
198        assert_eq!(a.value, 1);
199        drop(a);
200
201        counter.increment_by_one(2);
202        // ----------   ----------   ----------
203        // | (1, 1) |   | (0, 0) |   | (0, 0) |
204        // | (2, 2) |   | (0, 0) |   | (0, 0) |
205        // ----------   ----------   ----------
206        //     0            1            2
207        //    0-3          4-7          8-11
208        let a = counter.buckets[0][0].lock().unwrap();
209        assert_eq!(a.key, 2);
210        assert_eq!(a.value, 2);
211        drop(a);
212        let a = counter.buckets[0][1].lock().unwrap();
213        assert_eq!(a.key, 1);
214        assert_eq!(a.value, 1);
215        drop(a);
216
217        counter.increment_by_one(3);
218        // ----------   ----------   ----------
219        // | (3, 2) |   | (0, 0) |   | (0, 0) |
220        // | (2, 2) |   | (0, 0) |   | (0, 0) |
221        // ----------   ----------   ----------
222        //     0            1            2
223        //    0-3          4-7          8-11
224        let a = counter.buckets[0][0].lock().unwrap();
225        assert_eq!(a.key, 2);
226        assert_eq!(a.value, 2);
227        drop(a);
228        let a = counter.buckets[0][1].lock().unwrap();
229        assert_eq!(a.key, 3);
230        assert_eq!(a.value, 2);
231        drop(a);
232
233        counter.increment_by_one(4);
234        // ----------   ----------   ----------
235        // | (3, 2) |   | (0, 0) |   | (0, 0) |
236        // | (2, 2) |   | (4, 1) |   | (0, 0) |
237        // ----------   ----------   ----------
238        //     0            1            2
239        //    0-3          4-7          8-11
240        let a = counter.buckets[1][0].lock().unwrap();
241        assert_eq!(a.key, 4);
242        assert_eq!(a.value, 1);
243        drop(a);
244        let a = counter.buckets[1][1].lock().unwrap();
245        assert_eq!(a.key, 0);
246        assert_eq!(a.value, 0);
247        drop(a);
248
249        counter.increment_by_one(5);
250        // ----------   ----------   ----------
251        // | (3, 2) |   | (5, 1) |   | (0, 0) |
252        // | (2, 2) |   | (4, 1) |   | (0, 0) |
253        // ----------   ----------   ----------
254        //     0            1            2
255        //    0-3          4-7          8-11
256        let a = counter.buckets[1][0].lock().unwrap();
257        assert_eq!(a.key, 4);
258        assert_eq!(a.value, 1);
259        drop(a);
260        let a = counter.buckets[1][1].lock().unwrap();
261        assert_eq!(a.key, 5);
262        assert_eq!(a.value, 1);
263        drop(a);
264
265        // Almost at the end of the ring buffer
266
267        counter.increment_by_one(11);
268        // ----------   ----------   ----------
269        // | (3, 2) |   | (5, 1) |   |(11, 1) |
270        // | (2, 2) |   | (4, 1) |   | (0, 0) |
271        // ----------   ----------   ----------
272        //     0            1            2
273        //    0-3          4-7          8-11
274        let a = counter.buckets[2][0].lock().unwrap();
275        assert_eq!(a.key, 0);
276        assert_eq!(a.value, 0);
277        drop(a);
278        let a = counter.buckets[2][1].lock().unwrap();
279        assert_eq!(a.key, 11);
280        assert_eq!(a.value, 1);
281        drop(a);
282
283        counter.increment_by_one(12);
284        // ----------   ----------   ----------
285        // | (3, 2) |   | (5, 1) |   |(11, 1) |
286        // |(12, 1) |   | (4, 1) |   | (0, 0) |
287        // ----------   ----------   ----------
288        //     0            1            2
289        //    0-3          4-7          8-11
290        let a = counter.buckets[0][0].lock().unwrap();
291        assert_eq!(a.key, 12);
292        assert_eq!(a.value, 1);
293        drop(a);
294        let a = counter.buckets[0][1].lock().unwrap();
295        assert_eq!(a.key, 3);
296        assert_eq!(a.value, 2);
297        drop(a);
298
299        counter.increment_by_one(13);
300        // ----------   ----------   ----------
301        // |(13, 1) |   | (5, 1) |   |(11, 1) |
302        // |(12, 1) |   | (4, 1) |   | (0, 0) |
303        // ----------   ----------   ----------
304        //     0            1            2
305        //    0-3          4-7          8-11
306        let a = counter.buckets[0][0].lock().unwrap();
307        assert_eq!(a.key, 12);
308        assert_eq!(a.value, 1);
309        drop(a);
310        let a = counter.buckets[0][1].lock().unwrap();
311        assert_eq!(a.key, 13);
312        assert_eq!(a.value, 1);
313        drop(a);
314
315        counter.increment_by_one(14);
316        // ----------   ----------   ----------
317        // |(13, 1) |   | (5, 1) |   |(11, 1) |
318        // |(14, 2) |   | (4, 1) |   | (0, 0) |
319        // ----------   ----------   ----------
320        //     0            1            2
321        //    0-3          4-7          8-11
322        let a = counter.buckets[0][0].lock().unwrap();
323        assert_eq!(a.key, 14);
324        assert_eq!(a.value, 2);
325        drop(a);
326        let a = counter.buckets[0][1].lock().unwrap();
327        assert_eq!(a.key, 13);
328        assert_eq!(a.value, 1);
329        drop(a);
330
331        counter.increment_by_one(15);
332        // ----------   ----------   ----------
333        // |(15, 2) |   | (5, 1) |   |(11, 1) |
334        // |(14, 2) |   | (4, 1) |   | (0, 0) |
335        // ----------   ----------   ----------
336        //     0            1            2
337        //    0-3          4-7          8-11
338        let a = counter.buckets[0][0].lock().unwrap();
339        assert_eq!(a.key, 14);
340        assert_eq!(a.value, 2);
341        drop(a);
342        let a = counter.buckets[0][1].lock().unwrap();
343        assert_eq!(a.key, 15);
344        assert_eq!(a.value, 2);
345        drop(a);
346
347        counter.increment_by_one(16);
348        // ----------   ----------   ----------
349        // |(15, 2) |   | (5, 1) |   |(11, 1) |
350        // |(14, 2) |   |(16, 1) |   | (0, 0) |
351        // ----------   ----------   ----------
352        //     0            1            2
353        //    0-3          4-7          8-11
354        let a = counter.buckets[1][0].lock().unwrap();
355        assert_eq!(a.key, 16);
356        assert_eq!(a.value, 1);
357        drop(a);
358        let a = counter.buckets[1][1].lock().unwrap();
359        assert_eq!(a.key, 5);
360        assert_eq!(a.value, 1);
361        drop(a);
362    }
363
364    #[test]
365    fn test_get_count_till() {
366        const BUCKET_COUNT: usize = 3;
367        const SHIFT_FACTOR: u32 = 2;
368
369        let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
370
371        counter.increment_by_one(0);
372        assert_eq!(counter.get_count_till(0), 1);
373
374        counter.increment_by_one(0);
375        assert_eq!(counter.get_count_till(0), 2);
376
377        counter.increment_by_one(1);
378        assert_eq!(counter.get_count_till(1), 3);
379
380        counter.increment_by_one(1);
381        assert_eq!(counter.get_count_till(1), 4);
382
383        counter.increment_by_one(2);
384        assert_eq!(counter.get_count_till(2), 5);
385
386        counter.increment_by_one(3);
387        assert_eq!(counter.get_count_till(3), 6);
388
389        counter.increment_by_one(4);
390        assert_eq!(counter.get_count_till(4), 7);
391
392        counter.increment_by_one(5);
393        assert_eq!(counter.get_count_till(5), 8);
394
395        counter.increment_by_one(6);
396        assert_eq!(counter.get_count_till(6), 9);
397
398        counter.increment_by_one(7);
399        assert_eq!(counter.get_count_till(7), 10);
400
401        counter.increment_by_one(8);
402        assert_eq!(counter.get_count_till(8), 11);
403
404        counter.increment_by_one(11);
405        assert_eq!(counter.get_count_till(11), 12);
406    }
407
408    #[test]
409    fn test_get_count_till_cycle() {
410        const BUCKET_COUNT: usize = 3;
411        const SHIFT_FACTOR: u32 = 2;
412
413        let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
414
415        counter.increment_by_one(0);
416        assert_eq!(counter.get_count_till(0), 1);
417
418        counter.increment_by_one(1);
419        assert_eq!(counter.get_count_till(1), 2);
420
421        counter.increment_by_one(11);
422        assert_eq!(counter.get_count_till(11), 3);
423
424        counter.increment_by_one(12);
425        assert_eq!(counter.get_count_till(12), 3);
426
427        counter.increment_by_one(13);
428        assert_eq!(counter.get_count_till(13), 3);
429
430        counter.increment_by_one(14);
431        assert_eq!(counter.get_count_till(14), 4);
432    }
433
434    #[test]
435    fn test_get_count_till_expired() {
436        const BUCKET_COUNT: usize = 3;
437        const SHIFT_FACTOR: u32 = 2;
438
439        let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
440
441        counter.increment_by_one(0);
442        assert_eq!(counter.get_count_till(0), 1);
443
444        counter.increment_by_one(11);
445        assert_eq!(counter.get_count_till(11), 2);
446
447        assert_eq!(counter.get_count_till(1_000), 0);
448    }
449
450    #[test]
451    fn test_increment_check_bucket_shift_factor_0() {
452        const BUCKET_COUNT: usize = 3;
453        const SHIFT_FACTOR: u32 = 0;
454
455        let counter = Counter::<BUCKET_COUNT, 1>::new(SHIFT_FACTOR);
456        // ----------   ----------   ----------
457        // | (0, 0) |   | (0, 0) |   | (0, 0) |
458        // ----------   ----------   ----------
459        //     0            1            2
460        //     3            4            8
461
462        counter.increment_by_one(0);
463        // ----------   ----------   ----------
464        // | (0, 1) |   | (0, 0) |   | (0, 0) |
465        // ----------   ----------   ----------
466        //     0            1            2
467        //     3            4            8
468        let a = counter.buckets[0][0].lock().unwrap();
469        assert_eq!(a.key, 0);
470        assert_eq!(a.value, 1);
471        drop(a);
472
473        counter.increment_by_one(1);
474        // ----------   ----------   ----------
475        // | (0, 1) |   | (1, 1) |   | (0, 0) |
476        // ----------   ----------   ----------
477        //     0            1            2
478        //     3            4            8
479        let a = counter.buckets[1][0].lock().unwrap();
480        assert_eq!(a.key, 1);
481        assert_eq!(a.value, 1);
482        drop(a);
483
484        counter.increment_by_one(2);
485        // ----------   ----------   ----------
486        // | (0, 1) |   | (1, 1) |   | (2, 1) |
487        // ----------   ----------   ----------
488        //     0            1            2
489        //     3            4            8
490        let a = counter.buckets[2][0].lock().unwrap();
491        assert_eq!(a.key, 2);
492        assert_eq!(a.value, 1);
493        drop(a);
494
495        counter.increment_by_one(3);
496        // ----------   ----------   ----------
497        // | (3, 1) |   | (1, 1) |   | (2, 1) |
498        // ----------   ----------   ----------
499        //     0            1            2
500        //     3            4            8
501        let a = counter.buckets[0][0].lock().unwrap();
502        assert_eq!(a.key, 3);
503        assert_eq!(a.value, 1);
504        drop(a);
505    }
506
507    #[test]
508    fn test_shift_factor_0() {
509        const BUCKET_COUNT: usize = 3;
510        const SHIFT_FACTOR: u32 = 0;
511
512        let counter = Counter::<BUCKET_COUNT, 1>::new(SHIFT_FACTOR);
513        // ----------   ----------   ----------
514        // | (0, 0) |   | (0, 0) |   | (0, 0) |
515        // ----------   ----------   ----------
516        //     0            1            2
517
518        counter.increment_by_one(0);
519        // ----------   ----------   ----------
520        // | (0, 1) |   | (0, 0) |   | (0, 0) |
521        // ----------   ----------   ----------
522        //     0            1            2
523        assert_eq!(counter.get_count_till(0), 1);
524
525        counter.increment_by_one(1);
526        // ----------   ----------   ----------
527        // | (0, 1) |   | (1, 1) |   | (0, 0) |
528        // ----------   ----------   ----------
529        //     0            1            2
530        assert_eq!(counter.get_count_till(1), 2);
531
532        counter.increment_by_one(1);
533        // ----------   ----------   ----------
534        // | (0, 1) |   | (1, 2) |   | (0, 0) |
535        // ----------   ----------   ----------
536        //     0            1            2
537        assert_eq!(counter.get_count_till(1), 3);
538
539        counter.increment_by_one(2);
540        // ----------   ----------   ----------
541        // | (0, 1) |   | (1, 2) |   | (2, 1) |
542        // ----------   ----------   ----------
543        //     0            1            2
544        assert_eq!(counter.get_count_till(2), 4);
545
546        counter.increment_by_one(3);
547        // ----------   ----------   ----------
548        // | (3, 1) |   | (1, 2) |   | (2, 1) |
549        // ----------   ----------   ----------
550        //     0            1            2
551        assert_eq!(counter.get_count_till(3), 4);
552    }
553
554    #[test]
555    fn test_parallel() {
556        for _ in 0..100 {
557            let counter = Counter::<3, 5>::new(0);
558            let counter = Arc::new(counter);
559
560            const THREAD_NUMBER: usize = 2;
561
562            let ths: Vec<_> = (0..THREAD_NUMBER)
563                .map(|_| {
564                    let counter = Arc::clone(&counter);
565                    thread::spawn(move || {
566                        counter.increment_by_one(0);
567                        counter.increment_by_one(1);
568                        counter.increment_by_one(2);
569                        counter.increment_by_one(3);
570                    })
571                })
572                .collect();
573
574            for th in ths {
575                th.join().unwrap();
576            }
577
578            assert_eq!(counter.get_count_till(0), THREAD_NUMBER);
579            assert_eq!(counter.get_count_till(1), 2 * THREAD_NUMBER);
580            assert_eq!(counter.get_count_till(2), 3 * THREAD_NUMBER);
581            assert_eq!(counter.get_count_till(3), 3 * THREAD_NUMBER); // 0 is forgotten
582        }
583    }
584}