use std::cell::Cell;
use std::sync::atomic::{AtomicU64, Ordering};
const CACHE_LINE: usize = 128;
const SLOTS: usize = CACHE_LINE / 8; const NUM_SHARDS: usize = 64;
thread_local! {
static SHARD_ID: Cell<Option<usize>> = const { Cell::new(None) };
}
pub fn set_thread_shard(id: usize) {
SHARD_ID.set(Some(id % NUM_SHARDS));
}
#[repr(C, align(128))]
struct Shard {
slots: [AtomicU64; SLOTS],
}
pub struct CounterGroup {
shards: [Shard; NUM_SHARDS],
}
unsafe impl Send for CounterGroup {}
unsafe impl Sync for CounterGroup {}
impl CounterGroup {
#[allow(clippy::declare_interior_mutable_const)]
pub const fn new() -> Self {
const ZERO: AtomicU64 = AtomicU64::new(0);
const SHARD: Shard = Shard {
slots: [ZERO; SLOTS],
};
Self {
shards: [SHARD; NUM_SHARDS],
}
}
#[inline]
fn increment(&self, slot: usize) {
self.add(slot, 1);
}
#[inline]
fn add(&self, slot: usize, value: u64) {
debug_assert!(slot < SLOTS, "slot index out of bounds");
let shard = shard_index();
self.shards[shard].slots[slot].fetch_add(value, Ordering::Relaxed);
}
fn value(&self, slot: usize) -> u64 {
debug_assert!(slot < SLOTS, "slot index out of bounds");
self.shards
.iter()
.map(|s| s.slots[slot].load(Ordering::Relaxed))
.sum()
}
}
impl Default for CounterGroup {
fn default() -> Self {
Self::new()
}
}
pub struct Counter {
group: &'static CounterGroup,
slot: usize,
}
unsafe impl Send for Counter {}
unsafe impl Sync for Counter {}
impl Counter {
pub const fn new(group: &'static CounterGroup, slot: usize) -> Self {
Self { group, slot }
}
#[inline]
pub fn increment(&self) {
self.group.increment(self.slot);
}
#[inline]
pub fn add(&self, value: u64) {
self.group.add(self.slot, value);
}
pub fn value(&self) -> u64 {
self.group.value(self.slot)
}
}
impl metriken::Metric for Counter {
fn as_any(&self) -> Option<&dyn std::any::Any> {
Some(self)
}
fn value(&self) -> Option<metriken::Value<'_>> {
Some(metriken::Value::Counter(Counter::value(self)))
}
}
#[inline]
fn shard_index() -> usize {
SHARD_ID.get().unwrap_or_else(|| {
thread_local! {
static ID: u8 = const { 0 };
}
ID.with(|x| x as *const u8 as usize) % NUM_SHARDS
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_counter() {
static GROUP: CounterGroup = CounterGroup::new();
let counter = Counter::new(&GROUP, 0);
assert_eq!(counter.value(), 0);
counter.increment();
assert_eq!(counter.value(), 1);
counter.add(10);
assert_eq!(counter.value(), 11);
}
#[test]
fn multiple_slots() {
static GROUP: CounterGroup = CounterGroup::new();
let a = Counter::new(&GROUP, 0);
let b = Counter::new(&GROUP, 1);
a.increment();
b.add(5);
assert_eq!(a.value(), 1);
assert_eq!(b.value(), 5);
}
#[test]
fn thread_distribution() {
use std::sync::Arc;
use std::thread;
static GROUP: CounterGroup = CounterGroup::new();
let counter = Arc::new(Counter::new(&GROUP, 2));
let iterations = 1000;
let num_threads = 4;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let c = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..iterations {
c.increment();
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(counter.value(), iterations * num_threads);
}
#[test]
fn metriken_trait() {
use metriken::Metric;
static GROUP: CounterGroup = CounterGroup::new();
let counter = Counter::new(&GROUP, 3);
counter.add(42);
let value = Metric::value(&counter);
assert!(matches!(value, Some(metriken::Value::Counter(42))));
}
}