atomic_destructor/
lib.rs

1// Copyright (c) 2024 Yuki Kishimoto
2// Distributed under the MIT software license
3
4//! Atomic destructor
5
6#![no_std]
7#![forbid(unsafe_code)]
8#![warn(missing_docs)]
9
10extern crate alloc;
11
12use alloc::sync::Arc;
13use core::fmt::{self, Debug};
14use core::ops::{Deref, DerefMut};
15use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
16
17trait SaturatingUsize {
18    fn saturating_increment(&self, order: Ordering) -> usize;
19
20    fn saturating_decrement(&self, order: Ordering) -> usize;
21}
22
23impl SaturatingUsize for AtomicUsize {
24    /// Atomically increments the AtomicUsize by 1, saturating at `usize::MAX`.
25    ///
26    /// Return the new value or `usize::MAX`.
27    fn saturating_increment(&self, order: Ordering) -> usize {
28        loop {
29            let current: usize = self.load(order);
30
31            if current == usize::MAX {
32                // Already at maximum, cannot increment further
33                return current;
34            }
35
36            let new: usize = current + 1;
37            match self.compare_exchange(current, new, order, order) {
38                Ok(_) => return new,
39                Err(_) => continue, // Retry if the value changed concurrently
40            }
41        }
42    }
43
44    /// Atomically decrements the AtomicUsize by 1, saturating at `0`.
45    ///
46    /// Return the new value or `0`.
47    fn saturating_decrement(&self, order: Ordering) -> usize {
48        loop {
49            let current: usize = self.load(order);
50
51            if current == 0 {
52                // Already at minimum, cannot decrement further
53                return current;
54            }
55
56            let new: usize = current - 1;
57            match self.compare_exchange(current, new, order, order) {
58                Ok(_) => return new,
59                Err(_) => continue, // Retry if the value changed concurrently
60            }
61        }
62    }
63}
64
65/// Stealth clone
66pub trait StealthClone {
67    /// Clone without increment the atomic destructor counter.
68    ///
69    /// Items that are stealth cloned, NOT decrement the counter when dropped.
70    fn stealth_clone(&self) -> Self;
71}
72
73/// Atomic destroyer
74pub trait AtomicDestroyer: Debug + Clone {
75    /// Instructions to execute when all instances are dropped
76    fn on_destroy(&self);
77}
78
79/// Atomic destructor
80pub struct AtomicDestructor<T>
81where
82    T: AtomicDestroyer,
83{
84    destroyed: Arc<AtomicBool>,
85    counter: Arc<AtomicUsize>,
86    stealth: bool,
87    inner: T,
88}
89
90impl<T> Deref for AtomicDestructor<T>
91where
92    T: AtomicDestroyer,
93{
94    type Target = T;
95
96    fn deref(&self) -> &Self::Target {
97        &self.inner
98    }
99}
100
101impl<T> DerefMut for AtomicDestructor<T>
102where
103    T: AtomicDestroyer,
104{
105    fn deref_mut(&mut self) -> &mut Self::Target {
106        &mut self.inner
107    }
108}
109
110impl<T> Debug for AtomicDestructor<T>
111where
112    T: AtomicDestroyer,
113{
114    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115        f.debug_struct("AtomicDestructor")
116            .field("destroyed", &self.destroyed)
117            .field("counter", &self.counter)
118            .field("stealth", &self.stealth)
119            .field("inner", &self.inner)
120            .finish()
121    }
122}
123
124impl<T> Clone for AtomicDestructor<T>
125where
126    T: AtomicDestroyer,
127{
128    fn clone(&self) -> Self {
129        // Increase counter
130        self.counter.saturating_increment(Ordering::SeqCst);
131
132        // Clone
133        Self {
134            destroyed: self.destroyed.clone(),
135            counter: self.counter.clone(),
136            stealth: false,
137            inner: self.inner.clone(),
138        }
139    }
140}
141
142impl<T> StealthClone for AtomicDestructor<T>
143where
144    T: AtomicDestroyer,
145{
146    fn stealth_clone(&self) -> Self {
147        Self {
148            destroyed: self.destroyed.clone(),
149            counter: self.counter.clone(),
150            stealth: true,
151            inner: self.inner.clone(),
152        }
153    }
154}
155
156impl<T> Drop for AtomicDestructor<T>
157where
158    T: AtomicDestroyer,
159{
160    fn drop(&mut self) {
161        // Stealth or already destroyed, immediately return
162        if self.is_stealth() || self.is_destroyed() {
163            return;
164        }
165
166        // Decrease counter
167        let value: usize = self.counter.saturating_decrement(Ordering::SeqCst);
168
169        // Check if it's time for destruction
170        if value == 0 {
171            // Destroy
172            self.inner.on_destroy();
173
174            // Mark as destroyed
175            self.destroyed.store(true, Ordering::SeqCst);
176        }
177    }
178}
179
180impl<T> AtomicDestructor<T>
181where
182    T: AtomicDestroyer,
183{
184    /// New wrapper
185    pub fn new(inner: T) -> Self {
186        Self {
187            destroyed: Arc::new(AtomicBool::new(false)),
188            counter: Arc::new(AtomicUsize::new(1)),
189            stealth: false,
190            inner,
191        }
192    }
193
194    /// Get counter
195    pub fn counter(&self) -> usize {
196        self.counter.load(Ordering::SeqCst)
197    }
198
199    /// Check if destroyed
200    pub fn is_destroyed(&self) -> bool {
201        self.destroyed.load(Ordering::SeqCst)
202    }
203
204    /// Check if is stealth (stealth cloned, not subject to counter increase/decrease)
205    pub fn is_stealth(&self) -> bool {
206        self.stealth
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[derive(Debug, Clone)]
215    struct InternalTestingStealth;
216
217    impl AtomicDestroyer for InternalTestingStealth {
218        fn on_destroy(&self) {}
219    }
220
221    #[derive(Clone)]
222    struct TestingStealth {
223        inner: AtomicDestructor<InternalTestingStealth>,
224    }
225
226    impl StealthClone for TestingStealth {
227        fn stealth_clone(&self) -> Self {
228            Self {
229                inner: self.inner.stealth_clone(),
230            }
231        }
232    }
233
234    impl TestingStealth {
235        pub fn new() -> Self {
236            Self {
237                inner: AtomicDestructor::new(InternalTestingStealth),
238            }
239        }
240    }
241
242    #[test]
243    fn test_clone() {
244        let t = TestingStealth::new();
245        assert_eq!(t.inner.counter(), 1);
246        assert!(!t.inner.is_stealth());
247        assert!(!t.inner.is_destroyed());
248
249        let t_1 = t.clone();
250        assert_eq!(t.inner.counter(), 2);
251        assert_eq!(t_1.inner.counter(), 2);
252
253        let t_2 = t_1.clone();
254        assert_eq!(t.inner.counter(), 3);
255        assert_eq!(t_1.inner.counter(), 3);
256        assert_eq!(t_2.inner.counter(), 3);
257
258        drop(t_1);
259        assert_eq!(t.inner.counter(), 2);
260        assert!(!t.inner.is_destroyed());
261
262        drop(t_2);
263        assert_eq!(t.inner.counter(), 1);
264    }
265
266    #[test]
267    fn test_stealth_clone() {
268        let t = TestingStealth::new();
269        assert_eq!(t.inner.counter(), 1);
270        assert!(!t.inner.is_stealth());
271
272        let t_1 = t.stealth_clone();
273        assert_eq!(t.inner.counter(), 1);
274        assert_eq!(t_1.inner.counter(), 1);
275        assert!(!t.inner.is_stealth());
276        assert!(t_1.inner.is_stealth());
277
278        let t_2 = t_1.clone(); // Cloning stealth destructor
279        assert_eq!(t.inner.counter(), 2);
280        assert_eq!(t_1.inner.counter(), 2);
281        assert_eq!(t_2.inner.counter(), 2);
282
283        let t_3 = t.clone(); // Cloning NON-stealth destructor
284        assert_eq!(t.inner.counter(), 3);
285        assert_eq!(t_1.inner.counter(), 3);
286        assert_eq!(t_2.inner.counter(), 3);
287
288        drop(t_1); // Stealth
289        assert_eq!(t.inner.counter(), 3);
290
291        drop(t_2); // Classical
292        assert_eq!(t.inner.counter(), 2);
293
294        drop(t_3); // Classical
295        assert_eq!(t.inner.counter(), 1);
296    }
297}