fast_counter/counter.rs
1use std::fmt;
2use std::sync::atomic::{AtomicIsize, AtomicUsize, Ordering};
3use std::cell::Cell;
4
5use crate::utils::{CachePadded, make_new_padded_counter};
6use crate::safe_getters::SafeGetters;
7
8static THREAD_COUNTER: AtomicUsize = AtomicUsize::new(1);
9
10thread_local! {
11 static THREAD_ID: Cell<usize> = Cell::new(THREAD_COUNTER.fetch_add(1, Ordering::SeqCst));
12}
13
14/// A sharded atomic counter
15///
16/// ConcurrentCounter shards cacheline aligned AtomicIsizes across a vector for faster updates in
17/// a high contention scenarios.
18pub struct ConcurrentCounter {
19 cells: Vec<CachePadded::<AtomicIsize>>,
20}
21
22impl fmt::Debug for ConcurrentCounter {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_struct("ConcurrentCounter")
25 .field("sum", &self.sum())
26 .field("cells", &self.cells.len())
27 .finish()
28 }
29}
30
31impl ConcurrentCounter {
32 /// Creates a new ConcurrentCounter with a minimum of the `count` cells. Concurrent counter
33 /// will align the `count` to the next power of two for better speed when doing the modulus.
34 ///
35 /// # Examples
36 ///
37 /// ```
38 /// use fast_counter::ConcurrentCounter;
39 ///
40 /// let counter = ConcurrentCounter::new(10);
41 /// ```
42 #[inline]
43 pub fn new(count: usize) -> Self {
44 let count = count.next_power_of_two();
45 Self {
46 cells: (0..count)
47 .into_iter()
48 .map(|_| make_new_padded_counter())
49 .collect(),
50 }
51 }
52
53 #[inline]
54 fn thread_id(&self) -> usize {
55 THREAD_ID.with(|id| {
56 id.get()
57 })
58 }
59
60 /// Adds the value to the counter, internally with is using `add_with_ordering` with a
61 /// `Ordering::Relaxed` and is mainly for convenience.
62 ///
63 /// ConcurrentCounter will identify a cell to add the `value` too with using a thread_local
64 /// which will try to aleviate the contention on a single number
65 ///
66 /// # Examples
67 ///
68 /// ```
69 /// use fast_counter::ConcurrentCounter;
70 ///
71 /// let counter = ConcurrentCounter::new(10);
72 /// counter.add(1);
73 /// counter.add(-1);
74 /// ```
75 #[inline]
76 pub fn add(&self, value: isize) {
77 self.add_with_ordering(value, Ordering::Relaxed)
78 }
79
80 /// ConcurrentCounter will identify a cell to add the `value` too with using a thread_local
81 /// which will try to aleviate the contention on a single number. The cell will be updated
82 /// atomically using the ordering provided in `ordering`
83 ///
84 /// # Examples
85 ///
86 /// ```
87 /// use fast_counter::ConcurrentCounter;
88 /// use std::sync::atomic::Ordering;
89 ///
90 /// let counter = ConcurrentCounter::new(10);
91 /// counter.add_with_ordering(1, Ordering::SeqCst);
92 /// counter.add_with_ordering(-1, Ordering::Relaxed);
93 /// ```
94 #[inline]
95 pub fn add_with_ordering(&self, value: isize, ordering: Ordering) {
96 let c = self.cells.safely_get(self.thread_id() & (self.cells.len() - 1));
97 c.value.fetch_add(value, ordering);
98 }
99
100 /// This will fetch the sum of the concurrent counter be iterating through each of the cells
101 /// and loading the values. Internally this uses `sum_with_ordering` with a `Relaxed` ordering.
102 ///
103 /// Due to the fact the cells are sharded and the concurrent nature of the library this sum
104 /// may be slightly inaccurate. For example if used in a concurrent map and using
105 /// ConcurrentCounter to track the length, depending on the ordering the length may be returned
106 /// as a negative value.
107 ///
108 /// # Examples
109 ///
110 /// ```rust
111 /// use fast_counter::ConcurrentCounter;
112 ///
113 /// let counter = ConcurrentCounter::new(10);
114 ///
115 /// counter.add(1);
116 ///
117 /// let sum = counter.sum();
118 ///
119 /// assert_eq!(sum, 1);
120 /// ```
121 #[inline]
122 pub fn sum(&self) -> isize {
123 self.sum_with_ordering(Ordering::Relaxed)
124 }
125
126 /// This will fetch the sum of the concurrent counter be iterating through each of the cells
127 /// and loading the values with the ordering defined by `ordering`.
128 ///
129 /// Due to the fact the cells are sharded and the concurrent nature of the library this sum
130 /// may be slightly inaccurate. For example if used in a concurrent map and using
131 /// ConcurrentCounter to track the length, depending on the ordering the length may be returned
132 /// as a negative value.
133 ///
134 /// # Examples
135 ///
136 /// ```rust
137 /// use std::sync::atomic::Ordering;
138 /// use fast_counter::ConcurrentCounter;
139 ///
140 /// let counter = ConcurrentCounter::new(10);
141 ///
142 /// counter.add(1);
143 ///
144 /// let sum = counter.sum_with_ordering(Ordering::SeqCst);
145 ///
146 /// assert_eq!(sum, 1);
147 /// ```
148 #[inline]
149 pub fn sum_with_ordering(&self, ordering: Ordering) -> isize {
150 self.cells.iter().map(|c| c.value.load(ordering)).sum()
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use crate::ConcurrentCounter;
157
158 #[test]
159 fn basic_test() {
160 let counter = ConcurrentCounter::new(1);
161 counter.add(1);
162 assert_eq!(counter.sum(), 1);
163 }
164
165 #[test]
166 fn increment_multiple_times() {
167 let counter = ConcurrentCounter::new(1);
168 counter.add(1);
169 counter.add(1);
170 counter.add(1);
171 assert_eq!(counter.sum(), 3);
172 }
173
174 #[test]
175 fn two_threads_incrementing_concurrently() {
176 // Spin up two threads that increment the counter concurrently
177 let counter = ConcurrentCounter::new(2);
178
179 std::thread::scope(|s| {
180 for _ in 0..2 {
181 s.spawn(|| {
182 counter.add(1);
183 });
184 }
185 });
186
187
188 assert_eq!(counter.sum(), 2);
189 }
190
191 #[test]
192 fn two_threads_incrementing_multiple_times_concurrently() {
193 const WRITE_COUNT: isize = 100_000;
194 // Spin up two threads that increment the counter concurrently
195 let counter = ConcurrentCounter::new(2);
196
197 std::thread::scope(|s| {
198 for _ in 0..2 {
199 s.spawn(|| {
200 for _ in 0..WRITE_COUNT {
201 counter.add(1);
202 }
203 });
204 }
205 });
206
207
208 assert_eq!(counter.sum(), 2 * WRITE_COUNT);
209 }
210
211 #[test]
212 fn multple_threads_incrementing_multiple_times_concurrently() {
213 const WRITE_COUNT: isize = 1_000_000;
214 const THREAD_COUNT: isize = 8;
215 // Spin up two threads that increment the counter concurrently
216 let counter = ConcurrentCounter::new(THREAD_COUNT as usize);
217
218 std::thread::scope(|s| {
219 for _ in 0..THREAD_COUNT {
220 s.spawn(|| {
221 for _ in 0..WRITE_COUNT {
222 counter.add(1);
223 }
224 });
225 }
226 });
227
228
229 assert_eq!(counter.sum(), THREAD_COUNT * WRITE_COUNT);
230 }
231
232 #[test]
233 fn debug_works_as_expected() {
234 const WRITE_COUNT: isize = 1_000_000;
235 const THREAD_COUNT: isize = 8;
236 // Spin up two threads that increment the counter concurrently
237 let counter = ConcurrentCounter::new(THREAD_COUNT as usize);
238
239 for _ in 0..WRITE_COUNT {
240 counter.add(1);
241 }
242
243 assert_eq!(counter.sum(), WRITE_COUNT);
244
245 assert_eq!(format!("Counter is: {counter:?}"), "Counter is: ConcurrentCounter { sum: 1000000, cells: 8 }")
246 }
247}