raii_counter/
lib.rs

1//! Rust type for a RAII Counter (counts number of held instances,
2//! decrements count on `Drop`), implemented with `Arc<AtomicUsize>`.
3//!
4//! Useful for tracking the number of holders exist for a handle,
5//! tracking the number of transactions that are in-flight, etc.
6//!
7//! # Additional Features
8//! * [`Counter`]s can have a size, eg. a [`Counter`] with `size` 4 adds 4
9//! to the count, and removes 4 when dropped.
10//! * [`NotifyHandle`]s can be used for efficient conditional checking, eg.
11//! if you want to wait until there are no in-flight transactions, see:
12//! [`CounterBuilder::create_notify`] / [`WeakCounterBuilder::create_notify`]
13//! and [`NotifyHandle::wait_until_condition`].
14//!
15//! # Demo
16//!
17//! ```rust
18//! extern crate raii_counter;
19//! use raii_counter::Counter;
20//!
21//! let counter = Counter::builder().build();
22//! assert_eq!(counter.count(), 1);
23//!
24//! let weak = counter.downgrade();
25//! assert_eq!(weak.count(), 0);
26//!
27//! {
28//!     let _counter1 = weak.spawn_upgrade();
29//!     assert_eq!(weak.count(), 1);
30//!     let _counter2 = weak.spawn_upgrade();
31//!     assert_eq!(weak.count(), 2);
32//! }
33//!
34//! assert_eq!(weak.count(), 0);
35//! ```
36
37use notify::NotifySender;
38use std::fmt::{self, Display, Formatter};
39use std::sync::atomic::{AtomicUsize, Ordering};
40use std::sync::Arc;
41
42mod notify;
43
44pub use notify::{NotifyError, NotifyHandle, NotifyTimeoutError};
45
46/// Essentially an AtomicUsize that is clonable and whose count is based
47/// on the number of copies (and their size). The count is automatically updated on Drop.
48///
49/// If you want a weak reference to the counter that doesn't affect the count, see:
50/// [`WeakCounter`].
51#[derive(Debug)]
52pub struct Counter {
53    counter: Arc<AtomicUsize>,
54    notify: Vec<NotifySender>,
55    size: usize,
56}
57
58/// A 'weak' [`Counter`] that does not affect the count.
59#[derive(Clone, Debug)]
60pub struct WeakCounter {
61    counter: Arc<AtomicUsize>,
62    notify: Vec<NotifySender>,
63}
64
65/// A builder for the [`Counter`].
66pub struct CounterBuilder {
67    counter: Arc<AtomicUsize>,
68    size: usize,
69    notify: Vec<NotifySender>,
70}
71
72impl CounterBuilder {
73    /// Change the specified size of the new [`Counter`]. This counter will add
74    /// `size` to the count, and will remove `size` from the count
75    /// when dropped.
76    pub fn size(mut self, v: usize) -> Self {
77        self.size = v;
78        self
79    }
80
81    /// Create a [`NotifyHandle`] with a link to the count of this object. This [`NotifyHandle`] will
82    /// be notified when the value of this count changes.
83    ///
84    /// [`NotifyHandle`]s cannot be associated after creation, since all linked
85    /// [`Counter`] / [`WeakCounter`]s cannot be accounted for.
86    pub fn create_notify(&mut self) -> NotifyHandle {
87        let (handle, sender) = NotifyHandle::new(Arc::clone(&self.counter));
88        self.notify.push(sender);
89        handle
90    }
91
92    /// Create a new [`Counter`].
93    pub fn build(self) -> Counter {
94        self.counter.fetch_add(self.size, Ordering::SeqCst);
95        Counter {
96            counter: self.counter,
97            notify: self.notify,
98            size: self.size,
99        }
100    }
101}
102
103impl Default for CounterBuilder {
104    fn default() -> Self {
105        Self {
106            counter: Arc::new(AtomicUsize::new(0)),
107            size: 1,
108            notify: vec![],
109        }
110    }
111}
112
113impl Counter {
114    /// Create a new default [`CounterBuilder`].
115    pub fn builder() -> CounterBuilder {
116        CounterBuilder::default()
117    }
118
119    /// Consume self (causing the count to decrease by `size`)
120    /// and return a weak reference to the count through a [`WeakCounter`].
121    pub fn downgrade(self) -> WeakCounter {
122        self.spawn_downgrade()
123    }
124
125    /// Create a new [`WeakCounter`] without consuming self.
126    pub fn spawn_downgrade(&self) -> WeakCounter {
127        WeakCounter {
128            notify: self.notify.clone(),
129            counter: Arc::clone(&self.counter),
130        }
131    }
132
133    /// This method is inherently racey. Assume the count will have changed once
134    /// the value is observed.
135    #[inline]
136    pub fn count(&self) -> usize {
137        self.counter.load(Ordering::Acquire)
138    }
139}
140
141impl Clone for Counter {
142    fn clone(&self) -> Self {
143        self.counter.fetch_add(self.size, Ordering::SeqCst);
144        for sender in &self.notify {
145            sender.notify();
146        }
147        Counter {
148            notify: self.notify.clone(),
149            counter: Arc::clone(&self.counter),
150            size: self.size,
151        }
152    }
153}
154
155impl Display for Counter {
156    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
157        write!(f, "Counter(count={})", self.count())
158    }
159}
160
161impl Drop for Counter {
162    fn drop(&mut self) {
163        self.counter.fetch_sub(self.size, Ordering::SeqCst);
164        for sender in &self.notify {
165            sender.notify();
166        }
167    }
168}
169
170/// A builder for the [`WeakCounter`].
171pub struct WeakCounterBuilder {
172    counter: Arc<AtomicUsize>,
173    notify: Vec<NotifySender>,
174}
175
176impl WeakCounterBuilder {
177    /// Create a [`NotifyHandle`] with a link to the count of this object. This [`NotifyHandle`] will
178    /// be notified when the value of this count changes.
179    ///
180    /// [`NotifyHandle`]s cannot be associated after creation, since all linked
181    /// [`Counter`] / [`WeakCounter`]s cannot be accounted for.
182    pub fn create_notify(&mut self) -> NotifyHandle {
183        let (handle, sender) = NotifyHandle::new(Arc::clone(&self.counter));
184        self.notify.push(sender);
185        handle
186    }
187
188    /// Create a new [`WeakCounter`]. This [`WeakCounter`] creates a new count
189    /// with value: 0 since the [`WeakCounter`] has no effect on the count.
190    pub fn build(self) -> WeakCounter {
191        WeakCounter {
192            notify: self.notify,
193            counter: self.counter,
194        }
195    }
196}
197
198impl Default for WeakCounterBuilder {
199    fn default() -> Self {
200        Self {
201            counter: Arc::new(AtomicUsize::new(0)),
202            notify: vec![],
203        }
204    }
205}
206
207impl WeakCounter {
208    /// Create a new default [`WeakCounterBuilder`].
209    pub fn builder() -> WeakCounterBuilder {
210        WeakCounterBuilder::default()
211    }
212
213    /// This method is inherently racey. Assume the count will have changed once
214    /// the value is observed.
215    #[inline]
216    pub fn count(&self) -> usize {
217        self.counter.load(Ordering::Acquire)
218    }
219
220    /// Consumes self, becomes a [`Counter`] of `size` 1.
221    pub fn upgrade(self) -> Counter {
222        self.spawn_upgrade()
223    }
224
225    /// Create a new [`Counter`] with `size` 1 without consuming the
226    /// current [`WeakCounter`].
227    pub fn spawn_upgrade(&self) -> Counter {
228        self.spawn_upgrade_with_size(1)
229    }
230
231    /// Creates a new [`Counter`] with specified size without consuming the
232    /// current [`WeakCounter`].
233    pub fn spawn_upgrade_with_size(&self, size: usize) -> Counter {
234        self.counter.fetch_add(size, Ordering::SeqCst);
235        for sender in &self.notify {
236            sender.notify();
237        }
238        Counter {
239            notify: self.notify.clone(),
240            counter: Arc::clone(&self.counter),
241            size,
242        }
243    }
244}
245
246impl Display for WeakCounter {
247    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
248        write!(f, "WeakCounter(count={})", self.count())
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use std::thread;
256    use std::time::Duration;
257
258    #[test]
259    fn it_works() {
260        let counter = Counter::builder().build();
261        assert_eq!(counter.count(), 1);
262
263        let weak = counter.downgrade();
264        assert_eq!(weak.count(), 0);
265
266        {
267            let _counter1 = weak.spawn_upgrade();
268            assert_eq!(weak.count(), 1);
269            let _counter2 = weak.spawn_upgrade();
270            assert_eq!(weak.count(), 2);
271        }
272
273        assert_eq!(weak.count(), 0);
274    }
275
276    #[test]
277    fn different_sizes_work() {
278        let weak = WeakCounter::builder().build();
279        assert_eq!(weak.count(), 0);
280
281        let counter5 = weak.spawn_upgrade_with_size(5);
282        assert_eq!(weak.count(), 5);
283
284        {
285            let _moved_counter5 = counter5;
286            assert_eq!(weak.count(), 5);
287            let _counter1 = weak.spawn_upgrade();
288            assert_eq!(weak.count(), 6);
289        }
290
291        assert_eq!(weak.count(), 0);
292    }
293
294    #[test]
295    fn counter_with_size_works() {
296        let counter = Counter::builder().size(4).build();
297        assert_eq!(counter.count(), 4);
298
299        let weak = counter.spawn_downgrade();
300        assert_eq!(weak.count(), 4);
301        drop(counter);
302        assert_eq!(weak.count(), 0);
303    }
304
305    #[test]
306    fn wait_until_condition_works() {
307        run_wait_until_condition_test(|notify| notify.wait_until_condition(|v| v == 10).unwrap());
308    }
309
310    #[test]
311    fn wait_until_condition_with_timeout_works() {
312        run_wait_until_condition_test(|notify| {
313            notify
314                .wait_until_condition_timeout(|v| v == 10, Duration::from_secs(2))
315                .unwrap()
316        });
317    }
318
319    fn run_wait_until_condition_test(notify_fn: impl Fn(NotifyHandle)) {
320        let (weak, notify) = {
321            let mut builder = WeakCounter::builder();
322            let notify = builder.create_notify();
323            (builder.build(), notify)
324        };
325
326        let join_handle = thread::spawn(move || {
327            thread::sleep(Duration::from_millis(100));
328            let mut counters = vec![];
329            for _ in 0..10 {
330                counters.push(weak.spawn_upgrade());
331            }
332
333            // Return counters from the thread so they
334            // never get dropped (at least until the thread
335            // gets joined).
336            counters
337        });
338
339        notify_fn(notify);
340        join_handle.join().unwrap();
341    }
342
343    /// Run this test to gain more confidence that the notify is not flakey due to
344    /// race-conditions.
345    ///
346    /// ```
347    /// cargo test --release -- --ignored --nocapture
348    /// ```
349    #[test]
350    #[ignore]
351    fn test_wait_until_condition_always_occurs() {
352        let mut i = 0;
353        loop {
354            wait_until_condition_works();
355            println!("[{}] Completed.", i);
356            i += 1;
357        }
358    }
359
360    #[test]
361    fn notify_errors_when_all_references_are_dropped() {
362        let (weak, notify) = {
363            let mut builder = WeakCounter::builder();
364            let notify = builder.create_notify();
365            (builder.build(), notify)
366        };
367
368        thread::spawn(move || {
369            thread::sleep(Duration::from_millis(100));
370            let mut counters = vec![];
371            for _ in 0..5 {
372                counters.push(weak.spawn_upgrade());
373            }
374            // All references are dropped here, therefore the condition
375            // will never be true.
376        });
377
378        assert_eq!(
379            notify.wait_until_condition(|v| v == 10),
380            Err(NotifyError::Disconnected),
381        );
382    }
383
384    #[test]
385    fn notify_checks_condition_before_erroring() {
386        let (weak, notify) = {
387            let mut builder = WeakCounter::builder();
388            let notify = builder.create_notify();
389            (builder.build(), notify)
390        };
391
392        // All counter references are dropped.
393        drop(weak);
394
395        // Shouldn't error since the condition is true.
396        assert!(notify.wait_until_condition(|v| v == 0).is_ok());
397    }
398
399    #[test]
400    fn notify_with_timeout_can_timeout() {
401        let (weak, notify) = {
402            let mut builder = WeakCounter::builder();
403            let notify = builder.create_notify();
404            (builder.build(), notify)
405        };
406
407        assert_eq!(
408            notify.wait_until_condition_timeout(|v| v == 10, Duration::from_millis(100)),
409            Err(NotifyTimeoutError::Timeout)
410        );
411
412        // Counters are not dropped until here.
413        drop(weak);
414    }
415}