#![deny(unsafe_code)]
use parking_lot::Mutex;
use std::cell::Cell;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
static SLOT_COUNTER: AtomicUsize = AtomicUsize::new(0);
std::thread_local! {
static THREAD_SLOT: Cell<Option<usize>> = const { Cell::new(None) };
}
#[derive(Debug)]
pub struct PerThreadAccumulator<A> {
slots: Vec<Mutex<A>>,
}
impl<A: Default + Send> PerThreadAccumulator<A> {
#[must_use]
pub fn new(num_slots: usize) -> Arc<Self> {
let slots = (0..num_slots.max(1)).map(|_| Mutex::new(A::default())).collect();
Arc::new(Self { slots })
}
#[inline]
pub fn with_slot<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut A) -> R,
{
let slot = global_thread_index() % self.slots.len();
let mut guard = self.slots[slot].lock();
f(&mut guard)
}
#[must_use]
pub fn slots(&self) -> &[Mutex<A>] {
&self.slots
}
pub fn into_slots(self: Arc<Self>) -> Vec<A> {
debug_assert_eq!(
Arc::strong_count(&self),
1,
"into_slots called with outstanding Arc holders; fallback is lossy under concurrent writes",
);
match Arc::try_unwrap(self) {
Ok(inner) => inner.slots.into_iter().map(Mutex::into_inner).collect(),
Err(arc) => arc.slots.iter().map(|m| std::mem::take(&mut *m.lock())).collect(),
}
}
}
#[inline]
fn global_thread_index() -> usize {
THREAD_SLOT.with(|c| {
c.get().unwrap_or_else(|| {
let s = SLOT_COUNTER.fetch_add(1, Ordering::Relaxed);
c.set(Some(s));
s
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[derive(Default, Debug)]
struct Counter(u64);
#[test]
fn single_thread_accumulates_in_one_slot() {
let acc: Arc<PerThreadAccumulator<Counter>> = PerThreadAccumulator::new(4);
for _ in 0..100 {
acc.with_slot(|c| c.0 += 1);
}
let total: u64 = acc.slots().iter().map(|s| s.lock().0).sum();
assert_eq!(total, 100);
}
#[test]
fn parallel_threads_share_total_count() {
let acc: Arc<PerThreadAccumulator<Counter>> = PerThreadAccumulator::new(8);
thread::scope(|s| {
for _ in 0..8 {
let acc = Arc::clone(&acc);
s.spawn(move || {
for _ in 0..1_000 {
acc.with_slot(|c| c.0 += 1);
}
});
}
});
let total: u64 = acc.slots().iter().map(|s| s.lock().0).sum();
assert_eq!(total, 8_000);
}
#[test]
fn num_slots_zero_clamped_to_one() {
let acc: Arc<PerThreadAccumulator<Counter>> = PerThreadAccumulator::new(0);
assert_eq!(acc.slots().len(), 1);
acc.with_slot(|c| c.0 = 42);
assert_eq!(acc.slots()[0].lock().0, 42);
}
#[test]
fn into_slots_drains_under_unique_ownership() {
let acc: Arc<PerThreadAccumulator<Counter>> = PerThreadAccumulator::new(4);
acc.with_slot(|c| c.0 = 7);
let slots = acc.into_slots();
assert_eq!(slots.len(), 4);
assert_eq!(slots.iter().map(|c| c.0).sum::<u64>(), 7);
}
}