atomic_counter/
lib.rs

1
2//! # AtomicCounter
3//!
4//! Atomic (thread-safe) counters for Rust.
5//!
6//! This crate contains an [`AtomicCounter`](trait.AtomicCounter.html) trait
7//! that can safely be shared across threads.
8//!
9//! This crate provides two implementations:
10//!
11//! * [`RelaxedCounter`](struct.RelaxedCounter.html) which is suitable for
12//!     e.g. collecting metrics or generate IDs, but which does not provide
13//!     ["Sequential Consistency"](https://doc.rust-lang.org/nomicon/atomics.html#sequentially-consistent).
14//!     `RelaxedCounter` uses [`Relaxed`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.Relaxed)
15//!     memory ordering.
16//!
17//! * [`ConsistentCounter`](struct.ConsistentCounter.html) which provides the
18//!     same interface but is sequentially consistent. Use this counter if the
19//!     order of update from multiple threads is important.
20//!     `ConsistentCounter` uses [`Sequentially Consistent`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.SeqCst)
21//!     memory ordering.
22//!
23//! Both implementations are lock-free. Both are a very thin layer over
24//! [`AtomicUsize`](https://doc.rust-lang.org/std/sync/atomic/struct.AtomicUsize.html)
25//! which is more powerful but might be harder to use correctly.
26//!
27//! ## Which counter to use
28//!
29//! * If you are just collecting metrics, the `RelaxedCounter` is probably right choice.
30//!
31//! * If you are generating IDs, but don't make strong assumptions (like allocating
32//!     memory based on the ID count), `RelaxedCounter` is probably the right choice.
33//!
34//! * If you are generating multiple IDs where you maintain an ordering
35//!     invariant (e.g. ID `a` is always greater than ID `b`), you need "Sequential
36//!     Consistency" and thus need to use `ConsistentCounter`. The same is true
37//!     for all use cases where the _ordering_ of incrementing the counter is
38//!     important.
39//!
40//! ## No updates are lost - It's just about the ordering!
41//!
42//! Note that in both implementations, _no count is lost_ and all operations are atomic.
43//! The difference is _only_ in how the order of operations are observed by different
44//! threads.
45//!
46//! ## Example:
47//! Assume `a` is 5 and `b` is 4. You always want to maintain `a > b`.
48//!
49//! Thread 1 executes this code:
50//!
51//! ```rust,ignore
52//!
53//! a.inc();
54//! b.inc();
55//! ```
56//!
57//! Thread 2 gets counts:
58//!
59//! ```rust,ignore
60//!
61//! let a_local = a.get();
62//! let b_local = b.get();
63//! ```
64//!
65//! What are the values for `a_local` and `b_local`? That depends on the order
66//! in which thread 1 and 2 have run:
67//!
68//! * `a_local` could still be 5 and `b_local` is still be 4 (e.g. if thread 2 ran before thread 1)
69//! * `a_local` could be increment to 6 while `b_local` is still at 4 (e.g. if thread 1 and 2 ran in parallel)
70//! * `a_local` could be increment to 6 and `b_local` be incremented to 5 (e.g. if thread 2 ran after thread 1).
71//! * Additionally, if at least one counter is a `RelaxedCounter`, we cannot make
72//!     assumption on the order of `a.inc()` and `b.inc()`. Thus, in this case
73//!     thread 2 can also observe `a_local` to be 5 (not incremented yet) but
74//!     `b_local` to be incremented to 5, _breaking the invariant_ `a > b`.
75//!     Note that if thread 2 (or any other thread) `get()` the counts
76//!     again, at some point they will observe both values to be incremented.
77//!     No operations will be lost. It is only the _ordering_ of the operations
78//!     that cannot be assumed if `Ordering` is `Relaxed`.
79//!
80//! So in order to maintain invariants such as `a > b` across multiple threads,
81//! use `ConsistentCounter`.
82
83use std::sync::atomic::AtomicUsize;
84use std::sync::atomic::Ordering::{Relaxed, SeqCst};
85
86/// Provides an atomic counter trait that can be shared across threads.
87pub trait AtomicCounter: Send + Sync {
88    /// Underlying primitive type that is being shared atomically.
89    type PrimitiveType;
90
91    /// Atomically increments the counter by one, returning the _previous_ value.
92    fn inc(&self) -> Self::PrimitiveType;
93
94    /// Atomically increments the counter by amount, returning the _previous_ value.
95    fn add(&self, amount: Self::PrimitiveType) -> Self::PrimitiveType;
96
97    /// Atomically gets the current value of the counter, without modifying the counter.
98    fn get(&self) -> Self::PrimitiveType;
99
100    /// Atomically returns the current value of the counter, while resetting to count to zero.
101    fn reset(&self) -> Self::PrimitiveType;
102
103    /// Consume the atomic counter and return the primitive type.
104    ///
105    /// This is safe because passing self by value guarantees that no other threads are concurrently accessing the atomic data.
106    fn into_inner(self) -> Self::PrimitiveType;
107}
108
109/// Implementation of [`AtomicCounter`](trait.AtomicCounter.html) that uses
110/// [`Relaxed`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.Relaxed)
111/// memory ordering.
112///
113/// See [crate level documentation](index.html) for more details.
114///
115/// Note that all operations wrap if the counter is incremented beyond usize::max_value().
116#[derive(Debug, Default)]
117pub struct RelaxedCounter(AtomicUsize);
118
119impl RelaxedCounter {
120
121    /// Creates a new counter with initial_value
122    pub fn new(initial_count: usize) -> RelaxedCounter {
123        RelaxedCounter(AtomicUsize::new(initial_count))
124    }
125}
126
127impl AtomicCounter for RelaxedCounter {
128    type PrimitiveType = usize;
129
130    fn inc(&self) -> usize {
131        self.add(1)
132    }
133
134    fn add(&self, amount: usize) -> usize {
135        self.0.fetch_add(amount, Relaxed)
136    }
137
138    fn get(&self) -> usize {
139        self.0.load(Relaxed)
140    }
141
142    fn reset(&self) -> usize {
143        self.0.swap(0, Relaxed)
144    }
145
146    fn into_inner(self) -> usize {
147        self.0.into_inner()
148    }
149}
150
151/// Implementation of [`AtomicCounter`](trait.AtomicCounter.html) that uses
152/// [`Sequentially Consistent`](https://doc.rust-lang.org/std/sync/atomic/enum.Ordering.html#variant.SeqCst)
153/// memory ordering.
154///
155/// See [crate level documentation](index.html) for more details.
156///
157/// Note that all operations wrap if the counter is incremented beyond usize::max_value().
158#[derive(Debug, Default)]
159pub struct ConsistentCounter(AtomicUsize);
160
161impl ConsistentCounter {
162
163    /// Creates a new counter with initial_value
164    pub fn new(initial_count: usize) -> ConsistentCounter {
165        ConsistentCounter(AtomicUsize::new(initial_count))
166    }
167}
168
169impl AtomicCounter for ConsistentCounter {
170    type PrimitiveType = usize;
171
172    fn inc(&self) -> usize {
173        self.add(1)
174    }
175
176    fn add(&self, amount: usize) -> usize {
177        self.0.fetch_add(amount, SeqCst)
178    }
179
180    fn get(&self) -> usize {
181        self.0.load(SeqCst)
182    }
183
184    fn reset(&self) -> usize {
185        self.0.swap(0, SeqCst)
186    }
187
188    fn into_inner(self) -> usize {
189        self.0.into_inner()
190    }
191}
192
193#[cfg(test)]
194mod tests {
195
196    use std::fmt::Debug;
197    use std::thread;
198    use std::sync::Arc;
199    use std::ops::Deref;
200
201    use super::*;
202
203    const NUM_THREADS: usize = 29;
204    const NUM_ITERATIONS: usize = 7_000_000;
205
206    fn test_simple_with<Counter>(counter: Counter)
207        where Counter: AtomicCounter<PrimitiveType=usize>
208    {
209        counter.reset();
210        assert_eq!(0, counter.add(5));
211        assert_eq!(5, counter.add(3));
212        assert_eq!(8, counter.inc());
213        assert_eq!(9, counter.inc());
214        assert_eq!(10, counter.get());
215        assert_eq!(10, counter.get());
216    }
217
218    #[test]
219    fn test_simple_relaxed() {
220        test_simple_with(RelaxedCounter::new(0))
221    }
222
223    #[test]
224    fn test_simple_consistent() {
225        test_simple_with(ConsistentCounter::new(0))
226    }
227
228    fn test_inc_with<Counter>(counter: Arc<Counter>)
229        where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
230    {
231        let mut join_handles = Vec::new();
232        println!("test_inc: Spawning {} threads, each with {} iterations...",
233                 NUM_THREADS,
234                 NUM_ITERATIONS);
235        for _ in 0..NUM_THREADS {
236            let counter_ref = counter.clone();
237            join_handles.push(thread::spawn(move || {
238                //make sure we're not going though Arc on each iteration
239                let counter: &Counter = counter_ref.deref();
240                for _ in 0..NUM_ITERATIONS {
241                    counter.inc();
242                }
243            }));
244        }
245        for handle in join_handles {
246            handle.join().unwrap();
247        }
248        let count = Arc::try_unwrap(counter).unwrap().into_inner();
249        println!("test_inc: Got count: {}", count);
250        assert_eq!(NUM_THREADS * NUM_ITERATIONS, count);
251    }
252
253    #[test]
254    fn test_inc_relaxed() {
255        test_inc_with(Arc::new(RelaxedCounter::new(0)));
256    }
257
258    #[test]
259    fn test_inc_consistent() {
260        test_inc_with(Arc::new(ConsistentCounter::new(0)));
261    }
262
263    fn test_add_with<Counter>(counter: Arc<Counter>)
264        where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
265    {
266        let mut join_handles = Vec::new();
267        println!("test_add: Spawning {} threads, each with {} iterations...",
268                 NUM_THREADS,
269                 NUM_ITERATIONS);
270        let mut expected_count = 0;
271        for to_add in 0..NUM_THREADS {
272            let counter_ref = counter.clone();
273            expected_count += to_add * NUM_ITERATIONS;
274            join_handles.push(thread::spawn(move || {
275                //make sure we're not going though Arc on each iteration
276                let counter: &Counter = counter_ref.deref();
277                for _ in 0..NUM_ITERATIONS {
278                    counter.add(to_add);
279                }
280            }));
281        }
282        for handle in join_handles {
283            handle.join().unwrap();
284        }
285        let count = Arc::try_unwrap(counter).unwrap().into_inner();
286        println!("test_add: Expected count: {}, got count: {}",
287                 expected_count,
288                 count);
289        assert_eq!(expected_count, count);
290    }
291
292    #[test]
293    fn test_add_relaxed() {
294        test_add_with(Arc::new(RelaxedCounter::new(0)));
295    }
296
297    #[test]
298    fn test_add_consistent() {
299        test_add_with(Arc::new(ConsistentCounter::new(0)));
300    }
301
302    fn test_reset_with<Counter>(counter: Arc<Counter>)
303        where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
304    {
305        let mut join_handles = Vec::new();
306        println!("test_add_reset: Spawning {} threads, each with {} iterations...",
307                 NUM_THREADS,
308                 NUM_ITERATIONS);
309        let mut expected_count = 0;
310        for to_add in 0..NUM_THREADS {
311            expected_count += to_add * NUM_ITERATIONS;
312        }
313
314        // setup thread that `reset()`s all the time
315        let counter_ref = counter.clone();
316        let reset_handle = thread::spawn(move || {
317            // Usually, you would check for some better termination condition.
318            // I don't want to pollute my test with thread synchronization
319            // operations outside of AtomicCounter, hence this approach.
320            let mut total_count = 0;
321            let counter: &Counter = counter_ref.deref();
322            while total_count < expected_count {
323                total_count += counter.reset();
324            }
325            // Ok, now we got the total_count but this could just be lucky.
326            // Better do some more resets to be sure... ;)
327            for _ in 0..NUM_ITERATIONS {
328                total_count += counter.reset();
329            }
330            total_count
331        });
332
333        for to_add in 0..NUM_THREADS {
334            let counter_ref = counter.clone();
335
336            join_handles.push(thread::spawn(move || {
337                //make sure we're not going though Arc on each iteration
338                let counter: &Counter = counter_ref.deref();
339                for _ in 0..NUM_ITERATIONS {
340                    counter.add(to_add);
341                }
342            }));
343        }
344        for handle in join_handles {
345            handle.join().unwrap();
346        }
347        let actual_count = reset_handle.join().unwrap();
348        println!("test_add_reset: Expected count: {}, got count: {}",
349                 expected_count,
350                 actual_count);
351        assert_eq!(expected_count, actual_count);
352    }
353
354    #[test]
355    fn test_reset_consistent() {
356        test_reset_with(Arc::new(ConsistentCounter::new(0)));
357    }
358
359    #[test]
360    fn test_reset_relaxed() {
361        test_reset_with(Arc::new(RelaxedCounter::new(0)));
362    }
363
364}