atomic_destroy/
lib.rs

1#![cfg_attr(not(test), no_std)]
2//! Atomically destroyable types.
3//!
4//! # Examples
5//! ```rust
6//! # use atomic_destroy::AtomicDestroy;
7//! let value = AtomicDestroy::new(Box::new(5));
8//! assert_eq!(**value.get().unwrap(), 5);
9//! value.destroy();
10//! // The Box's destructor is run here.
11//! assert!(value.get().is_none());
12//! ```
13#![warn(clippy::pedantic, clippy::cargo)]
14
15use core::cell::UnsafeCell;
16use core::marker::PhantomData;
17use core::mem::MaybeUninit;
18use core::ops::Deref;
19use core::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
20use core::ptr::drop_in_place;
21
22/// An atomically destroyable value.
23#[derive(Debug)]
24pub struct AtomicDestroy<T> {
25    /// The number of people current using the value. When this is 0 and `drop_state` is 1,
26    /// drop the value.
27    held_count: AtomicUsize,
28    /// Whether the value should be dropped at the next opportunity. 0 means don't drop the value,
29    /// 1 means drop the value when possible and 2 means the value is already dropped to avoid
30    /// double free.
31    drop_state: AtomicU8,
32    /// The value itself.
33    value: UnsafeCell<MaybeUninit<T>>,
34}
35
36impl<T> AtomicDestroy<T> {
37    /// Create a new atomically destroyable value.
38    #[must_use]
39    pub const fn new(value: T) -> Self {
40        Self {
41            held_count: AtomicUsize::new(0),
42            drop_state: AtomicU8::new(0),
43            value: UnsafeCell::new(MaybeUninit::new(value)),
44        }
45    }
46
47    /// Create an atomically destroyable value that has already been dropped.
48    #[must_use]
49    pub const fn empty() -> Self {
50        Self {
51            held_count: AtomicUsize::new(0),
52            drop_state: AtomicU8::new(2),
53            value: UnsafeCell::new(MaybeUninit::uninit()),
54        }
55    }
56
57    /// Create an atomically destroyable value from an `Option<T>`.
58    #[must_use]
59    pub fn maybe_new(value: Option<T>) -> Self {
60        match value {
61            Some(v) => Self::new(v),
62            None => Self::empty(),
63        }
64    }
65
66    /// Get the value if it hasn't been destroyed.
67    pub fn get(&self) -> Option<Value<T, &Self>> {
68        Value::new(self)
69    }
70
71    /// Run a function using the value.
72    pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> Option<R> {
73        self.get().map(|v| f(&*v))
74    }
75
76    /// Destroy the value. If someone is currently using the value the destructor will be run when
77    /// they are done.
78    pub fn destroy(&self) {
79        if self.drop_state.compare_and_swap(0, 1, Ordering::SeqCst) == 0
80            && self.held_count.load(Ordering::SeqCst) == 0
81            && self.drop_state.swap(2, Ordering::SeqCst) == 1
82        {
83            // SAFETY: This code is only run if `drop_state` was zero. As this code sets it to one
84            // and nothing else can set it back, this block can only be run once.
85            //
86            // If we also know that `held_count` is zero then no code can currently be reading the
87            // value. Moreover, no code in the future can read from the value as `drop_state` is
88            // permanently nonzero.
89            unsafe {
90                self.drop_value();
91            }
92        }
93    }
94
95    /// Drop the value, not checking if anyone else is using it.
96    ///
97    /// # Safety
98    ///
99    /// This function must only be called once, and `value` must not be accessed from this point
100    /// onwards.
101    unsafe fn drop_value(&self) {
102        drop_in_place((*self.value.get()).as_mut_ptr());
103    }
104}
105
106// These can probably be relaxed but I want to play it safe
107unsafe impl<T: Send + Sync> Send for AtomicDestroy<T> {}
108unsafe impl<T: Send + Sync> Sync for AtomicDestroy<T> {}
109
110impl<T> Drop for AtomicDestroy<T> {
111    fn drop(&mut self) {
112        if self.drop_state.load(Ordering::SeqCst) < 2 {
113            // SAFETY: We have unique access and the value is about to be destroyed.
114            unsafe { self.drop_value() };
115        }
116    }
117}
118
119impl<T: Clone> Clone for AtomicDestroy<T> {
120    fn clone(&self) -> Self {
121        Self::maybe_new(self.get().as_deref().cloned())
122    }
123}
124
125/// A "locked" value of an `AtomicDestroy`. While one of these exists the value inside the
126/// `AtomicDestroy` is guaranteed not to be dropped.
127#[derive(Debug)]
128pub struct Value<T, R: Deref<Target = AtomicDestroy<T>>> {
129    inner: R,
130    phantom: PhantomData<T>,
131}
132
133impl<T, R: Deref<Target = AtomicDestroy<T>>> Value<T, R> {
134    /// Get the value of an atomic destroyable. Equivalent to `AtomicDestroy::get`.
135    pub fn new(inner: R) -> Option<Self> {
136        // Prematurely make sure that the value won't be dropped.
137        inner.held_count.fetch_add(1, Ordering::SeqCst);
138
139        // Created here so that the destructor is always run.
140        let this = Self {
141            inner,
142            phantom: PhantomData,
143        };
144
145        if this.inner.drop_state.load(Ordering::SeqCst) > 0 {
146            // The value is dropped or is attempting to drop. Don't interfere.
147            None
148        } else {
149            Some(this)
150        }
151    }
152}
153
154impl<T, R: Deref<Target = AtomicDestroy<T>>> Deref for Value<T, R> {
155    type Target = T;
156
157    fn deref(&self) -> &Self::Target {
158        // SAFETY: Held count is guaranteed to be >0 here, and so the value cannot be dropped.
159        unsafe { &*(*self.inner.value.get()).as_ptr() }
160    }
161}
162
163impl<T, R: Deref<Target = AtomicDestroy<T>>> Drop for Value<T, R> {
164    fn drop(&mut self) {
165        if self.inner.held_count.fetch_sub(1, Ordering::SeqCst) == 1
166            && self
167                .inner
168                .drop_state
169                .compare_and_swap(1, 2, Ordering::SeqCst)
170                == 1
171        {
172            // SAFETY: This can only happen when the value has not been dropped yet, as `drop_state`
173            // is still 1.
174            //
175            // We also know that there are no other readers as `held_count` is zero.
176            unsafe {
177                self.inner.drop_value();
178            }
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use crate::AtomicDestroy;
186
187    // Boxes are used here to better catch double frees.
188
189    #[test]
190    fn test_simple() {
191        let value = AtomicDestroy::new(Box::new(5));
192        assert_eq!(**value.get().unwrap(), 5);
193        assert_eq!(**value.get().unwrap(), 5);
194        value.destroy();
195        assert!(value.get().is_none());
196    }
197
198    #[test]
199    fn test_keep_alive() {
200        let value = AtomicDestroy::new(Box::new(5));
201        let contents_1 = value.get().unwrap();
202        let contents_2 = value.get().unwrap();
203        assert_eq!(**contents_1, 5);
204        assert_eq!(**contents_2, 5);
205
206        value.destroy();
207        assert_eq!(**contents_1, 5);
208        assert_eq!(**contents_2, 5);
209        assert!(value.get().is_none());
210
211        drop(contents_1);
212        assert_eq!(**contents_2, 5);
213        assert!(value.get().is_none());
214
215        drop(contents_2);
216        assert!(value.get().is_none());
217    }
218
219    #[test]
220    fn test_empty() {
221        assert!(<AtomicDestroy<()>>::empty().get().is_none());
222    }
223
224    use std::{thread, iter};
225    use std::sync::Arc;
226    use std::time::{Instant, Duration};
227
228    #[test]
229    fn stress_test() {
230        let limit = Instant::now() + Duration::from_secs(3);
231        let value = Arc::new(AtomicDestroy::new(Box::new(5)));
232
233        let mut threads = iter::repeat_with(|| {
234            let value = value.clone();
235
236            thread::spawn(move || {
237                while Instant::now() < limit {
238                    match value.get() {
239                        Some(v) => assert_eq!(**v, 5),
240                        None => break,
241                    }
242                }
243            })
244        }).take(5).collect::<Vec<_>>();
245
246        thread::sleep(Duration::from_secs(1));
247
248        threads.extend(iter::repeat_with(|| {
249            let value = value.clone();
250
251            thread::spawn(move || {
252                for _ in 0..800 {
253                    value.destroy();
254                }
255            })
256        }).take(5));
257
258        for thread in threads {
259            thread.join().unwrap();
260        }
261    }
262}