invocation_counter/lib.rs
1#![doc = include_str!("../README.md")]
2
3use std::{fmt::Debug, sync::Mutex};
4
5#[derive(Default)]
6struct Pair {
7 key: u64,
8 value: usize,
9}
10
11/// A counter useful for counting how many times a function invocation is called in last X seconds/minutes/hours.
12/// `N` is the number of buckets and `M` is the number of sub-buckets.
13///
14/// In the documentation and in the code, we use `key` to refer the temporal unit (e.g. seconds, minutes, hours) of the invocation.
15/// Because this library don't want force you to use seconds, you can use any unit you want.
16/// You can consider to use `std::time::Instant::elapsed().as_secs()` as the key for instance.
17///
18/// `Counter` groups keys into buckets based on the `group_shift_factor`: `key >> group_shift_factor % N` will be the bucket index.
19/// The index for the sub-bucket is `key % M`.
20///
21/// ## Internal structure
22///
23/// Internally, the `Counter` uses a ring buffer of `N` buckets. Each bucket has `M` sub-buckets.
24/// This allows the `Counter` to distribute the load across multiple sub-buckets when the keys have the same index.
25///
26/// For instance, `Counter<3, 2>::new(4)` will be like:
27/// ```text
28/// ---------- ---------- ----------
29/// | (0, 0) | | (0, 0) | | (0, 0) |
30/// | (0, 0) | | (0, 0) | | (0, 0) |
31/// ---------- ---------- ----------
32/// index 0 1 2
33/// key range 1 0-16 17-31 32-47
34/// key range 2 48-63 64-80 ...
35/// ```
36///
37pub struct Counter<const N: usize, const M: usize> {
38 buckets: [[Mutex<Pair>; M]; N],
39 group_shift_factor: u32,
40}
41
42impl<const N: usize, const M: usize> Debug for Counter<N, M> {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 let data = self
45 .buckets
46 .iter()
47 .map(|b| {
48 b.iter()
49 .map(|m| {
50 let pair = m.lock().unwrap();
51 (pair.key, pair.value)
52 })
53 .collect::<Vec<_>>()
54 })
55 .collect::<Vec<_>>();
56
57 f.debug_struct("Counter")
58 .field("buckets", &data)
59 .field("group_shift_factor", &self.group_shift_factor)
60 .finish()
61 }
62}
63
64impl<const N: usize, const M: usize> Counter<N, M> {
65 /// Create a new counter with the given group shift factor.
66 ///
67 /// `group_shift_factor` is the number of bits to shift the key to get the group index.
68 pub fn new(group_shift_factor: u32) -> Self {
69 let locks = core::array::from_fn(|_| core::array::from_fn(|_| Mutex::new(Pair::default())));
70
71 Self {
72 buckets: locks,
73 group_shift_factor,
74 }
75 }
76
77 /// Register an invocation.
78 ///
79 /// This will increment the value of the key by one.
80 /// You can use `std::time::Instant::elapsed().as_secs()` as the key for instance.
81 pub fn increment_by_one(&self, key: u64) {
82 let index = (key >> self.group_shift_factor) as usize % N;
83 let sub_index = (key as usize) % M;
84
85 let mut pair = self.buckets[index][sub_index].lock().unwrap();
86
87 let lower_bound = self.get_lower_bound(key, N);
88 if (lower_bound..=key).contains(&pair.key) {
89 pair.key = key;
90 pair.value += 1;
91 } else {
92 if pair.key > key {
93 return;
94 }
95 pair.key = key;
96 pair.value = 1;
97 }
98 }
99
100 /// Get the count of invocations till the given key.
101 ///
102 /// This will return the total count of invocations till the given key.
103 /// You can use `std::time::Instant::elapsed().as_secs()` as the key for instance.
104 ///
105 pub fn get_count_till(&self, key: u64) -> usize {
106 let d = 2_u64.pow(self.group_shift_factor) * N as u64;
107
108 let allowed_range = self.get_lower_bound(key, d as usize)..=key;
109
110 let mut tot = 0;
111 for b in &self.buckets {
112 let mut s = 0;
113 for sub in b {
114 let pair = sub.lock().unwrap();
115
116 if allowed_range.contains(&pair.key) {
117 s += pair.value;
118 }
119 }
120
121 tot += s;
122 }
123
124 tot
125 }
126
127 #[inline]
128 fn get_lower_bound(&self, key: u64, n: usize) -> u64 {
129 key.saturating_sub(n as u64 - 1)
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use std::{sync::Arc, thread};
136
137 use crate::Counter;
138
139 #[test]
140 fn test_initially_empty() {
141 let counter = Counter::<2, 5>::new(4);
142 assert_eq!(counter.get_count_till(0), 0);
143 assert_eq!(counter.get_count_till(1), 0);
144 assert_eq!(counter.get_count_till(2), 0);
145
146 assert_eq!(counter.get_count_till(16), 0);
147 assert_eq!(counter.get_count_till(17), 0);
148
149 assert_eq!(counter.get_count_till(16 * 2), 0);
150 assert_eq!(counter.get_count_till(16 * 2 + 1), 0);
151
152 assert_eq!(counter.get_count_till(16 * 3), 0);
153 assert_eq!(counter.get_count_till(16 * 3 + 1), 0);
154 }
155
156 #[test]
157 fn test_increment_check_bucket() {
158 const BUCKET_COUNT: usize = 3;
159 const SHIFT_FACTOR: u32 = 2;
160
161 let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
162 // ---------- ---------- ----------
163 // | (0, 0) | | (0, 0) | | (0, 0) |
164 // | (0, 0) | | (0, 0) | | (0, 0) |
165 // ---------- ---------- ----------
166 // 0 1 2
167 // 0-3 4-7 8-11
168
169 counter.increment_by_one(0);
170 // ---------- ---------- ----------
171 // | (0, 0) | | (0, 0) | | (0, 0) |
172 // | (0, 1) | | (0, 0) | | (0, 0) |
173 // ---------- ---------- ----------
174 // 0 1 2
175 // 0-3 4-7 8-11
176 let a = counter.buckets[0][0].lock().unwrap();
177 assert_eq!(a.key, 0);
178 assert_eq!(a.value, 1);
179 drop(a);
180 let a = counter.buckets[0][1].lock().unwrap();
181 assert_eq!(a.key, 0);
182 assert_eq!(a.value, 0);
183 drop(a);
184
185 counter.increment_by_one(1);
186 // ---------- ---------- ----------
187 // | (1, 1) | | (0, 0) | | (0, 0) |
188 // | (0, 1) | | (0, 0) | | (0, 0) |
189 // ---------- ---------- ----------
190 // 0 1 2
191 // 0-3 4-7 8-11
192 let a = counter.buckets[0][0].lock().unwrap();
193 assert_eq!(a.key, 0);
194 assert_eq!(a.value, 1);
195 drop(a);
196 let a = counter.buckets[0][1].lock().unwrap();
197 assert_eq!(a.key, 1);
198 assert_eq!(a.value, 1);
199 drop(a);
200
201 counter.increment_by_one(2);
202 // ---------- ---------- ----------
203 // | (1, 1) | | (0, 0) | | (0, 0) |
204 // | (2, 2) | | (0, 0) | | (0, 0) |
205 // ---------- ---------- ----------
206 // 0 1 2
207 // 0-3 4-7 8-11
208 let a = counter.buckets[0][0].lock().unwrap();
209 assert_eq!(a.key, 2);
210 assert_eq!(a.value, 2);
211 drop(a);
212 let a = counter.buckets[0][1].lock().unwrap();
213 assert_eq!(a.key, 1);
214 assert_eq!(a.value, 1);
215 drop(a);
216
217 counter.increment_by_one(3);
218 // ---------- ---------- ----------
219 // | (3, 2) | | (0, 0) | | (0, 0) |
220 // | (2, 2) | | (0, 0) | | (0, 0) |
221 // ---------- ---------- ----------
222 // 0 1 2
223 // 0-3 4-7 8-11
224 let a = counter.buckets[0][0].lock().unwrap();
225 assert_eq!(a.key, 2);
226 assert_eq!(a.value, 2);
227 drop(a);
228 let a = counter.buckets[0][1].lock().unwrap();
229 assert_eq!(a.key, 3);
230 assert_eq!(a.value, 2);
231 drop(a);
232
233 counter.increment_by_one(4);
234 // ---------- ---------- ----------
235 // | (3, 2) | | (0, 0) | | (0, 0) |
236 // | (2, 2) | | (4, 1) | | (0, 0) |
237 // ---------- ---------- ----------
238 // 0 1 2
239 // 0-3 4-7 8-11
240 let a = counter.buckets[1][0].lock().unwrap();
241 assert_eq!(a.key, 4);
242 assert_eq!(a.value, 1);
243 drop(a);
244 let a = counter.buckets[1][1].lock().unwrap();
245 assert_eq!(a.key, 0);
246 assert_eq!(a.value, 0);
247 drop(a);
248
249 counter.increment_by_one(5);
250 // ---------- ---------- ----------
251 // | (3, 2) | | (5, 1) | | (0, 0) |
252 // | (2, 2) | | (4, 1) | | (0, 0) |
253 // ---------- ---------- ----------
254 // 0 1 2
255 // 0-3 4-7 8-11
256 let a = counter.buckets[1][0].lock().unwrap();
257 assert_eq!(a.key, 4);
258 assert_eq!(a.value, 1);
259 drop(a);
260 let a = counter.buckets[1][1].lock().unwrap();
261 assert_eq!(a.key, 5);
262 assert_eq!(a.value, 1);
263 drop(a);
264
265 // Almost at the end of the ring buffer
266
267 counter.increment_by_one(11);
268 // ---------- ---------- ----------
269 // | (3, 2) | | (5, 1) | |(11, 1) |
270 // | (2, 2) | | (4, 1) | | (0, 0) |
271 // ---------- ---------- ----------
272 // 0 1 2
273 // 0-3 4-7 8-11
274 let a = counter.buckets[2][0].lock().unwrap();
275 assert_eq!(a.key, 0);
276 assert_eq!(a.value, 0);
277 drop(a);
278 let a = counter.buckets[2][1].lock().unwrap();
279 assert_eq!(a.key, 11);
280 assert_eq!(a.value, 1);
281 drop(a);
282
283 counter.increment_by_one(12);
284 // ---------- ---------- ----------
285 // | (3, 2) | | (5, 1) | |(11, 1) |
286 // |(12, 1) | | (4, 1) | | (0, 0) |
287 // ---------- ---------- ----------
288 // 0 1 2
289 // 0-3 4-7 8-11
290 let a = counter.buckets[0][0].lock().unwrap();
291 assert_eq!(a.key, 12);
292 assert_eq!(a.value, 1);
293 drop(a);
294 let a = counter.buckets[0][1].lock().unwrap();
295 assert_eq!(a.key, 3);
296 assert_eq!(a.value, 2);
297 drop(a);
298
299 counter.increment_by_one(13);
300 // ---------- ---------- ----------
301 // |(13, 1) | | (5, 1) | |(11, 1) |
302 // |(12, 1) | | (4, 1) | | (0, 0) |
303 // ---------- ---------- ----------
304 // 0 1 2
305 // 0-3 4-7 8-11
306 let a = counter.buckets[0][0].lock().unwrap();
307 assert_eq!(a.key, 12);
308 assert_eq!(a.value, 1);
309 drop(a);
310 let a = counter.buckets[0][1].lock().unwrap();
311 assert_eq!(a.key, 13);
312 assert_eq!(a.value, 1);
313 drop(a);
314
315 counter.increment_by_one(14);
316 // ---------- ---------- ----------
317 // |(13, 1) | | (5, 1) | |(11, 1) |
318 // |(14, 2) | | (4, 1) | | (0, 0) |
319 // ---------- ---------- ----------
320 // 0 1 2
321 // 0-3 4-7 8-11
322 let a = counter.buckets[0][0].lock().unwrap();
323 assert_eq!(a.key, 14);
324 assert_eq!(a.value, 2);
325 drop(a);
326 let a = counter.buckets[0][1].lock().unwrap();
327 assert_eq!(a.key, 13);
328 assert_eq!(a.value, 1);
329 drop(a);
330
331 counter.increment_by_one(15);
332 // ---------- ---------- ----------
333 // |(15, 2) | | (5, 1) | |(11, 1) |
334 // |(14, 2) | | (4, 1) | | (0, 0) |
335 // ---------- ---------- ----------
336 // 0 1 2
337 // 0-3 4-7 8-11
338 let a = counter.buckets[0][0].lock().unwrap();
339 assert_eq!(a.key, 14);
340 assert_eq!(a.value, 2);
341 drop(a);
342 let a = counter.buckets[0][1].lock().unwrap();
343 assert_eq!(a.key, 15);
344 assert_eq!(a.value, 2);
345 drop(a);
346
347 counter.increment_by_one(16);
348 // ---------- ---------- ----------
349 // |(15, 2) | | (5, 1) | |(11, 1) |
350 // |(14, 2) | |(16, 1) | | (0, 0) |
351 // ---------- ---------- ----------
352 // 0 1 2
353 // 0-3 4-7 8-11
354 let a = counter.buckets[1][0].lock().unwrap();
355 assert_eq!(a.key, 16);
356 assert_eq!(a.value, 1);
357 drop(a);
358 let a = counter.buckets[1][1].lock().unwrap();
359 assert_eq!(a.key, 5);
360 assert_eq!(a.value, 1);
361 drop(a);
362 }
363
364 #[test]
365 fn test_get_count_till() {
366 const BUCKET_COUNT: usize = 3;
367 const SHIFT_FACTOR: u32 = 2;
368
369 let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
370
371 counter.increment_by_one(0);
372 assert_eq!(counter.get_count_till(0), 1);
373
374 counter.increment_by_one(0);
375 assert_eq!(counter.get_count_till(0), 2);
376
377 counter.increment_by_one(1);
378 assert_eq!(counter.get_count_till(1), 3);
379
380 counter.increment_by_one(1);
381 assert_eq!(counter.get_count_till(1), 4);
382
383 counter.increment_by_one(2);
384 assert_eq!(counter.get_count_till(2), 5);
385
386 counter.increment_by_one(3);
387 assert_eq!(counter.get_count_till(3), 6);
388
389 counter.increment_by_one(4);
390 assert_eq!(counter.get_count_till(4), 7);
391
392 counter.increment_by_one(5);
393 assert_eq!(counter.get_count_till(5), 8);
394
395 counter.increment_by_one(6);
396 assert_eq!(counter.get_count_till(6), 9);
397
398 counter.increment_by_one(7);
399 assert_eq!(counter.get_count_till(7), 10);
400
401 counter.increment_by_one(8);
402 assert_eq!(counter.get_count_till(8), 11);
403
404 counter.increment_by_one(11);
405 assert_eq!(counter.get_count_till(11), 12);
406 }
407
408 #[test]
409 fn test_get_count_till_cycle() {
410 const BUCKET_COUNT: usize = 3;
411 const SHIFT_FACTOR: u32 = 2;
412
413 let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
414
415 counter.increment_by_one(0);
416 assert_eq!(counter.get_count_till(0), 1);
417
418 counter.increment_by_one(1);
419 assert_eq!(counter.get_count_till(1), 2);
420
421 counter.increment_by_one(11);
422 assert_eq!(counter.get_count_till(11), 3);
423
424 counter.increment_by_one(12);
425 assert_eq!(counter.get_count_till(12), 3);
426
427 counter.increment_by_one(13);
428 assert_eq!(counter.get_count_till(13), 3);
429
430 counter.increment_by_one(14);
431 assert_eq!(counter.get_count_till(14), 4);
432 }
433
434 #[test]
435 fn test_get_count_till_expired() {
436 const BUCKET_COUNT: usize = 3;
437 const SHIFT_FACTOR: u32 = 2;
438
439 let counter = Counter::<BUCKET_COUNT, 2>::new(SHIFT_FACTOR);
440
441 counter.increment_by_one(0);
442 assert_eq!(counter.get_count_till(0), 1);
443
444 counter.increment_by_one(11);
445 assert_eq!(counter.get_count_till(11), 2);
446
447 assert_eq!(counter.get_count_till(1_000), 0);
448 }
449
450 #[test]
451 fn test_increment_check_bucket_shift_factor_0() {
452 const BUCKET_COUNT: usize = 3;
453 const SHIFT_FACTOR: u32 = 0;
454
455 let counter = Counter::<BUCKET_COUNT, 1>::new(SHIFT_FACTOR);
456 // ---------- ---------- ----------
457 // | (0, 0) | | (0, 0) | | (0, 0) |
458 // ---------- ---------- ----------
459 // 0 1 2
460 // 3 4 8
461
462 counter.increment_by_one(0);
463 // ---------- ---------- ----------
464 // | (0, 1) | | (0, 0) | | (0, 0) |
465 // ---------- ---------- ----------
466 // 0 1 2
467 // 3 4 8
468 let a = counter.buckets[0][0].lock().unwrap();
469 assert_eq!(a.key, 0);
470 assert_eq!(a.value, 1);
471 drop(a);
472
473 counter.increment_by_one(1);
474 // ---------- ---------- ----------
475 // | (0, 1) | | (1, 1) | | (0, 0) |
476 // ---------- ---------- ----------
477 // 0 1 2
478 // 3 4 8
479 let a = counter.buckets[1][0].lock().unwrap();
480 assert_eq!(a.key, 1);
481 assert_eq!(a.value, 1);
482 drop(a);
483
484 counter.increment_by_one(2);
485 // ---------- ---------- ----------
486 // | (0, 1) | | (1, 1) | | (2, 1) |
487 // ---------- ---------- ----------
488 // 0 1 2
489 // 3 4 8
490 let a = counter.buckets[2][0].lock().unwrap();
491 assert_eq!(a.key, 2);
492 assert_eq!(a.value, 1);
493 drop(a);
494
495 counter.increment_by_one(3);
496 // ---------- ---------- ----------
497 // | (3, 1) | | (1, 1) | | (2, 1) |
498 // ---------- ---------- ----------
499 // 0 1 2
500 // 3 4 8
501 let a = counter.buckets[0][0].lock().unwrap();
502 assert_eq!(a.key, 3);
503 assert_eq!(a.value, 1);
504 drop(a);
505 }
506
507 #[test]
508 fn test_shift_factor_0() {
509 const BUCKET_COUNT: usize = 3;
510 const SHIFT_FACTOR: u32 = 0;
511
512 let counter = Counter::<BUCKET_COUNT, 1>::new(SHIFT_FACTOR);
513 // ---------- ---------- ----------
514 // | (0, 0) | | (0, 0) | | (0, 0) |
515 // ---------- ---------- ----------
516 // 0 1 2
517
518 counter.increment_by_one(0);
519 // ---------- ---------- ----------
520 // | (0, 1) | | (0, 0) | | (0, 0) |
521 // ---------- ---------- ----------
522 // 0 1 2
523 assert_eq!(counter.get_count_till(0), 1);
524
525 counter.increment_by_one(1);
526 // ---------- ---------- ----------
527 // | (0, 1) | | (1, 1) | | (0, 0) |
528 // ---------- ---------- ----------
529 // 0 1 2
530 assert_eq!(counter.get_count_till(1), 2);
531
532 counter.increment_by_one(1);
533 // ---------- ---------- ----------
534 // | (0, 1) | | (1, 2) | | (0, 0) |
535 // ---------- ---------- ----------
536 // 0 1 2
537 assert_eq!(counter.get_count_till(1), 3);
538
539 counter.increment_by_one(2);
540 // ---------- ---------- ----------
541 // | (0, 1) | | (1, 2) | | (2, 1) |
542 // ---------- ---------- ----------
543 // 0 1 2
544 assert_eq!(counter.get_count_till(2), 4);
545
546 counter.increment_by_one(3);
547 // ---------- ---------- ----------
548 // | (3, 1) | | (1, 2) | | (2, 1) |
549 // ---------- ---------- ----------
550 // 0 1 2
551 assert_eq!(counter.get_count_till(3), 4);
552 }
553
554 #[test]
555 fn test_parallel() {
556 for _ in 0..100 {
557 let counter = Counter::<3, 5>::new(0);
558 let counter = Arc::new(counter);
559
560 const THREAD_NUMBER: usize = 2;
561
562 let ths: Vec<_> = (0..THREAD_NUMBER)
563 .map(|_| {
564 let counter = Arc::clone(&counter);
565 thread::spawn(move || {
566 counter.increment_by_one(0);
567 counter.increment_by_one(1);
568 counter.increment_by_one(2);
569 counter.increment_by_one(3);
570 })
571 })
572 .collect();
573
574 for th in ths {
575 th.join().unwrap();
576 }
577
578 assert_eq!(counter.get_count_till(0), THREAD_NUMBER);
579 assert_eq!(counter.get_count_till(1), 2 * THREAD_NUMBER);
580 assert_eq!(counter.get_count_till(2), 3 * THREAD_NUMBER);
581 assert_eq!(counter.get_count_till(3), 3 * THREAD_NUMBER); // 0 is forgotten
582 }
583 }
584}