use crate::{MetricsError, Result};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[repr(align(64))]
pub struct Counter {
value: AtomicU64,
created_at: Instant,
}
#[derive(Debug, Clone)]
pub struct CounterStats {
pub value: u64,
pub age: Duration,
pub rate_per_second: f64,
pub total: u64,
}
impl Counter {
#[inline]
pub fn new() -> Self {
Self {
value: AtomicU64::new(0),
created_at: Instant::now(),
}
}
#[inline]
pub fn with_value(initial: u64) -> Self {
Self {
value: AtomicU64::new(initial),
created_at: Instant::now(),
}
}
#[inline(always)]
pub fn inc(&self) {
self.value.fetch_add(1, Ordering::Relaxed);
}
#[inline(always)]
pub fn try_inc(&self) -> Result<()> {
loop {
let current = self.value.load(Ordering::Relaxed);
if current == u64::MAX {
return Err(MetricsError::Overflow);
}
match self.value.compare_exchange_weak(
current,
current + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(_) => continue,
}
}
}
#[inline(always)]
pub fn add(&self, amount: u64) {
self.value.fetch_add(amount, Ordering::Relaxed);
}
#[inline(always)]
pub fn try_add(&self, amount: u64) -> Result<()> {
if amount == 0 {
return Ok(());
}
loop {
let current = self.value.load(Ordering::Relaxed);
let new_val = current.checked_add(amount).ok_or(MetricsError::Overflow)?;
match self.value.compare_exchange_weak(
current,
new_val,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return Ok(()),
Err(_) => continue,
}
}
}
#[must_use]
#[inline(always)]
pub fn get(&self) -> u64 {
self.value.load(Ordering::Relaxed)
}
#[inline]
pub fn reset(&self) {
self.value.store(0, Ordering::SeqCst);
}
#[inline]
pub fn set(&self, value: u64) {
self.value.store(value, Ordering::SeqCst);
}
#[inline]
pub fn try_set(&self, value: u64) -> Result<()> {
self.set(value);
Ok(())
}
#[inline]
pub fn compare_and_swap(&self, expected: u64, new: u64) -> core::result::Result<u64, u64> {
match self
.value
.compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst)
{
Ok(prev) => Ok(prev),
Err(current) => Err(current),
}
}
#[must_use]
#[inline]
pub fn fetch_add(&self, amount: u64) -> u64 {
self.value.fetch_add(amount, Ordering::Relaxed)
}
#[inline]
pub fn try_fetch_add(&self, amount: u64) -> Result<u64> {
if amount == 0 {
return Ok(self.get());
}
loop {
let current = self.value.load(Ordering::Relaxed);
let new_val = current.checked_add(amount).ok_or(MetricsError::Overflow)?;
match self.value.compare_exchange_weak(
current,
new_val,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(prev) => return Ok(prev),
Err(_) => continue,
}
}
}
#[must_use]
#[inline]
pub fn add_and_get(&self, amount: u64) -> u64 {
self.value.fetch_add(amount, Ordering::Relaxed) + amount
}
#[must_use]
#[inline]
pub fn inc_and_get(&self) -> u64 {
self.value.fetch_add(1, Ordering::Relaxed) + 1
}
#[inline]
pub fn try_inc_and_get(&self) -> Result<u64> {
loop {
let current = self.value.load(Ordering::Relaxed);
let new_val = current.checked_add(1).ok_or(MetricsError::Overflow)?;
match self.value.compare_exchange_weak(
current,
new_val,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return Ok(new_val),
Err(_) => continue,
}
}
}
#[must_use]
pub fn stats(&self) -> CounterStats {
let value = self.get();
let age = self.created_at.elapsed();
let age_seconds = age.as_secs_f64();
let rate_per_second = if age_seconds > 0.0 {
value as f64 / age_seconds
} else {
0.0
};
CounterStats {
value,
age,
rate_per_second,
total: value,
}
}
#[must_use]
#[inline]
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
#[must_use]
#[inline]
pub fn is_zero(&self) -> bool {
self.get() == 0
}
#[must_use]
#[inline]
pub fn rate_per_second(&self) -> f64 {
let age_seconds = self.age().as_secs_f64();
if age_seconds > 0.0 {
self.get() as f64 / age_seconds
} else {
0.0
}
}
#[inline]
pub fn saturating_add(&self, amount: u64) {
loop {
let current = self.get();
let new_value = current.saturating_add(amount);
if new_value == current {
break;
}
match self.compare_and_swap(current, new_value) {
Ok(_) => break,
Err(_) => continue, }
}
}
}
impl Default for Counter {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for Counter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Counter({})", self.get())
}
}
impl std::fmt::Debug for Counter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Counter")
.field("value", &self.get())
.field("age", &self.age())
.field("rate_per_second", &self.rate_per_second())
.finish()
}
}
impl Counter {
#[inline]
pub fn batch_inc(&self, count: usize) {
if count > 0 {
self.add(count as u64);
}
}
#[inline]
pub fn inc_if(&self, condition: bool) {
if condition {
self.inc();
}
}
#[inline]
pub fn inc_max(&self, max_value: u64) -> bool {
loop {
let current = self.get();
if current >= max_value {
return false;
}
match self.compare_and_swap(current, current + 1) {
Ok(_) => return true,
Err(_) => continue, }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_operations() {
let counter = Counter::new();
assert_eq!(counter.get(), 0);
assert!(counter.is_zero());
counter.inc();
assert_eq!(counter.get(), 1);
assert!(!counter.is_zero());
counter.add(5);
assert_eq!(counter.get(), 6);
counter.reset();
assert_eq!(counter.get(), 0);
counter.set(42);
assert_eq!(counter.get(), 42);
}
#[test]
fn test_fetch_operations() {
let counter = Counter::new();
assert_eq!(counter.fetch_add(10), 0);
assert_eq!(counter.get(), 10);
assert_eq!(counter.inc_and_get(), 11);
assert_eq!(counter.add_and_get(5), 16);
}
#[test]
fn test_compare_and_swap() {
let counter = Counter::new();
counter.set(10);
assert_eq!(counter.compare_and_swap(10, 20), Ok(10));
assert_eq!(counter.get(), 20);
assert_eq!(counter.compare_and_swap(10, 30), Err(20));
assert_eq!(counter.get(), 20);
}
#[test]
fn test_saturating_add() {
let counter = Counter::new();
counter.set(u64::MAX - 5);
counter.saturating_add(10);
assert_eq!(counter.get(), u64::MAX);
counter.saturating_add(100);
assert_eq!(counter.get(), u64::MAX);
}
#[test]
fn test_conditional_operations() {
let counter = Counter::new();
counter.inc_if(true);
assert_eq!(counter.get(), 1);
counter.inc_if(false);
assert_eq!(counter.get(), 1);
assert!(counter.inc_max(5));
assert_eq!(counter.get(), 2);
counter.set(5);
assert!(!counter.inc_max(5));
assert_eq!(counter.get(), 5);
}
#[test]
fn test_statistics() {
let counter = Counter::new();
counter.add(100);
let stats = counter.stats();
assert_eq!(stats.value, 100);
assert_eq!(stats.total, 100);
assert!(stats.age > Duration::from_nanos(0));
assert!(stats.rate_per_second >= 0.0);
}
#[test]
fn test_high_concurrency() {
let counter = Arc::new(Counter::new());
let num_threads = 100;
let increments_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let counter = Arc::clone(&counter);
thread::spawn(move || {
for _ in 0..increments_per_thread {
counter.inc();
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(counter.get(), num_threads * increments_per_thread);
let stats = counter.stats();
assert!(stats.rate_per_second > 0.0);
}
#[test]
fn test_batch_operations() {
let counter = Counter::new();
counter.batch_inc(1000);
assert_eq!(counter.get(), 1000);
counter.batch_inc(0); assert_eq!(counter.get(), 1000);
}
#[test]
fn test_display_and_debug() {
let counter = Counter::new();
counter.set(42);
let display_str = format!("{counter}");
assert!(display_str.contains("42"));
let debug_str = format!("{counter:?}");
assert!(debug_str.contains("Counter"));
assert!(debug_str.contains("42"));
}
#[test]
fn test_checked_operations_and_overflow_paths() {
let counter = Counter::new();
counter.try_set(3).unwrap();
assert_eq!(counter.get(), 3);
counter.try_inc().unwrap();
assert_eq!(counter.get(), 4);
counter.try_add(0).unwrap();
assert_eq!(counter.get(), 4);
assert_eq!(counter.try_fetch_add(2).unwrap(), 4);
assert_eq!(counter.get(), 6);
assert_eq!(counter.try_fetch_add(0).unwrap(), 6);
assert_eq!(counter.try_inc_and_get().unwrap(), 7);
let overflow = Counter::with_value(u64::MAX);
assert!(matches!(overflow.try_inc(), Err(MetricsError::Overflow)));
assert!(matches!(overflow.try_add(1), Err(MetricsError::Overflow)));
assert!(matches!(
overflow.try_fetch_add(1),
Err(MetricsError::Overflow)
));
assert!(matches!(
overflow.try_inc_and_get(),
Err(MetricsError::Overflow)
));
}
}
#[cfg(all(test, feature = "bench-tests", not(tarpaulin)))]
#[allow(unused_imports)]
mod benchmarks {
use super::*;
use std::time::Instant;
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_counter_increment() {
let counter = Counter::new();
let iterations = 10_000_000;
let start = Instant::now();
for _ in 0..iterations {
counter.inc();
}
let elapsed = start.elapsed();
println!(
"Counter increment: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert!(elapsed.as_nanos() / iterations < 100);
assert_eq!(counter.get(), iterations as u64);
}
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_counter_add() {
let counter = Counter::new();
let iterations = 1_000_000;
let start = Instant::now();
for i in 0..iterations {
counter.add(i + 1);
}
let elapsed = start.elapsed();
println!(
"Counter add: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert!(elapsed.as_nanos() / (iterations as u128) < 200);
}
#[cfg_attr(not(feature = "bench-tests"), ignore)]
#[test]
fn bench_counter_get() {
let counter = Counter::new();
counter.set(42);
let iterations = 100_000_000;
let start = Instant::now();
let mut sum = 0;
for _ in 0..iterations {
sum += counter.get();
}
let elapsed = start.elapsed();
println!(
"Counter get: {:.2} ns/op",
elapsed.as_nanos() as f64 / iterations as f64
);
assert_eq!(sum, 42 * iterations);
assert!(elapsed.as_nanos() / (iterations as u128) < 50);
}
}