1use crate::{MetricsError, Result};
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::{Duration, Instant};
16
17#[repr(align(64))]
22pub struct Counter {
23    value: AtomicU64,
25    created_at: Instant,
27}
28
29#[derive(Debug, Clone)]
31pub struct CounterStats {
32    pub value: u64,
34    pub age: Duration,
36    pub rate_per_second: f64,
38    pub total: u64,
40}
41
42impl Counter {
43    #[inline]
45    pub fn new() -> Self {
46        Self {
47            value: AtomicU64::new(0),
48            created_at: Instant::now(),
49        }
50    }
51
52    #[inline]
54    pub fn with_value(initial: u64) -> Self {
55        Self {
56            value: AtomicU64::new(initial),
57            created_at: Instant::now(),
58        }
59    }
60
61    #[inline(always)]
68    pub fn inc(&self) {
69        self.value.fetch_add(1, Ordering::Relaxed);
70    }
71
72    #[inline(always)]
86    pub fn try_inc(&self) -> Result<()> {
87        let current = self.value.load(Ordering::Relaxed);
88        if current == u64::MAX {
89            return Err(MetricsError::Overflow);
90        }
91        self.value.fetch_add(1, Ordering::Relaxed);
92        Ok(())
93    }
94
95    #[inline(always)]
100    pub fn add(&self, amount: u64) {
101        self.value.fetch_add(amount, Ordering::Relaxed);
102    }
103
104    #[inline(always)]
117    pub fn try_add(&self, amount: u64) -> Result<()> {
118        if amount == 0 {
119            return Ok(());
120        }
121        let current = self.value.load(Ordering::Relaxed);
122        if current.checked_add(amount).is_none() {
123            return Err(MetricsError::Overflow);
124        }
125        self.value.fetch_add(amount, Ordering::Relaxed);
126        Ok(())
127    }
128
129    #[inline(always)]
131    pub fn get(&self) -> u64 {
132        self.value.load(Ordering::Relaxed)
133    }
134
135    #[inline]
139    pub fn reset(&self) {
140        self.value.store(0, Ordering::SeqCst);
141    }
142
143    #[inline]
147    pub fn set(&self, value: u64) {
148        self.value.store(value, Ordering::SeqCst);
149    }
150
151    #[inline]
155    pub fn try_set(&self, value: u64) -> Result<()> {
156        self.set(value);
157        Ok(())
158    }
159
160    #[inline]
164    pub fn compare_and_swap(&self, expected: u64, new: u64) -> core::result::Result<u64, u64> {
165        match self
166            .value
167            .compare_exchange(expected, new, Ordering::SeqCst, Ordering::SeqCst)
168        {
169            Ok(prev) => Ok(prev),
170            Err(current) => Err(current),
171        }
172    }
173
174    #[inline]
176    pub fn fetch_add(&self, amount: u64) -> u64 {
177        self.value.fetch_add(amount, Ordering::Relaxed)
178    }
179
180    #[inline]
193    pub fn try_fetch_add(&self, amount: u64) -> Result<u64> {
194        if amount == 0 {
195            return Ok(self.get());
196        }
197        let current = self.value.load(Ordering::Relaxed);
198        if current.checked_add(amount).is_none() {
199            return Err(MetricsError::Overflow);
200        }
201        Ok(self.value.fetch_add(amount, Ordering::Relaxed))
202    }
203
204    #[inline]
206    pub fn add_and_get(&self, amount: u64) -> u64 {
207        self.value.fetch_add(amount, Ordering::Relaxed) + amount
208    }
209
210    #[inline]
212    pub fn inc_and_get(&self) -> u64 {
213        self.value.fetch_add(1, Ordering::Relaxed) + 1
214    }
215
216    #[inline]
229    pub fn try_inc_and_get(&self) -> Result<u64> {
230        let current = self.value.load(Ordering::Relaxed);
231        let new_val = current.checked_add(1).ok_or(MetricsError::Overflow)?;
232        let prev = self.value.fetch_add(1, Ordering::Relaxed);
233        debug_assert_eq!(prev, current);
234        Ok(new_val)
235    }
236
237    pub fn stats(&self) -> CounterStats {
239        let value = self.get();
240        let age = self.created_at.elapsed();
241        let age_seconds = age.as_secs_f64();
242
243        let rate_per_second = if age_seconds > 0.0 {
244            value as f64 / age_seconds
245        } else {
246            0.0
247        };
248
249        CounterStats {
250            value,
251            age,
252            rate_per_second,
253            total: value,
254        }
255    }
256
257    #[inline]
259    pub fn age(&self) -> Duration {
260        self.created_at.elapsed()
261    }
262
263    #[inline]
265    pub fn is_zero(&self) -> bool {
266        self.get() == 0
267    }
268
269    #[inline]
271    pub fn rate_per_second(&self) -> f64 {
272        let age_seconds = self.age().as_secs_f64();
273        if age_seconds > 0.0 {
274            self.get() as f64 / age_seconds
275        } else {
276            0.0
277        }
278    }
279
280    #[inline]
282    pub fn saturating_add(&self, amount: u64) {
283        loop {
284            let current = self.get();
285            let new_value = current.saturating_add(amount);
286
287            if new_value == current {
289                break;
290            }
291
292            match self.compare_and_swap(current, new_value) {
294                Ok(_) => break,
295                Err(_) => continue, }
297        }
298    }
299}
300
301impl Default for Counter {
302    #[inline]
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308impl std::fmt::Display for Counter {
309    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
310        write!(f, "Counter({})", self.get())
311    }
312}
313
314impl std::fmt::Debug for Counter {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("Counter")
317            .field("value", &self.get())
318            .field("age", &self.age())
319            .field("rate_per_second", &self.rate_per_second())
320            .finish()
321    }
322}
323
324unsafe impl Send for Counter {}
326unsafe impl Sync for Counter {}
327
328impl Counter {
330    #[inline]
334    pub fn batch_inc(&self, count: usize) {
335        if count > 0 {
336            self.add(count as u64);
337        }
338    }
339
340    #[inline]
342    pub fn inc_if(&self, condition: bool) {
343        if condition {
344            self.inc();
345        }
346    }
347
348    #[inline]
350    pub fn inc_max(&self, max_value: u64) -> bool {
351        loop {
352            let current = self.get();
353            if current >= max_value {
354                return false;
355            }
356
357            match self.compare_and_swap(current, current + 1) {
358                Ok(_) => return true,
359                Err(_) => continue, }
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use std::sync::Arc;
369    use std::thread;
370
371    #[test]
372    fn test_basic_operations() {
373        let counter = Counter::new();
374
375        assert_eq!(counter.get(), 0);
376        assert!(counter.is_zero());
377
378        counter.inc();
379        assert_eq!(counter.get(), 1);
380        assert!(!counter.is_zero());
381
382        counter.add(5);
383        assert_eq!(counter.get(), 6);
384
385        counter.reset();
386        assert_eq!(counter.get(), 0);
387
388        counter.set(42);
389        assert_eq!(counter.get(), 42);
390    }
391
392    #[test]
393    fn test_fetch_operations() {
394        let counter = Counter::new();
395
396        assert_eq!(counter.fetch_add(10), 0);
397        assert_eq!(counter.get(), 10);
398
399        assert_eq!(counter.inc_and_get(), 11);
400        assert_eq!(counter.add_and_get(5), 16);
401    }
402
403    #[test]
404    fn test_compare_and_swap() {
405        let counter = Counter::new();
406        counter.set(10);
407
408        assert_eq!(counter.compare_and_swap(10, 20), Ok(10));
410        assert_eq!(counter.get(), 20);
411
412        assert_eq!(counter.compare_and_swap(10, 30), Err(20));
414        assert_eq!(counter.get(), 20);
415    }
416
417    #[test]
418    fn test_saturating_add() {
419        let counter = Counter::new();
420        counter.set(u64::MAX - 5);
421
422        counter.saturating_add(10);
423        assert_eq!(counter.get(), u64::MAX);
424
425        counter.saturating_add(100);
427        assert_eq!(counter.get(), u64::MAX);
428    }
429
430    #[test]
431    fn test_conditional_operations() {
432        let counter = Counter::new();
433
434        counter.inc_if(true);
435        assert_eq!(counter.get(), 1);
436
437        counter.inc_if(false);
438        assert_eq!(counter.get(), 1);
439
440        assert!(counter.inc_max(5));
442        assert_eq!(counter.get(), 2);
443
444        counter.set(5);
445        assert!(!counter.inc_max(5));
446        assert_eq!(counter.get(), 5);
447    }
448
449    #[test]
450    fn test_statistics() {
451        let counter = Counter::new();
452        counter.add(100);
453
454        let stats = counter.stats();
455        assert_eq!(stats.value, 100);
456        assert_eq!(stats.total, 100);
457        assert!(stats.age > Duration::from_nanos(0));
458        assert!(stats.rate_per_second >= 0.0);
460    }
461
462    #[test]
463    fn test_high_concurrency() {
464        let counter = Arc::new(Counter::new());
465        let num_threads = 100;
466        let increments_per_thread = 1000;
467
468        let handles: Vec<_> = (0..num_threads)
469            .map(|_| {
470                let counter = Arc::clone(&counter);
471                thread::spawn(move || {
472                    for _ in 0..increments_per_thread {
473                        counter.inc();
474                    }
475                })
476            })
477            .collect();
478
479        for handle in handles {
480            handle.join().unwrap();
481        }
482
483        assert_eq!(counter.get(), num_threads * increments_per_thread);
484
485        let stats = counter.stats();
486        assert!(stats.rate_per_second > 0.0);
487    }
488
489    #[test]
490    fn test_batch_operations() {
491        let counter = Counter::new();
492
493        counter.batch_inc(1000);
494        assert_eq!(counter.get(), 1000);
495
496        counter.batch_inc(0); assert_eq!(counter.get(), 1000);
498    }
499
500    #[test]
501    fn test_display_and_debug() {
502        let counter = Counter::new();
503        counter.set(42);
504
505        let display_str = format!("{}", counter);
506        assert!(display_str.contains("42"));
507
508        let debug_str = format!("{counter:?}");
509        assert!(debug_str.contains("Counter"));
510        assert!(debug_str.contains("42"));
511    }
512}
513
514#[cfg(all(test, feature = "bench-tests", not(tarpaulin)))]
515#[allow(unused_imports)]
516mod benchmarks {
517    use super::*;
518    use std::time::Instant;
519
520    #[cfg_attr(not(feature = "bench-tests"), ignore)]
521    #[test]
522    fn bench_counter_increment() {
523        let counter = Counter::new();
524        let iterations = 10_000_000;
525
526        let start = Instant::now();
527        for _ in 0..iterations {
528            counter.inc();
529        }
530        let elapsed = start.elapsed();
531
532        println!(
533            "Counter increment: {:.2} ns/op",
534            elapsed.as_nanos() as f64 / iterations as f64
535        );
536
537        assert!(elapsed.as_nanos() / iterations < 100);
539        assert_eq!(counter.get(), iterations as u64);
540    }
541
542    #[cfg_attr(not(feature = "bench-tests"), ignore)]
543    #[test]
544    fn bench_counter_add() {
545        let counter = Counter::new();
546        let iterations = 1_000_000;
547
548        let start = Instant::now();
549        for i in 0..iterations {
550            counter.add(i + 1);
551        }
552        let elapsed = start.elapsed();
553
554        println!(
555            "Counter add: {:.2} ns/op",
556            elapsed.as_nanos() as f64 / iterations as f64
557        );
558
559        assert!(elapsed.as_nanos() / (iterations as u128) < 200);
561    }
562
563    #[cfg_attr(not(feature = "bench-tests"), ignore)]
564    #[test]
565    fn bench_counter_get() {
566        let counter = Counter::new();
567        counter.set(42);
568        let iterations = 100_000_000;
569
570        let start = Instant::now();
571        let mut sum = 0;
572        for _ in 0..iterations {
573            sum += counter.get();
574        }
575        let elapsed = start.elapsed();
576
577        println!(
578            "Counter get: {:.2} ns/op",
579            elapsed.as_nanos() as f64 / iterations as f64
580        );
581
582        assert_eq!(sum, 42 * iterations);
584
585        assert!(elapsed.as_nanos() / (iterations as u128) < 50);
587    }
588}