use crate::thread_id::thread_id;
use crossbeam_utils::CachePadded;
use std::fmt;
use std::sync::atomic::{AtomicIsize, Ordering};
fn make_padded_counter() -> CachePadded<AtomicIsize> {
CachePadded::new(AtomicIsize::new(0))
}
pub struct Counter {
cells: Vec<CachePadded<AtomicIsize>>,
}
impl fmt::Debug for Counter {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Counter")
.field("sum", &self.sum())
.field("cells", &self.cells.len())
.finish()
}
}
impl Counter {
#[inline]
pub fn new(count: usize) -> Self {
let count = count.next_power_of_two();
Self {
cells: (0..count).map(|_| make_padded_counter()).collect(),
}
}
#[inline]
pub fn add(&self, value: isize) {
self.add_with_ordering(value, Ordering::Relaxed)
}
#[inline]
pub fn inc(&self) {
self.add(1)
}
#[inline]
pub fn add_with_ordering(&self, value: isize, ordering: Ordering) {
let idx = thread_id() & (self.cells.len() - 1);
let cell = if cfg!(debug_assertions) {
self.cells.get(idx).expect("index out of bounds")
} else {
unsafe { self.cells.get_unchecked(idx) }
};
cell.fetch_add(value, ordering);
}
#[inline]
pub fn sum(&self) -> isize {
self.sum_with_ordering(Ordering::Relaxed)
}
#[inline]
pub fn sum_with_ordering(&self, ordering: Ordering) -> isize {
self.cells.iter().map(|c| c.load(ordering)).sum()
}
#[inline]
pub fn swap(&self) -> isize {
self.cells
.iter()
.map(|c| c.swap(0, Ordering::Relaxed))
.sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_test() {
let counter = Counter::new(1);
counter.add(1);
assert_eq!(counter.sum(), 1);
}
#[test]
fn increment_multiple_times() {
let counter = Counter::new(1);
counter.add(1);
counter.add(1);
counter.add(1);
assert_eq!(counter.sum(), 3);
}
#[test]
fn test_inc() {
let counter = Counter::new(4);
counter.inc();
counter.inc();
assert_eq!(counter.sum(), 2);
}
#[test]
fn test_swap() {
let counter = Counter::new(4);
counter.add(100);
let val = counter.swap();
assert_eq!(val, 100);
assert_eq!(counter.sum(), 0);
}
#[test]
fn two_threads_incrementing_concurrently() {
let counter = Counter::new(2);
std::thread::scope(|s| {
for _ in 0..2 {
s.spawn(|| {
counter.add(1);
});
}
});
assert_eq!(counter.sum(), 2);
}
#[test]
fn multiple_threads_incrementing_many_times() {
const WRITE_COUNT: isize = 1_000_000;
const THREAD_COUNT: isize = 8;
let counter = Counter::new(THREAD_COUNT as usize);
std::thread::scope(|s| {
for _ in 0..THREAD_COUNT {
s.spawn(|| {
for _ in 0..WRITE_COUNT {
counter.add(1);
}
});
}
});
assert_eq!(counter.sum(), THREAD_COUNT * WRITE_COUNT);
}
#[test]
fn debug_format() {
let counter = Counter::new(8);
counter.add(42);
let debug = format!("{counter:?}");
assert!(debug.contains("sum: 42"));
assert!(debug.contains("cells: 8"));
}
}