hypercounter/
lib.rs

1//! An atomic, lock-free, hash map-like counter structure.
2//!
3//! It uses [`papaya::HashMap`] under the hood to provide concurrent access to multiple keys at
4//! once, allowing for efficient counting without the need for locks.
5//!
6//! ## Notes Before Use
7//!
8//! - Operations on atomics are always wrapping on overflow.
9//!
10//! # Getting Started
11//!
12//! To install this library, run the following command:
13//!
14//! ```sh
15//! cargo add hypercounter
16//! ```
17//!
18//! That's it! To start using it, create a new [`HyperCounter`] instance:
19//!
20//! ```rust
21//! use std::sync::atomic::{AtomicUsize, Ordering};
22//! use hypercounter::HyperCounter;
23//!
24//! let counter: HyperCounter<String, AtomicUsize> = HyperCounter::new();
25//!
26//! counter.fetch_add("example_key".to_string(), 1, Ordering::Relaxed);
27//! counter.fetch_sub("example_key".to_string(), 1, Ordering::Relaxed);
28//! ```
29//!
30//! Keys are automatically removed when their associated counter reaches zero. Neither inserts nor
31//! removals are needed explicitly. If you want to remove a key manually, however, you can do so
32//! using [`HyperCounter::swap()`] to swap the value with 0.
33//!
34//! ```rust
35//! # use std::sync::atomic::{AtomicUsize, Ordering};
36//! # use hypercounter::HyperCounter;
37//! # let counter: HyperCounter<String, AtomicUsize> = HyperCounter::new();
38//! let previous_value = counter.swap("example_key".to_string(), 0, Ordering::Relaxed);
39//! ```
40//!
41//! ## Supported Operations
42//!
43//! The following atomic operations are supported:
44//!
45//! - [`HyperCounter::load()`]: Atomically loads the current value for a given key.
46//! - [`HyperCounter::swap()`]: Atomically swaps the value for a given key.
47//! - [`HyperCounter::fetch_add()`]: Atomically adds a value to the counter for a given key.
48//! - [`HyperCounter::fetch_sub()`]: Atomically subtracts a value from the counter for a given key.
49//! - [`HyperCounter::fetch_and()`]: Atomically performs a bitwise AND operation on the counter for
50//!   a given key.
51//! - [`HyperCounter::fetch_nand()`]: Atomically performs a bitwise NAND operation on the counter
52//!   for a given key.
53//! - [`HyperCounter::fetch_or()`]: Atomically performs a bitwise OR operation on the counter for a
54//!   given key.
55//! - [`HyperCounter::fetch_xor()`]: Atomically performs a bitwise XOR operation on the counter for
56//!   a given key.
57//! - [`HyperCounter::fetch_max()`]: Atomically sets the counter for a given key to the maximum of
58//!   the current value and the provided value.
59//! - [`HyperCounter::fetch_min()`]: Atomically sets the counter for a given key to the minimum of
60//!   the current value and the provided value.
61//!
62//! # Benchmarking
63//!
64//! There's a simple benchmark example included in the `examples` directory. You can run it using:
65//!
66//! ```sh
67//! cargo run --example bench
68//! ```
69//!
70//! This will execute a series of single-threaded benchmarks and print the operations per second
71//! for various scenarios.
72
73use std::{
74    hash::{BuildHasher, Hash, RandomState},
75    sync::{Arc, atomic::Ordering},
76};
77
78use papaya::HashMap;
79
80use crate::numbers::AtomicNumber;
81
82mod numbers;
83
84pub struct HyperCounter<K, V, H = RandomState>
85where
86    K: Eq + Hash,
87    V: AtomicNumber,
88    H: BuildHasher + Default,
89{
90    inner: HashMap<K, Arc<V>, H>,
91}
92
93impl<K, V, H> HyperCounter<K, V, H>
94where
95    K: Eq + Hash,
96    V: AtomicNumber,
97    H: BuildHasher + Default,
98{
99    /// Creates a new, empty HyperCounter.
100    ///
101    /// Returns:
102    /// * [`HyperCounter<K, V>`] - A new HyperCounter instance.
103    pub fn new() -> Self {
104        Self {
105            inner: HashMap::with_hasher(H::default()),
106        }
107    }
108
109    /// Returns the current amount of occupied entries in the HyperCounter.
110    ///
111    /// Returns:
112    /// * [`usize`] - The current length.
113    pub fn len(&self) -> usize {
114        self.inner.len()
115    }
116
117    /// Checks if the HyperCounter is empty.
118    ///
119    /// Returns:
120    /// * `true` - if the HyperCounter is empty.
121    /// * `false` - otherwise.
122    pub fn is_empty(&self) -> bool {
123        self.inner.is_empty()
124    }
125
126    /// Gets the appropriate load ordering based on the provided ordering.
127    ///
128    /// This is because [`HyperCounter::fetch_add()`] is not a real read-modify-write operation
129    /// since it has to combine a fetch to the atomic pointer before updating the atomic number.
130    ///
131    /// Arguments:
132    /// * `ordering` - The original memory ordering.
133    ///
134    /// Returns:
135    /// * `Ordering` - The adjusted load ordering.
136    fn get_load_ordering(&self, ordering: Ordering) -> Ordering {
137        match ordering {
138            Ordering::Release | Ordering::AcqRel => Ordering::Acquire,
139            o => o,
140        }
141    }
142
143    /// Atomically loads the value for the given key.
144    ///
145    /// Arguments:
146    /// * `key` - The key to load.
147    /// * `ordering` - The memory ordering to use.
148    ///
149    /// Returns:
150    /// * [`V::Primitive`] - The current value associated with the key, or zero if the key is
151    ///   missing.
152    pub fn load(&self, key: &K, ordering: Ordering) -> V::Primitive {
153        self.inner
154            .pin()
155            .get(key)
156            .map(|i| i.load(ordering))
157            .unwrap_or(V::ZERO)
158    }
159
160    /// Atomically swaps the value for the given key.
161    ///
162    /// If the new value is zero, the entry is removed.
163    ///
164    /// If the key is missing, a new entry is created with the new value.
165    ///
166    /// Arguments:
167    /// * `key` - The key to swap.
168    /// * `new_value` - The new value to set.
169    /// * `ordering` - The memory ordering to use.
170    ///
171    /// Returns:
172    /// * [`V::Primitive`] - The previous value associated with the key. (before swap)
173    pub fn swap(&self, key: K, new_value: V::Primitive, ordering: Ordering) -> V::Primitive {
174        let map = self.inner.pin();
175
176        if new_value == V::ZERO {
177            let old = map.remove(&key);
178
179            old.map(|i| i.load(self.get_load_ordering(ordering)))
180                .unwrap_or(V::ZERO)
181        } else {
182            let entry = map.get(&key);
183
184            if let Some(entry) = entry {
185                entry.swap(new_value, ordering)
186            } else {
187                let value = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
188
189                value.swap(new_value, ordering)
190            }
191        }
192    }
193
194    /// Atomically adds a value to the counter for the given key.
195    ///
196    /// If the key ends up being zero after addition, the entry is removed.
197    ///
198    /// If the key is missing, a new entry is created with the given value.
199    ///
200    /// Arguments:
201    /// * `key` - The key to add to.
202    /// * `value` - The value to add.
203    /// * `ordering` - The memory ordering to use.
204    ///
205    /// Returns:
206    /// * [`V::Primitive`] - The previous value associated with the key. (before addition)
207    pub fn fetch_add(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
208        let map = self.inner.pin();
209
210        let entry = map.get(&key);
211
212        if let Some(entry) = entry {
213            let result = entry.fetch_add(value, ordering);
214
215            if V::primitive_wrapping_add(result, value) == V::ZERO {
216                map.remove(&key);
217            }
218
219            result
220        } else {
221            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
222
223            result.fetch_add(value, ordering)
224        }
225    }
226
227    /// Atomically subtracts a value from the counter for the given key.
228    ///
229    /// If the key ends up being zero after subtraction, the entry is removed.
230    ///
231    /// If the key is missing, a new entry is created with zero - the given value.
232    ///
233    /// Arguments:
234    /// * `key` - The key to subtract from.
235    /// * `value` - The value to subtract.
236    /// * `ordering` - The memory ordering to use.
237    ///
238    /// Returns:
239    /// * [`V::Primitive`] - The previous value associated with the key. (before subtraction)
240    pub fn fetch_sub(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
241        let map = self.inner.pin();
242
243        let entry = map.get(&key);
244
245        if let Some(entry) = entry {
246            let result = entry.fetch_sub(value, ordering);
247
248            if V::primitive_wrapping_sub(result, value) == V::ZERO {
249                map.remove(&key);
250            }
251
252            result
253        } else {
254            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
255
256            result.fetch_sub(value, ordering)
257        }
258    }
259
260    /// Atomically performs a bitwise AND operation on the counter for the given key.
261    ///
262    /// If the key is missing, nothing is done and zero is returned.
263    ///
264    /// Arguments:
265    /// * `key` - The key to perform the AND operation on.
266    /// * `value` - The value to AND with.
267    /// * `ordering` - The memory ordering to use.
268    ///
269    /// Returns:
270    /// * [`V::Primitive`] - The previous value associated with the key. (before AND operation)
271    pub fn fetch_and(&self, key: &K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
272        let map = self.inner.pin();
273
274        let entry = map.get(key);
275
276        if let Some(entry) = entry {
277            entry.fetch_and(value, ordering)
278        } else {
279            V::ZERO
280        }
281    }
282
283    /// Atomically performs a bitwise NAND operation on the counter for the given key.
284    ///
285    /// If the key is missing, the new value is inserted and all bits set (i.e., !0) is returned.
286    ///
287    /// Arguments:
288    /// * `key` - The key to perform the NAND operation on.
289    /// * `value` - The value to NAND with.
290    /// * `ordering` - The memory ordering to use.
291    ///
292    /// Returns:
293    /// * [`V::Primitive`] - The previous value associated with the key. (before NAND operation)
294    pub fn fetch_nand(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
295        let map = self.inner.pin();
296
297        let entry = map.get(&key);
298
299        if let Some(entry) = entry {
300            let result = entry.fetch_nand(value, ordering);
301
302            if !(result & value) == V::ZERO {
303                map.remove(&key);
304            }
305
306            result
307        } else {
308            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
309
310            result.fetch_nand(value, ordering)
311        }
312    }
313
314    /// Atomically performs a bitwise OR operation on the counter for the given key.
315    ///
316    /// If the key is missing, the new value is inserted and zero is returned.
317    ///
318    /// Arguments:
319    /// * `key` - The key to perform the OR operation on.
320    /// * `value` - The value to OR with.
321    /// * `ordering` - The memory ordering to use.
322    ///
323    /// Returns:
324    /// * [`V::Primitive`] - The previous value associated with the key. (before OR operation)
325    pub fn fetch_or(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
326        let map = self.inner.pin();
327
328        let entry = map.get(&key);
329
330        if let Some(entry) = entry {
331            let result = entry.fetch_or(value, ordering);
332
333            if result | value == V::ZERO {
334                map.remove(&key);
335            }
336
337            result
338        } else {
339            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
340
341            result.fetch_or(value, ordering)
342        }
343    }
344
345    /// Atomically performs a bitwise XOR operation on the counter for the given key.
346    ///
347    /// If the key is missing, the new value is inserted and zero is returned.
348    ///
349    /// If the resulting value is zero after the XOR operation, the entry is removed.
350    ///
351    /// Arguments:
352    /// * `key` - The key to perform the XOR operation on.
353    /// * `value` - The value to XOR with.
354    /// * `ordering` - The memory ordering to use.
355    ///
356    /// Returns:
357    /// * [`V::Primitive`] - The previous value associated with the key. (before XOR operation)
358    pub fn fetch_xor(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
359        let map = self.inner.pin();
360
361        let entry = map.get(&key);
362
363        if let Some(entry) = entry {
364            let result = entry.fetch_xor(value, ordering);
365
366            if result ^ value == V::ZERO {
367                map.remove(&key);
368            }
369
370            result
371        } else {
372            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
373
374            result.fetch_xor(value, ordering)
375        }
376    }
377
378    /// Atomically sets the counter for the given key to the maximum of its current value and the
379    /// given value.
380    ///
381    /// If the key is missing and the value is higher than zero, the new value is inserted and zero
382    /// is returned. Otherwise, nothing is done and zero is returned.
383    ///
384    /// Arguments:
385    /// * `key` - The key to perform the max operation on.
386    /// * `value` - The value to compare with.
387    /// * `ordering` - The memory ordering to use.
388    ///
389    /// Returns:
390    /// * [`V::Primitive`] - The previous value associated with the key. (before max operation)
391    pub fn fetch_max(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
392        let map = self.inner.pin();
393
394        let entry = map.get(&key);
395
396        if let Some(entry) = entry {
397            let result = entry.fetch_max(value, ordering);
398
399            if value >= result && value == V::ZERO {
400                map.remove(&key);
401            }
402
403            result
404        } else {
405            // If there's no value (i.e. entry == None), the default is zero. If the value is less
406            // than or equal to zero, then nothing changes and we return zero.
407            if value <= V::ZERO {
408                return V::ZERO;
409            }
410
411            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
412
413            result.fetch_max(value, ordering)
414        }
415    }
416
417    /// Atomically sets the counter for the given key to the minimum of its current value and the
418    /// given value.
419    ///
420    /// If the key is missing and the value is lower than zero, the new value is inserted and zero
421    /// is returned. Otherwise, nothing is done and zero is returned.
422    ///
423    /// Arguments:
424    /// * `key` - The key to perform the min operation on.
425    /// * `value` - The value to compare with.
426    /// * `ordering` - The memory ordering to use.
427    ///
428    /// Returns:
429    /// * [`V::Primitive`] - The previous value associated with the key. (before min operation)
430    pub fn fetch_min(&self, key: K, value: V::Primitive, ordering: Ordering) -> V::Primitive {
431        let map = self.inner.pin();
432
433        let entry = map.get(&key);
434
435        if let Some(entry) = entry {
436            let result = entry.fetch_min(value, ordering);
437
438            if value <= result && value == V::ZERO {
439                map.remove(&key);
440            }
441
442            result
443        } else {
444            // If there's no value (i.e. entry == None), the default is zero. If the value is
445            // higher than or equal to zero, then nothing changes and we return zero.
446            if value >= V::ZERO {
447                return V::ZERO;
448            }
449
450            let result = map.get_or_insert(key, Arc::new(V::new(V::ZERO)));
451
452            result.fetch_min(value, ordering)
453        }
454    }
455
456    /// Removes all entries in the [`HyperCounter`].
457    pub fn clear(&self) {
458        let map = self.inner.pin();
459        map.clear();
460    }
461
462    /// Scans all entries in the [`HyperCounter`] and applies the provided function to each
463    /// key-value.
464    ///
465    /// Arguments:
466    /// * `f` - The function to apply to each key-value pair.
467    /// * `ordering` - The memory ordering to use when loading values.
468    pub fn scan(&self, mut f: impl FnMut(&K, &V::Primitive), ordering: Ordering) {
469        let map = self.inner.pin();
470        map.iter().for_each(|(k, v)| {
471            let value = v.load(ordering);
472
473            f(k, &value);
474        });
475    }
476}
477
478impl<K, V, H> HyperCounter<K, V, H>
479where
480    K: Eq + Hash + Clone,
481    V: AtomicNumber,
482    H: BuildHasher + Default,
483{
484    /// Retains only the entries specified by the predicate function.
485    /// 
486    /// Arguments:
487    /// * `f` - The predicate function to determine which entries to retain.
488    /// * `ordering_load` - The memory ordering to use when loading values.
489    /// * `ordering_remove` - The memory ordering to use when removing entries.
490    pub fn retain(
491        &self,
492        mut f: impl FnMut(&K, &V::Primitive) -> bool,
493        ordering_load: Ordering,
494        ordering_remove: Ordering,
495    ) {
496        let map = self.inner.pin();
497
498        map.iter().for_each(|(k, v)| {
499            let value = v.load(ordering_load);
500
501            if !f(k, &value) {
502                if value > V::ZERO {
503                    self.fetch_sub(k.clone(), value, ordering_remove);
504                } else {
505                    self.fetch_add(k.clone(), !value, ordering_remove);
506                }
507            }
508        });
509    }
510}
511
512impl<K, V> Default for HyperCounter<K, V>
513where
514    K: Eq + Hash,
515    V: AtomicNumber,
516{
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::sync::atomic::AtomicUsize;
526
527    #[test]
528    fn test_hypercounter_basic() {
529        let counter: HyperCounter<String, AtomicUsize> = HyperCounter::new();
530
531        assert_eq!(counter.len(), 0);
532        assert!(counter.is_empty());
533
534        let prev = counter.fetch_add("apple".to_string(), 5, Ordering::SeqCst);
535        assert_eq!(prev, 0);
536        assert_eq!(counter.len(), 1);
537        assert!(!counter.is_empty());
538
539        let prev = counter.fetch_add("apple".to_string(), 3, Ordering::SeqCst);
540        assert_eq!(prev, 5);
541
542        let prev = counter.fetch_add("banana".to_string(), 2, Ordering::SeqCst);
543        assert_eq!(prev, 0);
544        assert_eq!(counter.len(), 2);
545
546        let prev = counter.fetch_sub("apple".to_string(), 8, Ordering::SeqCst);
547        assert_eq!(prev, 8);
548        let load = counter.load(&"apple".to_string(), Ordering::SeqCst);
549        assert_eq!(load, 0);
550        assert_eq!(counter.len(), 1); // "apple" should be removed
551
552        let prev = counter.fetch_sub("banana".to_string(), 2, Ordering::SeqCst);
553        assert_eq!(prev, 2);
554        assert_eq!(counter.len(), 0); // "banana" should be removed
555    }
556
557    #[test]
558    fn test_hypercounter_expand() {
559        let counter: HyperCounter<usize, AtomicUsize> = HyperCounter::new();
560
561        for i in 0..100 {
562            counter.fetch_add(i, i, Ordering::SeqCst);
563            assert_eq!(counter.len(), i + 1);
564        }
565
566        assert_eq!(counter.len(), 100);
567
568        for i in 0..100 {
569            let load = counter.load(&i, Ordering::SeqCst);
570            assert_eq!(load, i);
571        }
572    }
573
574    #[test]
575    fn test_hypercounter_remove() {
576        let counter: HyperCounter<usize, AtomicUsize> = HyperCounter::new();
577
578        for i in 0..100 {
579            counter.fetch_add(i, i, Ordering::SeqCst);
580        }
581
582        for i in 0..100 {
583            let prev = counter.fetch_sub(i, i, Ordering::SeqCst);
584            assert_eq!(prev, i);
585            let load = counter.load(&i, Ordering::SeqCst);
586            assert_eq!(load, 0);
587
588            assert_eq!(counter.len(), 99 - i);
589        }
590    }
591
592    #[test]
593    fn test_hypercounter_orderings() {
594        let counter: HyperCounter<String, AtomicUsize> = HyperCounter::new();
595
596        let orderings = [
597            Ordering::AcqRel,
598            Ordering::Acquire,
599            Ordering::Release,
600            Ordering::SeqCst,
601            Ordering::Relaxed,
602        ];
603
604        for &ordering in &orderings {
605            counter.fetch_add("key".to_string(), 10, ordering);
606            counter.fetch_sub("key".to_string(), 5, ordering);
607            counter.fetch_and(&"key".to_string(), 7, ordering);
608            counter.fetch_nand("key".to_string(), 3, ordering);
609            counter.fetch_or("key".to_string(), 12, ordering);
610            counter.fetch_xor("key".to_string(), 6, ordering);
611            counter.fetch_max("key".to_string(), 15, ordering);
612            counter.fetch_min("key".to_string(), 5, ordering);
613            counter.swap("key".to_string(), 20, ordering);
614        }
615
616        let load_orderings = [Ordering::Acquire, Ordering::SeqCst, Ordering::Relaxed];
617
618        for &ordering in &load_orderings {
619            counter.load(&"key".to_string(), ordering);
620        }
621    }
622
623    #[test]
624    fn test_hypercounter_concurrency() {
625        let counter: Arc<HyperCounter<String, AtomicUsize>> = Arc::new(HyperCounter::new());
626
627        let mut handles = Vec::new();
628
629        for _ in 0..10 {
630            let counter = Arc::clone(&counter);
631
632            let handle = std::thread::spawn(move || {
633                let key = "key".to_string();
634
635                for i in 0..1_000_000 {
636                    if i % 2 == 0 {
637                        counter.fetch_add(key.clone(), 1, Ordering::Relaxed)
638                    } else {
639                        counter.fetch_sub(key.clone(), 1, Ordering::Relaxed)
640                    };
641                }
642            });
643
644            handles.push(handle);
645        }
646
647        for handle in handles {
648            handle.join().unwrap();
649        }
650    }
651}