1use std::sync::Arc;
5
6use core::sync::atomic::{AtomicBool, AtomicU8, AtomicU16, AtomicU32, AtomicU64, AtomicUsize};
7
8use crate::CloseValue;
9
10#[derive(Default, Debug)]
15pub struct Counter(pub AtomicU64);
16impl Counter {
17 pub const fn new(starting_count: u64) -> Self {
19 Self(AtomicU64::new(starting_count))
20 }
21
22 pub fn increment(&self) {
24 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
25 }
26
27 pub fn increment_scoped(&self) -> (CounterGuard<'_>, u64) {
30 let count = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
31 (CounterGuard(&self.0), count)
32 }
33
34 pub fn add(&self, i: u64) {
36 self.0.fetch_add(i, std::sync::atomic::Ordering::Relaxed);
37 }
38
39 pub fn set(&self, i: u64) {
41 self.0.store(i, std::sync::atomic::Ordering::SeqCst);
42 }
43
44 pub fn increment_owned(self: &Arc<Self>) -> (OwnedCounterGuard, u64) {
51 let count = self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) + 1;
52 (
53 OwnedCounterGuard {
54 counter: Arc::clone(self),
55 },
56 count,
57 )
58 }
59}
60
61#[must_use]
65pub struct CounterGuard<'a>(&'a AtomicU64);
66
67impl Drop for CounterGuard<'_> {
68 fn drop(&mut self) {
69 self.0
70 .fetch_update(
71 std::sync::atomic::Ordering::Relaxed,
72 std::sync::atomic::Ordering::Relaxed,
73 |v| Some(v.saturating_sub(1)),
74 )
75 .ok();
76 }
77}
78
79#[diagnostic::do_not_recommend]
80impl CloseValue for &CounterGuard<'_> {
81 type Closed = u64;
82
83 fn close(self) -> Self::Closed {
84 self.0.load(std::sync::atomic::Ordering::Relaxed)
85 }
86}
87
88#[diagnostic::do_not_recommend]
89impl CloseValue for CounterGuard<'_> {
90 type Closed = u64;
91
92 fn close(self) -> Self::Closed {
93 (&self).close()
94 }
95}
96
97#[must_use]
104pub struct OwnedCounterGuard {
105 counter: Arc<Counter>,
106}
107
108impl Drop for OwnedCounterGuard {
109 fn drop(&mut self) {
110 self.counter
111 .0
112 .fetch_update(
113 std::sync::atomic::Ordering::Relaxed,
114 std::sync::atomic::Ordering::Relaxed,
115 |v| Some(v.saturating_sub(1)),
116 )
117 .ok();
118 }
119}
120
121#[diagnostic::do_not_recommend]
122impl CloseValue for &OwnedCounterGuard {
123 type Closed = u64;
124
125 fn close(self) -> Self::Closed {
126 self.counter.0.load(std::sync::atomic::Ordering::Relaxed)
127 }
128}
129
130#[diagnostic::do_not_recommend]
131impl CloseValue for OwnedCounterGuard {
132 type Closed = u64;
133
134 fn close(self) -> Self::Closed {
135 (&self).close()
136 }
137}
138
139impl CloseValue for &'_ Counter {
140 type Closed = u64;
141
142 fn close(self) -> Self::Closed {
143 <&AtomicU64>::close(&self.0)
144 }
145}
146
147impl CloseValue for Counter {
148 type Closed = u64;
149
150 fn close(self) -> Self::Closed {
151 self.0.close()
152 }
153}
154
155macro_rules! close_value_atomic {
156 (atomic: $atomic: ty, inner: $inner: ty) => {
157 impl $crate::CloseValue for &'_ $atomic {
158 type Closed = $inner;
159
160 fn close(self) -> Self::Closed {
161 self.load(std::sync::atomic::Ordering::Relaxed)
162 }
163 }
164
165 impl $crate::CloseValue for $atomic {
166 type Closed = $inner;
167
168 fn close(self) -> Self::Closed {
169 self.load(std::sync::atomic::Ordering::Relaxed)
170 }
171 }
172 };
173}
174
175close_value_atomic!(atomic: AtomicU64, inner: u64);
176close_value_atomic!(atomic: AtomicU32, inner: u32);
177close_value_atomic!(atomic: AtomicU16, inner: u16);
178close_value_atomic!(atomic: AtomicU8, inner: u8);
179close_value_atomic!(atomic: AtomicUsize, inner: usize);
180
181close_value_atomic!(atomic: AtomicBool, inner: bool);
182
183#[cfg(test)]
184mod tests {
185 use std::sync::Arc;
186
187 use super::*;
188
189 #[test]
190 fn increment_scoped() {
191 let counter = Counter::new(0);
192 let (guard, count) = counter.increment_scoped();
193 assert_eq!(count, 1);
194 drop(guard);
195 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
196 }
197
198 #[test]
199 fn increment_scoped_static() {
200 static COUNTER: Counter = Counter::new(0);
201 let (guard, count) = COUNTER.increment_scoped();
202 assert_eq!(count, 1);
203 drop(guard);
204 assert_eq!(COUNTER.0.load(std::sync::atomic::Ordering::Relaxed), 0);
205 }
206
207 #[test]
208 fn counter_guard_close_value() {
209 let counter = Counter::new(0);
210 let (guard, _) = counter.increment_scoped();
211 assert_eq!((&guard).close(), 1);
213 drop(guard);
215 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
216 }
217
218 #[test]
219 fn owned_counter_guard_increment_and_drop() {
220 let counter = Arc::new(Counter::new(0));
221 let (guard, count) = counter.increment_owned();
222 assert_eq!(count, 1);
223 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 1);
224 drop(guard);
225 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
226 }
227
228 #[test]
229 fn owned_counter_guard_saturates_at_zero() {
230 let counter = Arc::new(Counter::new(0));
231 let (guard, _) = counter.increment_owned();
232 counter.0.store(0, std::sync::atomic::Ordering::Relaxed);
234 drop(guard);
235 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
236 }
237
238 #[test]
239 fn owned_counter_guard_close_value() {
240 let counter = Arc::new(Counter::new(0));
241 let (guard, _) = counter.increment_owned();
242 assert_eq!((&guard).close(), 1);
243 drop(guard);
245 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
246 }
247
248 #[test]
249 fn owned_counter_guard_move_across_threads() {
250 let counter = Arc::new(Counter::new(0));
251 let (guard, count) = counter.increment_owned();
252 assert_eq!(count, 1);
253 let counter_clone = Arc::clone(&counter);
254 let handle = std::thread::spawn(move || {
255 assert_eq!(
256 counter_clone.0.load(std::sync::atomic::Ordering::Relaxed),
257 1
258 );
259 drop(guard);
260 });
261 handle.join().unwrap();
262 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
263 }
264
265 #[test]
266 fn owned_counter_guard_multiple_guards() {
267 let counter = Arc::new(Counter::new(0));
268 let (g1, c1) = counter.increment_owned();
269 let (g2, c2) = counter.increment_owned();
270 assert_eq!(c1, 1);
271 assert_eq!(c2, 2);
272 drop(g1);
273 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 1);
274 drop(g2);
275 assert_eq!(counter.0.load(std::sync::atomic::Ordering::Relaxed), 0);
276 }
277}