use rustc_hash::FxHasher;
use static_assertions::{const_assert, const_assert_eq};
use std::cell::Cell;
use std::fmt as StdFmt;
use std::hash::{Hash, Hasher};
use std::sync::atomic::AtomicIsize;
use std::sync::atomic::Ordering as AtomicOrdering;
use std::thread::{self as StdThread, ThreadId};
const SHARDS: usize = 16;
const CACHE_LINE_SIZE: usize = 128;
#[derive(Debug)]
#[repr(C, align(128))]
struct PaddedCounter {
value: AtomicIsize,
}
impl PaddedCounter {
const fn new() -> Self {
Self {
value: AtomicIsize::new(0),
}
}
}
const_assert_eq!(std::mem::align_of::<PaddedCounter>(), CACHE_LINE_SIZE);
const_assert!(std::mem::size_of::<PaddedCounter>() >= CACHE_LINE_SIZE);
thread_local! {
static CACHED_SHARD: Cell<Option<usize>> = const { Cell::new(None) };
}
pub struct ShardedCounter {
shards: [PaddedCounter; SHARDS],
}
impl ShardedCounter {
#[inline]
#[must_use]
pub const fn new() -> Self {
#[expect(
clippy::declare_interior_mutable_const,
reason = "Intentional: each array element gets a fresh zeroed copy"
)]
const EMPTY: PaddedCounter = PaddedCounter::new();
Self {
shards: [EMPTY; SHARDS],
}
}
#[cold]
#[expect(clippy::cast_possible_truncation, reason = "Intentional")]
fn compute_shard_index() -> usize {
let thread_id: ThreadId = StdThread::current().id();
let mut hasher: FxHasher = FxHasher::default();
thread_id.hash(&mut hasher);
(hasher.finish() as usize) % SHARDS
}
#[inline(always)]
fn shard_index() -> usize {
CACHED_SHARD.with(|cell: &Cell<Option<usize>>| {
cell.get().unwrap_or_else(|| {
let index: usize = Self::compute_shard_index();
cell.set(Some(index));
index
})
})
}
#[inline(always)]
#[expect(clippy::indexing_slicing, reason = "INVARIANT: index < SHARDS")]
fn get_shard(&self) -> &AtomicIsize {
let index: usize = Self::shard_index();
&self.shards[index].value
}
#[inline(always)]
pub fn increment(&self) {
self.get_shard().fetch_add(1, AtomicOrdering::Relaxed);
}
#[inline(always)]
pub fn decrement(&self) {
self.get_shard().fetch_sub(1, AtomicOrdering::Relaxed);
}
#[allow(dead_code, reason = "Public API")]
#[inline(always)]
pub fn add(&self, val: isize) {
self.get_shard().fetch_add(val, AtomicOrdering::Relaxed);
}
pub fn load(&self) -> usize {
let mut total: isize = 0;
for shard in &self.shards {
total += shard.value.load(AtomicOrdering::Relaxed);
}
debug_assert!(
total >= 0,
"ShardedCounter total is negative ({total}): more decrements than increments"
);
if total >= 0 { total.cast_unsigned() } else { 0 }
}
#[allow(dead_code, reason = "Public API")]
#[inline]
pub fn reset(&self) {
for shard in &self.shards {
shard.value.store(0, AtomicOrdering::Relaxed);
}
}
}
impl Default for ShardedCounter {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl StdFmt::Debug for ShardedCounter {
fn fmt(&self, f: &mut StdFmt::Formatter<'_>) -> StdFmt::Result {
f.debug_struct("ShardedCounter")
.field("total", &self.load())
.finish()
}
}
#[cfg(test)]
mod unit_tests;