use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
pub trait AtomicCounter: Send + Sync {
type PrimitiveType;
fn inc(&self) -> Self::PrimitiveType;
fn add(&self, amount: Self::PrimitiveType) -> Self::PrimitiveType;
fn get(&self) -> Self::PrimitiveType;
fn reset(&self) -> Self::PrimitiveType;
fn into_inner(self) -> Self::PrimitiveType;
}
#[derive(Debug, Default)]
pub struct RelaxedCounter(AtomicUsize);
impl RelaxedCounter {
pub fn new(initial_count: usize) -> RelaxedCounter {
RelaxedCounter(AtomicUsize::new(initial_count))
}
}
impl AtomicCounter for RelaxedCounter {
type PrimitiveType = usize;
fn inc(&self) -> usize {
self.add(1)
}
fn add(&self, amount: usize) -> usize {
self.0.fetch_add(amount, Relaxed)
}
fn get(&self) -> usize {
self.0.load(Relaxed)
}
fn reset(&self) -> usize {
self.0.swap(0, Relaxed)
}
fn into_inner(self) -> usize {
self.0.into_inner()
}
}
#[derive(Debug, Default)]
pub struct ConsistentCounter(AtomicUsize);
impl ConsistentCounter {
pub fn new(initial_count: usize) -> ConsistentCounter {
ConsistentCounter(AtomicUsize::new(initial_count))
}
}
impl AtomicCounter for ConsistentCounter {
type PrimitiveType = usize;
fn inc(&self) -> usize {
self.add(1)
}
fn add(&self, amount: usize) -> usize {
self.0.fetch_add(amount, SeqCst)
}
fn get(&self) -> usize {
self.0.load(SeqCst)
}
fn reset(&self) -> usize {
self.0.swap(0, SeqCst)
}
fn into_inner(self) -> usize {
self.0.into_inner()
}
}
#[cfg(test)]
mod tests {
use std::fmt::Debug;
use std::thread;
use std::sync::Arc;
use std::ops::Deref;
use super::*;
const NUM_THREADS: usize = 29;
const NUM_ITERATIONS: usize = 7_000_000;
fn test_simple_with<Counter>(counter: Counter)
where Counter: AtomicCounter<PrimitiveType=usize>
{
counter.reset();
assert_eq!(0, counter.add(5));
assert_eq!(5, counter.add(3));
assert_eq!(8, counter.inc());
assert_eq!(9, counter.inc());
assert_eq!(10, counter.get());
assert_eq!(10, counter.get());
}
#[test]
fn test_simple_relaxed() {
test_simple_with(RelaxedCounter::new(0))
}
#[test]
fn test_simple_consistent() {
test_simple_with(ConsistentCounter::new(0))
}
fn test_inc_with<Counter>(counter: Arc<Counter>)
where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
{
let mut join_handles = Vec::new();
println!("test_inc: Spawning {} threads, each with {} iterations...",
NUM_THREADS,
NUM_ITERATIONS);
for _ in 0..NUM_THREADS {
let counter_ref = counter.clone();
join_handles.push(thread::spawn(move || {
let counter: &Counter = counter_ref.deref();
for _ in 0..NUM_ITERATIONS {
counter.inc();
}
}));
}
for handle in join_handles {
handle.join().unwrap();
}
let count = Arc::try_unwrap(counter).unwrap().into_inner();
println!("test_inc: Got count: {}", count);
assert_eq!(NUM_THREADS * NUM_ITERATIONS, count);
}
#[test]
fn test_inc_relaxed() {
test_inc_with(Arc::new(RelaxedCounter::new(0)));
}
#[test]
fn test_inc_consistent() {
test_inc_with(Arc::new(ConsistentCounter::new(0)));
}
fn test_add_with<Counter>(counter: Arc<Counter>)
where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
{
let mut join_handles = Vec::new();
println!("test_add: Spawning {} threads, each with {} iterations...",
NUM_THREADS,
NUM_ITERATIONS);
let mut expected_count = 0;
for to_add in 0..NUM_THREADS {
let counter_ref = counter.clone();
expected_count += to_add * NUM_ITERATIONS;
join_handles.push(thread::spawn(move || {
let counter: &Counter = counter_ref.deref();
for _ in 0..NUM_ITERATIONS {
counter.add(to_add);
}
}));
}
for handle in join_handles {
handle.join().unwrap();
}
let count = Arc::try_unwrap(counter).unwrap().into_inner();
println!("test_add: Expected count: {}, got count: {}",
expected_count,
count);
assert_eq!(expected_count, count);
}
#[test]
fn test_add_relaxed() {
test_add_with(Arc::new(RelaxedCounter::new(0)));
}
#[test]
fn test_add_consistent() {
test_add_with(Arc::new(ConsistentCounter::new(0)));
}
fn test_reset_with<Counter>(counter: Arc<Counter>)
where Counter: AtomicCounter<PrimitiveType=usize> + 'static + Debug
{
let mut join_handles = Vec::new();
println!("test_add_reset: Spawning {} threads, each with {} iterations...",
NUM_THREADS,
NUM_ITERATIONS);
let mut expected_count = 0;
for to_add in 0..NUM_THREADS {
expected_count += to_add * NUM_ITERATIONS;
}
let counter_ref = counter.clone();
let reset_handle = thread::spawn(move || {
let mut total_count = 0;
let counter: &Counter = counter_ref.deref();
while total_count < expected_count {
total_count += counter.reset();
}
for _ in 0..NUM_ITERATIONS {
total_count += counter.reset();
}
total_count
});
for to_add in 0..NUM_THREADS {
let counter_ref = counter.clone();
join_handles.push(thread::spawn(move || {
let counter: &Counter = counter_ref.deref();
for _ in 0..NUM_ITERATIONS {
counter.add(to_add);
}
}));
}
for handle in join_handles {
handle.join().unwrap();
}
let actual_count = reset_handle.join().unwrap();
println!("test_add_reset: Expected count: {}, got count: {}",
expected_count,
actual_count);
assert_eq!(expected_count, actual_count);
}
#[test]
fn test_reset_consistent() {
test_reset_with(Arc::new(ConsistentCounter::new(0)));
}
#[test]
fn test_reset_relaxed() {
test_reset_with(Arc::new(RelaxedCounter::new(0)));
}
}