Skip to main content

metrique_core/
atomics.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use core::sync::atomic::{AtomicBool, AtomicU8, AtomicU16, AtomicU32, AtomicU64, AtomicUsize};
7
8use crate::CloseValue;
9
10/// A thin wrapper around `AtomicU64` that implements [`CloseValue`](crate::CloseValue).
11///
12/// This is provided for convenience to avoid the need to specify an ordering. However,
13/// all other atomics also implement [`CloseValue`] and can be used directly.
14#[derive(Default, Debug)]
15pub struct Counter(pub AtomicU64);
16impl Counter {
17    /// Create a new [`Counter`], initialized a specific value
18    pub const fn new(starting_count: u64) -> Self {
19        Self(AtomicU64::new(starting_count))
20    }
21
22    /// Add 1 to this counter
23    pub fn increment(&self) {
24        self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
25    }
26
27    /// Increments the count by 1, returning a guard that decrements the count
28    /// on drop, and the new value. Useful for tracking in-flight operations.
29    pub fn increment_scoped(&self) -> (CounterGuard<'_>, u64) {
30        let count = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
31        (CounterGuard(&self.0), count)
32    }
33
34    /// Increase the value of this counter by `i`
35    pub fn add(&self, i: u64) {
36        self.0.fetch_add(i, std::sync::atomic::Ordering::Relaxed);
37    }
38
39    /// Set this counter to `i`, discarding the previous value
40    pub fn set(&self, i: u64) {
41        self.0.store(i, std::sync::atomic::Ordering::SeqCst);
42    }
43
44    /// Increments the count by 1, returning an owned guard that decrements the
45    /// count on drop, and the new value.
46    ///
47    /// Unlike [`increment_scoped`](Self::increment_scoped), the returned
48    /// [`OwnedCounterGuard`] can be moved across async boundaries or stored
49    /// in structs without lifetime constraints.
50    pub fn increment_owned(self: &Arc<Self>) -> (OwnedCounterGuard, u64) {
51        let count = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
52        (
53            OwnedCounterGuard {
54                counter: Arc::clone(self),
55            },
56            count,
57        )
58    }
59}
60
61/// A guard that decrements a [`Counter`] when dropped.
62///
63/// Returned by [`Counter::increment_scoped`].
64#[must_use]
65pub struct CounterGuard<'a>(&'a AtomicU64);
66
67impl Drop for CounterGuard<'_> {
68    fn drop(&mut self) {
69        self.0
70            .fetch_update(
71                std::sync::atomic::Ordering::Relaxed,
72                std::sync::atomic::Ordering::Relaxed,
73                |v| Some(v.saturating_sub(1)),
74            )
75            .ok();
76    }
77}
78
79#[diagnostic::do_not_recommend]
80impl CloseValue for &CounterGuard<'_> {
81    type Closed = u64;
82
83    fn close(self) -> Self::Closed {
84        self.0.load(std::sync::atomic::Ordering::Relaxed)
85    }
86}
87
88#[diagnostic::do_not_recommend]
89impl CloseValue for CounterGuard<'_> {
90    type Closed = u64;
91
92    fn close(self) -> Self::Closed {
93        (&self).close()
94    }
95}
96
97/// An owned guard that decrements a [`Counter`] when dropped.
98///
99/// Unlike [`CounterGuard`], this guard can be moved across async boundaries
100/// or stored in structs without lifetime constraints.
101///
102/// Returned by [`Counter::increment_owned`].
103#[must_use]
104pub struct OwnedCounterGuard {
105    counter: Arc<Counter>,
106}
107
108impl Drop for OwnedCounterGuard {
109    fn drop(&mut self) {
110        self.counter
111            .0
112            .fetch_update(
113                std::sync::atomic::Ordering::Relaxed,
114                std::sync::atomic::Ordering::Relaxed,
115                |v| Some(v.saturating_sub(1)),
116            )
117            .ok();
118    }
119}
120
121#[diagnostic::do_not_recommend]
122impl CloseValue for &OwnedCounterGuard {
123    type Closed = u64;
124
125    fn close(self) -> Self::Closed {
126        self.counter.0.load(std::sync::atomic::Ordering::Relaxed)
127    }
128}
129
130#[diagnostic::do_not_recommend]
131impl CloseValue for OwnedCounterGuard {
132    type Closed = u64;
133
134    fn close(self) -> Self::Closed {
135        (&self).close()
136    }
137}
138
139impl CloseValue for &'_ Counter {
140    type Closed = u64;
141
142    fn close(self) -> Self::Closed {
143        <&AtomicU64>::close(&self.0)
144    }
145}
146
147impl CloseValue for Counter {
148    type Closed = u64;
149
150    fn close(self) -> Self::Closed {
151        self.0.close()
152    }
153}
154
155macro_rules! close_value_atomic {
156    (atomic: $atomic: ty, inner: $inner: ty) => {
157        impl $crate::CloseValue for &'_ $atomic {
158            type Closed = $inner;
159
160            fn close(self) -> Self::Closed {
161                self.load(std::sync::atomic::Ordering::Relaxed)
162            }
163        }
164
165        impl $crate::CloseValue for $atomic {
166            type Closed = $inner;
167
168            fn close(self) -> Self::Closed {
169                self.load(std::sync::atomic::Ordering::Relaxed)
170            }
171        }
172    };
173}
174
175close_value_atomic!(atomic: AtomicU64, inner: u64);
176close_value_atomic!(atomic: AtomicU32, inner: u32);
177close_value_atomic!(atomic: AtomicU16, inner: u16);
178close_value_atomic!(atomic: AtomicU8, inner: u8);
179close_value_atomic!(atomic: AtomicUsize, inner: usize);
180
181close_value_atomic!(atomic: AtomicBool, inner: bool);
182
183#[cfg(test)]
184mod tests {
185    use std::sync::Arc;
186
187    use super::*;
188
189    #[test]
190    fn increment_scoped() {
191        let counter = Counter::new(0);
192        let (guard, count) = counter.increment_scoped();
193        assert_eq!(count, 1);
194        drop(guard);
195        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
196    }
197
198    #[test]
199    fn increment_scoped_static() {
200        static COUNTER: Counter = Counter::new(0);
201        let (guard, count) = COUNTER.increment_scoped();
202        assert_eq!(count, 1);
203        drop(guard);
204        assert_eq!(COUNTER.0.load(std::sync::atomic::Ordering::Relaxed), 0);
205    }
206
207    #[test]
208    fn counter_guard_close_value() {
209        let counter = Counter::new(0);
210        let (guard, _) = counter.increment_scoped();
211        // CloseValue reads the current count (1) without decrementing.
212        assert_eq!((&guard).close(), 1);
213        // Guard still decrements on drop.
214        drop(guard);
215        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
216    }
217
218    #[test]
219    fn owned_counter_guard_increment_and_drop() {
220        let counter = Arc::new(Counter::new(0));
221        let (guard, count) = counter.increment_owned();
222        assert_eq!(count, 1);
223        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 1);
224        drop(guard);
225        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
226    }
227
228    #[test]
229    fn owned_counter_guard_saturates_at_zero() {
230        let counter = Arc::new(Counter::new(0));
231        let (guard, _) = counter.increment_owned();
232        // Manually set to 0 to test saturating_sub
233        counter.0.store(0, std::sync::atomic::Ordering::Relaxed);
234        drop(guard);
235        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
236    }
237
238    #[test]
239    fn owned_counter_guard_close_value() {
240        let counter = Arc::new(Counter::new(0));
241        let (guard, _) = counter.increment_owned();
242        assert_eq!((&guard).close(), 1);
243        // Guard still decrements on drop.
244        drop(guard);
245        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
246    }
247
248    #[test]
249    fn owned_counter_guard_move_across_threads() {
250        let counter = Arc::new(Counter::new(0));
251        let (guard, count) = counter.increment_owned();
252        assert_eq!(count, 1);
253        let counter_clone = Arc::clone(&counter);
254        let handle = std::thread::spawn(move || {
255            assert_eq!(
256                counter_clone.0.load(std::sync::atomic::Ordering::Relaxed),
257                1
258            );
259            drop(guard);
260        });
261        handle.join().unwrap();
262        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
263    }
264
265    #[test]
266    fn owned_counter_guard_multiple_guards() {
267        let counter = Arc::new(Counter::new(0));
268        let (g1, c1) = counter.increment_owned();
269        let (g2, c2) = counter.increment_owned();
270        assert_eq!(c1, 1);
271        assert_eq!(c2, 2);
272        drop(g1);
273        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 1);
274        drop(g2);
275        assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
276    }
277}