atomic_counter/lib.rs
1
2//! # AtomicCounter
3//!
4//! Atomic (thread-safe) counters for Rust.
5//!
6//! This crate contains an [`AtomicCounter`](trait.AtomicCounter.html) trait
7//! that can safely be shared across threads.
8//!
9//! This crate provides two implementations:
10//!
11//! * [`RelaxedCounter`](struct.RelaxedCounter.html) which is suitable for
12//! e.g. collecting metrics or generate IDs, but which does not provide
13//! ["Sequential Consistency"](https://doc.rust-lang.org/nomicon/atomics.html#sequentially-consistent).
14//! `RelaxedCounter` uses [`Relaxed`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.Relaxed)
15//! memory ordering.
16//!
17//! * [`ConsistentCounter`](struct.ConsistentCounter.html) which provides the
18//! same interface but is sequentially consistent. Use this counter if the
19//! order of update from multiple threads is important.
20//! `ConsistentCounter` uses [`Sequentially Consistent`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.SeqCst)
21//! memory ordering.
22//!
23//! Both implementations are lock-free. Both are a very thin layer over
24//! [`AtomicUsize`](https://doc.rust-lang.org/std/sync/atomic/struct.AtomicUsize.html)
25//! which is more powerful but might be harder to use correctly.
26//!
27//! ## Which counter to use
28//!
29//! * If you are just collecting metrics, the `RelaxedCounter` is probably right choice.
30//!
31//! * If you are generating IDs, but don't make strong assumptions (like allocating
32//! memory based on the ID count), `RelaxedCounter` is probably the right choice.
33//!
34//! * If you are generating multiple IDs where you maintain an ordering
35//! invariant (e.g. ID `a` is always greater than ID `b`), you need "Sequential
36//! Consistency" and thus need to use `ConsistentCounter`. The same is true
37//! for all use cases where the _ordering_ of incrementing the counter is
38//! important.
39//!
40//! ## No updates are lost - It's just about the ordering!
41//!
42//! Note that in both implementations, _no count is lost_ and all operations are atomic.
43//! The difference is _only_ in how the order of operations are observed by different
44//! threads.
45//!
46//! ## Example:
47//! Assume `a` is 5 and `b` is 4. You always want to maintain `a > b`.
48//!
49//! Thread 1 executes this code:
50//!
51//! ```rust,ignore
52//!
53//! a.inc();
54//! b.inc();
55//! ```
56//!
57//! Thread 2 gets counts:
58//!
59//! ```rust,ignore
60//!
61//! let a_local = a.get();
62//! let b_local = b.get();
63//! ```
64//!
65//! What are the values for `a_local` and `b_local`? That depends on the order
66//! in which thread 1 and 2 have run:
67//!
68//! * `a_local` could still be 5 and `b_local` is still be 4 (e.g. if thread 2 ran before thread 1)
69//! * `a_local` could be increment to 6 while `b_local` is still at 4 (e.g. if thread 1 and 2 ran in parallel)
70//! * `a_local` could be increment to 6 and `b_local` be incremented to 5 (e.g. if thread 2 ran after thread 1).
71//! * Additionally, if at least one counter is a `RelaxedCounter`, we cannot make
72//! assumption on the order of `a.inc()` and `b.inc()`. Thus, in this case
73//! thread 2 can also observe `a_local` to be 5 (not incremented yet) but
74//! `b_local` to be incremented to 5, _breaking the invariant_ `a > b`.
75//! Note that if thread 2 (or any other thread) `get()` the counts
76//! again, at some point they will observe both values to be incremented.
77//! No operations will be lost. It is only the _ordering_ of the operations
78//! that cannot be assumed if `Ordering` is `Relaxed`.
79//!
80//! So in order to maintain invariants such as `a > b` across multiple threads,
81//! use `ConsistentCounter`.
82
83use std::sync::atomic::AtomicUsize;
84use std::sync::atomic::Ordering::{Relaxed, SeqCst};
85
86/// Provides an atomic counter trait that can be shared across threads.
87pub trait AtomicCounter: Send + Sync {
88 /// Underlying primitive type that is being shared atomically.
89 type PrimitiveType;
90
91 /// Atomically increments the counter by one, returning the _previous_ value.
92 fn inc(&self) -> Self::PrimitiveType;
93
94 /// Atomically increments the counter by amount, returning the _previous_ value.
95 fn add(&self, amount: Self::PrimitiveType) -> Self::PrimitiveType;
96
97 /// Atomically gets the current value of the counter, without modifying the counter.
98 fn get(&self) -> Self::PrimitiveType;
99
100 /// Atomically returns the current value of the counter, while resetting to count to zero.
101 fn reset(&self) -> Self::PrimitiveType;
102
103 /// Consume the atomic counter and return the primitive type.
104 ///
105 /// This is safe because passing self by value guarantees that no other threads are concurrently accessing the atomic data.
106 fn into_inner(self) -> Self::PrimitiveType;
107}
108
109/// Implementation of [`AtomicCounter`](trait.AtomicCounter.html) that uses
110/// [`Relaxed`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.Relaxed)
111/// memory ordering.
112///
113/// See [crate level documentation](index.html) for more details.
114///
115/// Note that all operations wrap if the counter is incremented beyond usize::max_value().
116#[derive(Debug, Default)]
117pub struct RelaxedCounter(AtomicUsize);
118
119impl RelaxedCounter {
120
121 /// Creates a new counter with initial_value
122 pub fn new(initial_count: usize) -> RelaxedCounter {
123 RelaxedCounter(AtomicUsize::new(initial_count))
124 }
125}
126
127impl AtomicCounter for RelaxedCounter {
128 type PrimitiveType = usize;
129
130 fn inc(&self) -> usize {
131 self.add(1)
132 }
133
134 fn add(&self, amount: usize) -> usize {
135 self.0.fetch_add(amount, Relaxed)
136 }
137
138 fn get(&self) -> usize {
139 self.0.load(Relaxed)
140 }
141
142 fn reset(&self) -> usize {
143 self.0.swap(0, Relaxed)
144 }
145
146 fn into_inner(self) -> usize {
147 self.0.into_inner()
148 }
149}
150
151/// Implementation of [`AtomicCounter`](trait.AtomicCounter.html) that uses
152/// [`Sequentially Consistent`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.SeqCst)
153/// memory ordering.
154///
155/// See [crate level documentation](index.html) for more details.
156///
157/// Note that all operations wrap if the counter is incremented beyond usize::max_value().
158#[derive(Debug, Default)]
159pub struct ConsistentCounter(AtomicUsize);
160
161impl ConsistentCounter {
162
163 /// Creates a new counter with initial_value
164 pub fn new(initial_count: usize) -> ConsistentCounter {
165 ConsistentCounter(AtomicUsize::new(initial_count))
166 }
167}
168
169impl AtomicCounter for ConsistentCounter {
170 type PrimitiveType = usize;
171
172 fn inc(&self) -> usize {
173 self.add(1)
174 }
175
176 fn add(&self, amount: usize) -> usize {
177 self.0.fetch_add(amount, SeqCst)
178 }
179
180 fn get(&self) -> usize {
181 self.0.load(SeqCst)
182 }
183
184 fn reset(&self) -> usize {
185 self.0.swap(0, SeqCst)
186 }
187
188 fn into_inner(self) -> usize {
189 self.0.into_inner()
190 }
191}
192
193#[cfg(test)]
194mod tests {
195
196 use std::fmt::Debug;
197 use std::thread;
198 use std::sync::Arc;
199 use std::ops::Deref;
200
201 use super::*;
202
203 const NUM_THREADS: usize = 29;
204 const NUM_ITERATIONS: usize = 7_000_000;
205
206 fn test_simple_with<Counter>(counter: Counter)
207 where Counter: AtomicCounter<PrimitiveType=usize>
208 {
209 counter.reset();
210 assert_eq!(0, counter.add(5));
211 assert_eq!(5, counter.add(3));
212 assert_eq!(8, counter.inc());
213 assert_eq!(9, counter.inc());
214 assert_eq!(10, counter.get());
215 assert_eq!(10, counter.get());
216 }
217
218 #[test]
219 fn test_simple_relaxed() {
220 test_simple_with(RelaxedCounter::new(0))
221 }
222
223 #[test]
224 fn test_simple_consistent() {
225 test_simple_with(ConsistentCounter::new(0))
226 }
227
228 fn test_inc_with<Counter>(counter: Arc<Counter>)
229 where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
230 {
231 let mut join_handles = Vec::new();
232 println!("test_inc: Spawning {} threads, each with {} iterations...",
233 NUM_THREADS,
234 NUM_ITERATIONS);
235 for _ in 0..NUM_THREADS {
236 let counter_ref = counter.clone();
237 join_handles.push(thread::spawn(move || {
238 //make sure we're not going though Arc on each iteration
239 let counter: &Counter = counter_ref.deref();
240 for _ in 0..NUM_ITERATIONS {
241 counter.inc();
242 }
243 }));
244 }
245 for handle in join_handles {
246 handle.join().unwrap();
247 }
248 let count = Arc::try_unwrap(counter).unwrap().into_inner();
249 println!("test_inc: Got count: {}", count);
250 assert_eq!(NUM_THREADS * NUM_ITERATIONS, count);
251 }
252
253 #[test]
254 fn test_inc_relaxed() {
255 test_inc_with(Arc::new(RelaxedCounter::new(0)));
256 }
257
258 #[test]
259 fn test_inc_consistent() {
260 test_inc_with(Arc::new(ConsistentCounter::new(0)));
261 }
262
263 fn test_add_with<Counter>(counter: Arc<Counter>)
264 where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
265 {
266 let mut join_handles = Vec::new();
267 println!("test_add: Spawning {} threads, each with {} iterations...",
268 NUM_THREADS,
269 NUM_ITERATIONS);
270 let mut expected_count = 0;
271 for to_add in 0..NUM_THREADS {
272 let counter_ref = counter.clone();
273 expected_count += to_add * NUM_ITERATIONS;
274 join_handles.push(thread::spawn(move || {
275 //make sure we're not going though Arc on each iteration
276 let counter: &Counter = counter_ref.deref();
277 for _ in 0..NUM_ITERATIONS {
278 counter.add(to_add);
279 }
280 }));
281 }
282 for handle in join_handles {
283 handle.join().unwrap();
284 }
285 let count = Arc::try_unwrap(counter).unwrap().into_inner();
286 println!("test_add: Expected count: {}, got count: {}",
287 expected_count,
288 count);
289 assert_eq!(expected_count, count);
290 }
291
292 #[test]
293 fn test_add_relaxed() {
294 test_add_with(Arc::new(RelaxedCounter::new(0)));
295 }
296
297 #[test]
298 fn test_add_consistent() {
299 test_add_with(Arc::new(ConsistentCounter::new(0)));
300 }
301
302 fn test_reset_with<Counter>(counter: Arc<Counter>)
303 where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
304 {
305 let mut join_handles = Vec::new();
306 println!("test_add_reset: Spawning {} threads, each with {} iterations...",
307 NUM_THREADS,
308 NUM_ITERATIONS);
309 let mut expected_count = 0;
310 for to_add in 0..NUM_THREADS {
311 expected_count += to_add * NUM_ITERATIONS;
312 }
313
314 // setup thread that `reset()`s all the time
315 let counter_ref = counter.clone();
316 let reset_handle = thread::spawn(move || {
317 // Usually, you would check for some better termination condition.
318 // I don't want to pollute my test with thread synchronization
319 // operations outside of AtomicCounter, hence this approach.
320 let mut total_count = 0;
321 let counter: &Counter = counter_ref.deref();
322 while total_count < expected_count {
323 total_count += counter.reset();
324 }
325 // Ok, now we got the total_count but this could just be lucky.
326 // Better do some more resets to be sure... ;)
327 for _ in 0..NUM_ITERATIONS {
328 total_count += counter.reset();
329 }
330 total_count
331 });
332
333 for to_add in 0..NUM_THREADS {
334 let counter_ref = counter.clone();
335
336 join_handles.push(thread::spawn(move || {
337 //make sure we're not going though Arc on each iteration
338 let counter: &Counter = counter_ref.deref();
339 for _ in 0..NUM_ITERATIONS {
340 counter.add(to_add);
341 }
342 }));
343 }
344 for handle in join_handles {
345 handle.join().unwrap();
346 }
347 let actual_count = reset_handle.join().unwrap();
348 println!("test_add_reset: Expected count: {}, got count: {}",
349 expected_count,
350 actual_count);
351 assert_eq!(expected_count, actual_count);
352 }
353
354 #[test]
355 fn test_reset_consistent() {
356 test_reset_with(Arc::new(ConsistentCounter::new(0)));
357 }
358
359 #[test]
360 fn test_reset_relaxed() {
361 test_reset_with(Arc::new(RelaxedCounter::new(0)));
362 }
363
364}