fast_counter/
counter.rs

1use std::fmt;
2use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
3use std::cell::Cell;
4
5use crate::utils::{CachePadded, make_new_padded_counter};
6use crate::safe_getters::SafeGetters;
7
8static THREAD_COUNTER: AtomicUsize = AtomicUsize::new(1);
9
10thread_local! {
11    static THREAD_ID: Cell<usize> = Cell::new(THREAD_COUNTER.fetch_add(1, Ordering::SeqCst));
12}
13
14/// A sharded atomic counter
15///
16/// ConcurrentCounter shards cacheline aligned AtomicIsizes across a vector for faster updates in
17/// a high contention scenarios. 
18pub struct ConcurrentCounter {
19    cells: Vec<CachePadded::<AtomicIsize>>,
20}
21
22impl fmt::Debug for ConcurrentCounter {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_struct("ConcurrentCounter")
25         .field("sum", &self.sum())
26         .field("cells", &self.cells.len())
27         .finish()
28    }
29}
30
31impl ConcurrentCounter {
32    /// Creates a new ConcurrentCounter with a minimum of the `count` cells. Concurrent counter
33    /// will align the `count` to the next power of two for better speed when doing the modulus.
34    ///
35    /// # Examples 
36    ///
37    /// ```
38    /// use fast_counter::ConcurrentCounter;
39    ///
40    /// let counter = ConcurrentCounter::new(10);
41    /// ```
42    #[inline]
43    pub fn new(count: usize) -> Self {
44        let count = count.next_power_of_two();
45        Self {
46            cells: (0..count)
47                .into_iter()
48                .map(|_| make_new_padded_counter())
49                .collect(),
50        }
51    }
52
53    #[inline]
54    fn thread_id(&self) -> usize {
55        THREAD_ID.with(|id| {
56            id.get()
57        })
58    }
59
60    /// Adds the value to the counter, internally with is using `add_with_ordering` with a
61    /// `Ordering::Relaxed` and is mainly for convenience. 
62    ///
63    /// ConcurrentCounter will identify a cell to add the `value` too with using a thread_local
64    /// which will try to aleviate the contention on a single number 
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use fast_counter::ConcurrentCounter;
70    ///
71    /// let counter = ConcurrentCounter::new(10);
72    /// counter.add(1);
73    /// counter.add(-1);
74    /// ```
75    #[inline]
76    pub fn add(&self, value: isize) {
77        self.add_with_ordering(value, Ordering::Relaxed)
78    }
79
80    /// ConcurrentCounter will identify a cell to add the `value` too with using a thread_local
81    /// which will try to aleviate the contention on a single number. The cell will be updated
82    /// atomically using the ordering provided in `ordering`
83    ///
84    /// # Examples
85    ///
86    /// ```
87    /// use fast_counter::ConcurrentCounter;
88    /// use std::sync::atomic::Ordering;
89    ///
90    /// let counter = ConcurrentCounter::new(10);
91    /// counter.add_with_ordering(1, Ordering::SeqCst);
92    /// counter.add_with_ordering(-1, Ordering::Relaxed);
93    /// ```
94    #[inline]
95    pub fn add_with_ordering(&self, value: isize, ordering: Ordering) {
96        let c = self.cells.safely_get(self.thread_id() & (self.cells.len() - 1));
97        c.value.fetch_add(value, ordering);
98    }
99
100    /// This will fetch the sum of the concurrent counter be iterating through each of the cells
101    /// and loading the values. Internally this uses `sum_with_ordering` with a `Relaxed` ordering.
102    ///
103    /// Due to the fact the cells are sharded and the concurrent nature of the library this sum
104    /// may be slightly inaccurate. For example if used in a concurrent map and using
105    /// ConcurrentCounter to track the length, depending on the ordering the length may be returned
106    /// as a negative value. 
107    ///
108    /// # Examples
109    ///
110    /// ```rust
111    /// use fast_counter::ConcurrentCounter;
112    ///
113    /// let counter = ConcurrentCounter::new(10);
114    ///
115    /// counter.add(1);
116    ///
117    /// let sum = counter.sum();
118    ///
119    /// assert_eq!(sum, 1);
120    /// ```
121    #[inline]
122    pub fn sum(&self) -> isize {
123        self.sum_with_ordering(Ordering::Relaxed)
124    }
125
126    /// This will fetch the sum of the concurrent counter be iterating through each of the cells
127    /// and loading the values with the ordering defined by `ordering`.
128    ///
129    /// Due to the fact the cells are sharded and the concurrent nature of the library this sum
130    /// may be slightly inaccurate. For example if used in a concurrent map and using
131    /// ConcurrentCounter to track the length, depending on the ordering the length may be returned
132    /// as a negative value. 
133    ///
134    /// # Examples
135    ///
136    /// ```rust
137    /// use std::sync::atomic::Ordering;
138    /// use fast_counter::ConcurrentCounter;
139    ///
140    /// let counter = ConcurrentCounter::new(10);
141    ///
142    /// counter.add(1);
143    ///
144    /// let sum = counter.sum_with_ordering(Ordering::SeqCst);
145    ///
146    /// assert_eq!(sum, 1);
147    /// ```
148    #[inline]
149    pub fn sum_with_ordering(&self, ordering: Ordering) -> isize {
150        self.cells.iter().map(|c| c.value.load(ordering)).sum()
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::ConcurrentCounter;
157
158    #[test]
159    fn basic_test() {
160        let counter = ConcurrentCounter::new(1);
161        counter.add(1);
162        assert_eq!(counter.sum(), 1);
163    }
164
165    #[test]
166    fn increment_multiple_times() {
167        let counter = ConcurrentCounter::new(1);
168        counter.add(1);
169        counter.add(1);
170        counter.add(1);
171        assert_eq!(counter.sum(), 3);
172    }
173
174    #[test]
175    fn two_threads_incrementing_concurrently() {
176        // Spin up two threads that increment the counter concurrently
177        let counter = ConcurrentCounter::new(2);
178
179        std::thread::scope(|s| {
180            for _ in 0..2 {
181                s.spawn(|| {
182                    counter.add(1);
183                });
184            }
185        });
186
187
188        assert_eq!(counter.sum(), 2);
189    }
190
191    #[test]
192    fn two_threads_incrementing_multiple_times_concurrently() {
193        const WRITE_COUNT: isize = 100_000;
194        // Spin up two threads that increment the counter concurrently
195        let counter = ConcurrentCounter::new(2);
196
197        std::thread::scope(|s| {
198            for _ in 0..2 {
199                s.spawn(|| {
200                    for _ in 0..WRITE_COUNT {
201                        counter.add(1);
202                    }
203                });
204            }
205        });
206
207
208        assert_eq!(counter.sum(), 2 * WRITE_COUNT);
209    }
210
211    #[test]
212    fn multple_threads_incrementing_multiple_times_concurrently() {
213        const WRITE_COUNT: isize = 1_000_000;
214        const THREAD_COUNT: isize = 8;
215        // Spin up two threads that increment the counter concurrently
216        let counter = ConcurrentCounter::new(THREAD_COUNT as usize);
217
218        std::thread::scope(|s| {
219            for _ in 0..THREAD_COUNT {
220                s.spawn(|| {
221                    for _ in 0..WRITE_COUNT {
222                        counter.add(1);
223                    }
224                });
225            }
226        });
227
228
229        assert_eq!(counter.sum(), THREAD_COUNT * WRITE_COUNT);
230    }
231
232    #[test]
233    fn debug_works_as_expected() {
234        const WRITE_COUNT: isize = 1_000_000;
235        const THREAD_COUNT: isize = 8;
236        // Spin up two threads that increment the counter concurrently
237        let counter = ConcurrentCounter::new(THREAD_COUNT as usize);
238
239        for _ in 0..WRITE_COUNT {
240            counter.add(1);
241        }
242
243        assert_eq!(counter.sum(), WRITE_COUNT);
244
245        assert_eq!(format!("Counter is: {counter:?}"), "Counter is: ConcurrentCounter { sum: 1000000, cells: 8 }")
246    }
247}