use core::cell::UnsafeCell;
use core::sync::atomic::{
AtomicI16, AtomicI32, AtomicI64, AtomicI8, AtomicIsize, AtomicU16, AtomicU32, AtomicU64,
AtomicU8, AtomicUsize, Ordering,
};
use std::thread::LocalKey;
macro_rules! flushing_counter {
($( $primitive:ident $atomic:ident $counter:ident ), *) => {
$(
pub struct $counter {
global_counter: $atomic,
thread_local_counter: &'static LocalKey<UnsafeCell<$primitive>>,
}
impl $counter {
#[inline]
pub const fn new(start: $primitive) -> Self {
thread_local!(pub static TL_COUNTER : UnsafeCell<$primitive> = UnsafeCell::new(0));
$counter {
global_counter: $atomic::new(start),
thread_local_counter: &TL_COUNTER,
}
}
#[inline]
pub fn inc(&self) {
self.thread_local_counter.with(|tlc| unsafe {
let tlc = &mut *tlc.get();
*tlc += 1;
});
}
#[inline]
pub fn get(&self) -> $primitive {
self.global_counter.load(Ordering::Relaxed)
}
#[inline]
pub fn flush(&self) {
self.thread_local_counter.with(|tlc| unsafe {
let tlc = &mut *tlc.get();
self.global_counter.fetch_add(*tlc, Ordering::Relaxed);
*tlc = 0;
});
}
}
)*
};
}
flushing_counter![u8 AtomicU8 FlushingCounterU8, u16 AtomicU16 FlushingCounterU16, u32 AtomicU32 FlushingCounterU32, u64 AtomicU64 FlushingCounterU64, usize AtomicUsize FlushingCounterUsize, i8 AtomicI8 FlushingCounterI8, i16 AtomicI16 FlushingCounterI16, i32 AtomicI32 FlushingCounterI32, i64 AtomicI64 FlushingCounterI64, isize AtomicIsize FlushingCounterIsize];
macro_rules! approx_counter {
($( $primitive:ident $atomic:ident $counter:ident $resolution:ty), *) => {
$(
pub struct $counter {
threshold: $resolution,
global_counter: $atomic,
thread_local_counter: &'static LocalKey<UnsafeCell<$resolution>>,
}
impl $counter {
#[inline]
pub const fn new(start: $primitive, resolution: $resolution) -> Self {
thread_local!(static TL_COUNTER : UnsafeCell<$resolution> = UnsafeCell::new(0));
$counter {
threshold: resolution,
global_counter: $atomic::new(start),
thread_local_counter: &TL_COUNTER,
}
}
#[inline]
pub fn inc(&self) {
self.thread_local_counter.with(|tlc| unsafe {
let tlc = &mut *tlc.get();
*tlc += 1;
if *tlc >= self.threshold {
self.global_counter.fetch_add(*tlc as $primitive, Ordering::Relaxed);
*tlc = 0;
}
});
}
#[inline]
pub fn get(&self) -> $primitive {
self.global_counter.load(Ordering::Relaxed)
}
#[inline]
pub fn flush(&self) {
self.thread_local_counter.with(|tlc| unsafe {
let tlc = &mut *tlc.get();
self.global_counter.fetch_add(*tlc as $primitive, Ordering::Relaxed);
*tlc = 0;
});
}
}
)*
};
}
approx_counter![u8 AtomicU8 ApproxCounterU8 u8, u16 AtomicU16 ApproxCounterU16 u16, u32 AtomicU32 ApproxCounterU32 u32, u64 AtomicU64 ApproxCounterU64 u64, usize AtomicUsize ApproxCounterUsize usize, i8 AtomicI8 ApproxCounterI8 u8, i16 AtomicI16 ApproxCounterI16 u16, i32 AtomicI32 ApproxCounterI32 u32, i64 AtomicI64 ApproxCounterI64 u64, isize AtomicIsize ApproxCounterIsize usize];
#[cfg(test)]
mod tests {
use super::*;
macro_rules! within_tolerance {
($val:expr, $expected:expr, $tol:expr) => {
($expected) - ($tol) <= ($val) && ($val) <= ($expected) + ($val)
};
}
#[test]
fn approx_new_const() {
static COUNTER: ApproxCounterUsize = ApproxCounterUsize::new(0, 1024);
assert_eq!(COUNTER.get(), 0);
COUNTER.inc();
assert!(COUNTER.get() <= 1);
}
#[test]
fn approx_flush_single_threaded() {
static COUNTER: ApproxCounterU64 = ApproxCounterU64::new(0, 1024);
assert_eq!(COUNTER.get(), 0);
COUNTER.inc();
COUNTER.flush();
assert_eq!(COUNTER.get(), 1);
}
#[test]
fn approx_negative_start_flush() {
static COUNTER: ApproxCounterI64 = ApproxCounterI64::new(-1154, 1024);
assert_eq!(COUNTER.get(), -1154);
COUNTER.inc();
COUNTER.flush();
assert_eq!(COUNTER.get(), -1153);
}
#[test]
fn approx_negative_to_positive() {
static COUNTER: ApproxCounterI64 = ApproxCounterI64::new(-999, 1000);
assert_eq!(COUNTER.get(), -999);
for _ in 0..1000 {
COUNTER.inc();
}
assert!(COUNTER.get() > 0);
}
#[test]
fn approx_different_counters() {
const NUM_THREADS: u32 = 1;
const LOCAL_ACC: u32 = 1024;
const GLOBAL_ACC: u32 = LOCAL_ACC * NUM_THREADS;
static COUNTER: ApproxCounterU32 = ApproxCounterU32::new(0, LOCAL_ACC);
assert_eq!(COUNTER.get(), 0);
for _ in 0..50000 {
COUNTER.inc();
}
assert!(within_tolerance!(COUNTER.get(), 50000, GLOBAL_ACC));
static COUNTER_2: ApproxCounterU32 = ApproxCounterU32::new(0, LOCAL_ACC);
assert!(within_tolerance!(COUNTER_2.get(), 0, 0));
assert!(within_tolerance!(COUNTER.get(), 50000, GLOBAL_ACC));
for _ in 0..50000 {
COUNTER_2.inc();
}
assert!(within_tolerance!(COUNTER_2.get(), 50000, GLOBAL_ACC));
assert!(within_tolerance!(COUNTER.get(), 50000, GLOBAL_ACC));
}
#[test]
fn approx_count_to_50000_single_threaded() {
const NUM_THREADS: u32 = 1;
const LOCAL_ACC: u32 = 1024;
const GLOBAL_ACC: u32 = LOCAL_ACC * NUM_THREADS;
static COUNTER: ApproxCounterU32 = ApproxCounterU32::new(0, LOCAL_ACC);
assert_eq!(COUNTER.get(), 0);
for _ in 0..50000 {
COUNTER.inc();
}
assert!(50000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 50000 + GLOBAL_ACC);
}
#[test]
fn approx_count_to_50000_seq_threaded() {
const NUM_THREADS: u16 = 5;
const LOCAL_ACC: u16 = 256;
const GLOBAL_ACC: u16 = (LOCAL_ACC - 1) * NUM_THREADS;
static COUNTER: ApproxCounterU16 = ApproxCounterU16::new(0, LOCAL_ACC);
assert_eq!(COUNTER.get(), 0);
let t_0 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_0.join().expect("Err joining thread");
assert!(10000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 10000 + GLOBAL_ACC);
let t_1 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_1.join().expect("Err joining thread");
assert!(20000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 20000 + GLOBAL_ACC);
let t_2 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_2.join().expect("Err joining thread");
assert!(30000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 30000 + GLOBAL_ACC);
let t_3 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_3.join().expect("Err joining thread");
assert!(40000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 40000 + GLOBAL_ACC);
let t_4 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_4.join().expect("Err joining thread");
assert!(50000 - GLOBAL_ACC <= COUNTER.get() && COUNTER.get() <= 50000 + GLOBAL_ACC);
}
#[test]
fn approx_count_to_50000_par_threaded() {
const NUM_THREADS: u32 = 5;
const LOCAL_ACC: u32 = 419;
const GLOBAL_ACC: u32 = (LOCAL_ACC - 1) * NUM_THREADS;
static COUNTER: ApproxCounterI32 = ApproxCounterI32::new(0, LOCAL_ACC);
assert_eq!(COUNTER.get(), 0);
let t_0 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
let t_1 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
let t_2 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
let t_3 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
let t_4 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
});
t_0.join().expect("Err joining thread");
t_1.join().expect("Err joining thread");
t_2.join().expect("Err joining thread");
t_3.join().expect("Err joining thread");
t_4.join().expect("Err joining thread");
assert!(
(50000 - GLOBAL_ACC) as i32 <= COUNTER.get()
&& COUNTER.get() <= (50000 + GLOBAL_ACC) as i32
);
}
#[test]
fn approx_flushed_count_to_50000_par_threaded() {
const LOCAL_ACC: usize = 419;
static COUNTER: ApproxCounterIsize = ApproxCounterIsize::new(0, LOCAL_ACC);
assert_eq!(COUNTER.get(), 0);
let t_0 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_1 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_2 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_3 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_4 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_0.join().expect("Err joining thread");
t_1.join().expect("Err joining thread");
t_2.join().expect("Err joining thread");
t_3.join().expect("Err joining thread");
t_4.join().expect("Err joining thread");
assert_eq!(50000, COUNTER.get());
}
#[test]
fn flushing_new_const() {
static COUNTER: FlushingCounterUsize = FlushingCounterUsize::new(0);
assert_eq!(COUNTER.get(), 0);
}
#[test]
fn flushing_different_counters() {
static COUNTER: FlushingCounterU32 = FlushingCounterU32::new(0);
assert_eq!(COUNTER.get(), 0);
for _ in 0..50000 {
COUNTER.inc();
}
COUNTER.flush();
assert!(within_tolerance!(COUNTER.get(), 50000, 0));
static COUNTER_2: FlushingCounterU32 = FlushingCounterU32::new(0);
assert!(within_tolerance!(COUNTER_2.get(), 0, 0));
assert!(within_tolerance!(COUNTER.get(), 50000, 0));
for _ in 0..50000 {
COUNTER_2.inc();
}
COUNTER_2.flush();
assert!(within_tolerance!(COUNTER_2.get(), 50000, 0));
assert!(within_tolerance!(COUNTER.get(), 50000, 0));
}
#[test]
fn flushing_count_to_50000_single_threaded() {
static COUNTER: FlushingCounterU64 = FlushingCounterU64::new(0);
assert_eq!(COUNTER.get(), 0);
for _ in 0..50000 {
COUNTER.inc();
}
COUNTER.flush();
assert_eq!(50000, COUNTER.get());
}
#[test]
fn flushing_count_to_50000_seq_threaded() {
static COUNTER: FlushingCounterU32 = FlushingCounterU32::new(0);
assert_eq!(COUNTER.get(), 0);
let t_0 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_0.join().expect("Err joining thread");
assert_eq!(10000, COUNTER.get());
let t_1 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_1.join().expect("Err joining thread");
assert_eq!(20000, COUNTER.get());
let t_2 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_2.join().expect("Err joining thread");
assert_eq!(30000, COUNTER.get());
let t_3 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_3.join().expect("Err joining thread");
assert_eq!(40000, COUNTER.get());
let t_4 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_4.join().expect("Err joining thread");
assert_eq!(50000, COUNTER.get());
}
#[test]
fn flushing_count_to_50000_par_threaded() {
static COUNTER: FlushingCounterU16 = FlushingCounterU16::new(0);
assert_eq!(COUNTER.get(), 0);
let t_0 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_1 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_2 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_3 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
let t_4 = std::thread::spawn(|| {
for _ in 0..10000 {
COUNTER.inc();
}
COUNTER.flush();
});
t_0.join().expect("Err joining thread");
t_1.join().expect("Err joining thread");
t_2.join().expect("Err joining thread");
t_3.join().expect("Err joining thread");
t_4.join().expect("Err joining thread");
assert_eq!(50000, COUNTER.get());
}
}