#![expect(clippy::unwrap_used)]
use super::{CACHE_LINE_SIZE, PaddedCounter, SHARDS, ShardedCounter};
use std::{sync::Arc, thread::Scope};
#[test]
fn test_new_counter_is_zero() {
let counter = ShardedCounter::new();
assert_eq!(counter.load(), 0);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_single_thread_increment() {
let counter = ShardedCounter::new();
for _ in 0..1000 {
counter.increment();
}
assert_eq!(counter.load(), 1000);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_single_thread_decrement() {
let counter = ShardedCounter::new();
for _ in 0..1000 {
counter.increment();
}
for _ in 0..300 {
counter.decrement();
}
assert_eq!(counter.load(), 700);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_concurrent_increments() {
let counter = Arc::new(ShardedCounter::new());
let threads: usize = 8;
let increments_per_thread: usize = 10_000;
std::thread::scope(|s: &Scope<'_, '_>| {
for _ in 0..threads {
let counter = Arc::clone(&counter);
s.spawn(move || {
for _ in 0..increments_per_thread {
counter.increment();
}
});
}
});
assert_eq!(counter.load(), threads * increments_per_thread);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_concurrent_mixed() {
let counter: Arc<ShardedCounter> = Arc::new(ShardedCounter::new());
for _ in 0..5000 {
counter.increment();
}
std::thread::scope(|s: &Scope<'_, '_>| {
for _ in 0..4 {
let counter: Arc<ShardedCounter> = Arc::clone(&counter);
s.spawn(move || {
for _ in 0..1000 {
counter.increment();
}
});
}
for _ in 0..4 {
let counter: Arc<ShardedCounter> = Arc::clone(&counter);
s.spawn(move || {
for _ in 0..1000 {
counter.decrement();
}
});
}
});
assert_eq!(counter.load(), 5000);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_reset() {
let counter = ShardedCounter::new();
for _ in 0..1000 {
counter.increment();
}
counter.reset();
assert_eq!(counter.load(), 0);
}
#[test]
fn test_shard_index_is_cached() {
let shard1 = ShardedCounter::shard_index();
let shard2 = ShardedCounter::shard_index();
let shard3 = ShardedCounter::shard_index();
assert_eq!(shard1, shard2);
assert_eq!(shard2, shard3);
}
#[test]
fn test_shard_index_in_valid_range() {
let index = ShardedCounter::shard_index();
assert!(index < SHARDS, "shard index {index} >= SHARDS ({SHARDS})");
}
#[test]
fn test_cache_line_alignment() {
assert_eq!(
std::mem::align_of::<PaddedCounter>(),
CACHE_LINE_SIZE,
"PaddedCounter alignment should be {CACHE_LINE_SIZE}"
);
}
#[test]
fn test_padded_counter_size() {
assert!(
std::mem::size_of::<PaddedCounter>() >= CACHE_LINE_SIZE,
"PaddedCounter size should be >= {CACHE_LINE_SIZE}"
);
}
#[test]
fn test_sharded_counter_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<ShardedCounter>();
}
#[test]
fn test_add_positive_and_negative() {
let counter = ShardedCounter::new();
counter.add(100);
assert_eq!(counter.load(), 100);
counter.add(-30);
assert_eq!(counter.load(), 70);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_stress_large_scale() {
let counter: Arc<ShardedCounter> = Arc::new(ShardedCounter::new());
let threads: usize = 16;
let ops_per_thread: usize = 100_000;
std::thread::scope(|s: &Scope<'_, '_>| {
for _ in 0..threads {
let counter: Arc<ShardedCounter> = Arc::clone(&counter);
s.spawn(move || {
for _ in 0..ops_per_thread {
counter.increment();
}
});
}
});
assert_eq!(counter.load(), threads * ops_per_thread);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_stress_mixed_operations() {
let counter: Arc<ShardedCounter> = Arc::new(ShardedCounter::new());
let threads: usize = 16;
let ops_per_thread: usize = 50_000;
let initial: usize = threads * ops_per_thread;
for _ in 0..initial {
counter.increment();
}
std::thread::scope(|s: &Scope<'_, '_>| {
for i in 0..threads {
let counter: Arc<ShardedCounter> = Arc::clone(&counter);
s.spawn(move || {
if i % 2 == 0 {
for _ in 0..ops_per_thread {
counter.increment();
}
} else {
for _ in 0..ops_per_thread {
counter.decrement();
}
}
});
}
});
assert_eq!(counter.load(), initial);
}
#[test]
#[cfg_attr(miri, ignore)] fn test_shard_distribution_across_threads() {
use std::collections::HashSet;
use std::sync::Mutex;
let observed_shards: Arc<Mutex<HashSet<usize>>> = Arc::new(Mutex::new(HashSet::new()));
let threads: usize = 32;
std::thread::scope(|s: &Scope<'_, '_>| {
for _ in 0..threads {
let observed: Arc<Mutex<HashSet<usize>>> = Arc::clone(&observed_shards);
s.spawn(move || {
let index: usize = ShardedCounter::shard_index();
observed.lock().unwrap().insert(index);
});
}
});
let shards_hit: usize = observed_shards.lock().unwrap().len();
assert!(
shards_hit >= 8,
"Poor shard distribution: only {shards_hit}/{SHARDS} shards hit by {threads} threads"
);
}
#[test]
fn test_sharded_counter_memory_layout() {
let expected_size: usize = SHARDS * CACHE_LINE_SIZE;
assert_eq!(
std::mem::size_of::<ShardedCounter>(),
expected_size,
"ShardedCounter should be {expected_size} bytes ({SHARDS} shards × {CACHE_LINE_SIZE} bytes)"
);
}
#[test]
fn test_default_trait() {
let counter: ShardedCounter = ShardedCounter::default();
assert_eq!(counter.load(), 0);
}
#[test]
fn test_debug_output() {
let counter = ShardedCounter::new();
counter.increment();
counter.increment();
counter.increment();
let debug_str: String = format!("{counter:?}");
assert!(
debug_str.contains("ShardedCounter"),
"Debug output should contain type name"
);
assert!(
debug_str.contains("total"),
"Debug output should contain 'total'"
);
assert!(debug_str.contains('3'), "Debug output should contain count");
}