rafx_base/
atomic_once_cell.rs

1use core::ptr;
2use std::cell::UnsafeCell;
3use std::mem::MaybeUninit;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6const SET_ACQUIRE_FLAG: u8 = 1 << 1;
7const SET_RELEASE_FLAG: u8 = 1 << 0;
8const IS_INIT_BITMASK: u8 = SET_ACQUIRE_FLAG | SET_RELEASE_FLAG;
9
10/// A thread-safe container that does not require default initialization. The cell may be initialized
11/// with `set` and then retrieved as a reference with `get`.  Calling `set` is thread-safe. The cell
12/// will panic if the `set` function is called more than once. The cell will only drop initialized elements.
13///
14/// # Guarantees
15///
16/// - The allocated memory will not be `default` initialized.
17/// - Elements initialized by `set` are immutable.
18/// - The synchronization is `lock-free`.
19pub struct AtomicOnceCell<T> {
20    data: MaybeUninit<UnsafeCell<T>>,
21    is_initialized: AtomicU8,
22}
23
24impl<T> Default for AtomicOnceCell<T> {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl<T> AtomicOnceCell<T> {
31    pub fn new() -> Self {
32        Self {
33            data: MaybeUninit::uninit(),
34            is_initialized: AtomicU8::new(0),
35        }
36    }
37
38    #[inline(always)]
39    fn start_set(&self) {
40        // NOTE(dvd): Use `Acquire` to start a protected section.
41        match self
42            .is_initialized
43            .fetch_update(Ordering::Acquire, Ordering::Relaxed, |atomic_val| {
44                Some(atomic_val | SET_ACQUIRE_FLAG)
45            }) {
46            Ok(atomic_val) => {
47                if atomic_val & IS_INIT_BITMASK > 0 {
48                    // SAFETY: Panic if multiple attempts to initialize the same index occur.
49                    panic!("cannot be set more than once");
50                }
51            }
52            _ => unreachable!(),
53        };
54    }
55
56    #[inline(always)]
57    fn end_set(&self) {
58        // NOTE(dvd): Use `Release` to end the protected section.
59        match self
60            .is_initialized
61            .fetch_update(Ordering::Release, Ordering::Relaxed, |atomic_val| {
62                Some(atomic_val | SET_RELEASE_FLAG)
63            }) {
64            Ok(_) => {}
65            _ => unreachable!(),
66        };
67    }
68
69    pub fn set(
70        &self,
71        val: T,
72    ) {
73        // NOTE(dvd): "Acquire" a lock.
74        self.start_set();
75
76        {
77            let maybe_uninit = self.ptr_to_maybe_uninit();
78            unsafe {
79                // SAFETY: If `atomic_val` had neither bits sit, we know that this value
80                // is uninitialized & no other thread is trying to initialize it at the same
81                // time. If another thread had been trying to initialize it, then the
82                // `SET_ACQUIRE_FLAG` would have been set and we would have panicked above.
83                // We can therefore safely initialize the `MaybeUninit` value following the
84                // example for how to initialize an `UnsafeCell` inside of `MaybeUninit`.
85                // https://doc.rust-lang.org/beta/std/cell/struct.UnsafeCell.html#method.raw_get.
86                let ptr = AtomicOnceCell::maybe_uninit_as_ptr(maybe_uninit);
87                AtomicOnceCell::unsafe_cell_raw_get(ptr).write(val);
88            }
89        }
90
91        // NOTE(dvd): "Release" the lock.
92        self.end_set();
93    }
94
95    pub fn get(&self) -> &T {
96        let is_initialized = self.is_initialized.load(Ordering::Acquire);
97        if is_initialized == 0 {
98            // SAFETY: Panic if uninitialized data would be read.
99            panic!("not initialized");
100        }
101
102        let maybe_uninit = self.ptr_to_maybe_uninit();
103        let assume_init = unsafe {
104            // SAFETY: We can create a &MaybeUninit because we've initialized the memory
105            // in `set`, otherwise we would have panicked above otherwise when checking the bitmask.
106            let maybe_uninit_ref = maybe_uninit.as_ref().unwrap();
107
108            // SAFETY: We can then use `assume_init_ref` to get the initialized UnsafeCell<T>.
109            AtomicOnceCell::maybe_uninit_assume_init_ref(maybe_uninit_ref)
110        };
111
112        let val = unsafe {
113            // SAFETY: Cast the &UnsafeCell<T> to &T.
114            // This is ok because we know that nothing can mutate the underlying index.
115            // If something tried to `set` that index, it would panic instead.
116            &*assume_init.get()
117        };
118
119        val
120    }
121
122    #[inline(always)]
123    fn ptr_to_maybe_uninit(&self) -> *const MaybeUninit<UnsafeCell<T>> {
124        &self.data as *const MaybeUninit<UnsafeCell<T>>
125    }
126
127    #[inline(always)]
128    fn ptr_to_maybe_uninit_mut(&mut self) -> *mut MaybeUninit<UnsafeCell<T>> {
129        &mut self.data as *mut MaybeUninit<UnsafeCell<T>>
130    }
131
132    #[inline(always)]
133    unsafe fn maybe_uninit_as_ptr(
134        maybe_uninit: *const MaybeUninit<UnsafeCell<T>>
135    ) -> *const UnsafeCell<T> {
136        // SAFETY: Equivalent to MaybeUninit::as_ptr, but defined for a ptr instead of &self.
137        // https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.as_ptr
138        maybe_uninit as *const _ as *const UnsafeCell<T>
139    }
140
141    #[inline(always)]
142    unsafe fn maybe_uninit_as_mut_ptr(
143        maybe_uninit: *mut MaybeUninit<UnsafeCell<T>>
144    ) -> *mut UnsafeCell<T> {
145        // SAFETY: Equivalent to MaybeUninit::as_mut_ptr, but defined for a ptr instead of &mut self.
146        // https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.as_mut_ptr
147        maybe_uninit as *mut _ as *mut UnsafeCell<T>
148    }
149
150    #[inline(always)]
151    unsafe fn unsafe_cell_raw_get(cell: *const UnsafeCell<T>) -> *mut T {
152        // SAFETY: Equivalent to the unstable API UnsafeCell::raw_get defined at
153        // https://doc.rust-lang.org/beta/std/cell/struct.UnsafeCell.html#method.raw_get
154        cell as *const T as *mut T
155    }
156
157    #[inline(always)]
158    unsafe fn maybe_uninit_assume_init_ref(
159        maybe_uninit: &MaybeUninit<UnsafeCell<T>>
160    ) -> &UnsafeCell<T> {
161        // SAFETY: Equivalent to the unstable API MaybeUninit::assume_init_ref defined at
162        // https://doc.rust-lang.org/std/mem/union.MaybeUninit.html#method.assume_init_ref
163        &*maybe_uninit.as_ptr()
164    }
165}
166
167impl<T> Drop for AtomicOnceCell<T> {
168    fn drop(&mut self) {
169        // SAFETY: We don't need to be concerned about any set that conceptually occurs while the
170        // `drop` in progress because `drop` takes a &mut self so no other code has a &self.
171
172        let atomic_val = self.is_initialized.load(Ordering::Relaxed);
173        let is_initialized = atomic_val & IS_INIT_BITMASK == IS_INIT_BITMASK;
174
175        if is_initialized {
176            let maybe_uninit = self.ptr_to_maybe_uninit_mut();
177            unsafe {
178                // SAFETY: If the bitmask has both bits set, this index is definitely initialized.
179                ptr::drop_in_place(AtomicOnceCell::maybe_uninit_as_mut_ptr(maybe_uninit))
180            }
181        } else {
182            // SAFETY: If the bitmask only has the high bit set (the set was in progress),
183            // we won't drop it, so the value T that was moved into `set` will get leaked just
184            // like mem::forget (which is safe).
185            // If the bitmask has both bits unset, that index doesn't need to be dropped
186            // because it's definitely uninitialized.
187        }
188    }
189}
190
191unsafe impl<T> Sync for AtomicOnceCell<T> {}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use std::sync::mpsc;
197    use std::sync::mpsc::{Receiver, Sender};
198    use std::{panic, thread};
199
200    struct DroppableElement {
201        id: usize,
202        sender: Option<Sender<usize>>,
203    }
204
205    impl DroppableElement {
206        pub fn new(
207            id: usize,
208            sender: Option<&Sender<usize>>,
209        ) -> Self {
210            Self {
211                id,
212                sender: sender.map(|sender| sender.clone()),
213            }
214        }
215    }
216
217    impl Drop for DroppableElement {
218        fn drop(&mut self) {
219            if let Some(sender) = &self.sender {
220                let _ = sender.send(self.id);
221            }
222        }
223    }
224
225    fn default_drop() -> (AtomicOnceCell<DroppableElement>, Receiver<usize>) {
226        let array = AtomicOnceCell::new();
227
228        let receiver = {
229            let (sender, receiver) = mpsc::channel();
230            array.set(DroppableElement::new(0, Some(&sender)));
231            receiver
232        };
233
234        (array, receiver)
235    }
236
237    #[test]
238    fn test_drop() {
239        let (array, receiver) = default_drop();
240
241        assert_eq!(receiver.try_recv().ok(), None);
242
243        // NOTE(dvd): `array` is dropped here.
244        std::mem::drop(array);
245
246        let indices = receiver.iter().collect::<Vec<_>>();
247        assert_eq!(indices.len(), 1);
248        assert_eq!(indices[0], 0);
249    }
250
251    #[test]
252    fn test_drop_panic() {
253        let (array, receiver) = default_drop();
254
255        assert_eq!(receiver.try_recv().ok(), None);
256
257        let result = thread::spawn(move || {
258            array.set(DroppableElement::new(1, None)); // NOTE(dvd): `array` panics here.
259        })
260        .join();
261
262        assert!(result.is_err());
263
264        let indices = receiver.iter().collect::<Vec<_>>();
265        assert_eq!(indices.len(), 1);
266        assert_eq!(indices[0], 0);
267    }
268
269    #[test]
270    fn test_drop_thread() {
271        let (array, receiver) = default_drop();
272
273        assert_eq!(receiver.try_recv().ok(), None);
274
275        let result = thread::spawn(move || {
276            assert_eq!(array.get().id, 0);
277            // NOTE(dvd): `array` is dropped here.
278        })
279        .join();
280
281        assert!(result.is_ok());
282
283        let indices = receiver.iter().collect::<Vec<_>>();
284        assert_eq!(indices.len(), 1);
285        assert_eq!(indices[0], 0);
286    }
287
288    struct PanicOnDropElement {
289        _id: u32,
290    }
291
292    impl Drop for PanicOnDropElement {
293        fn drop(&mut self) {
294            panic!("element dropped");
295        }
296    }
297
298    fn default_panic_on_drop() -> AtomicOnceCell<PanicOnDropElement> {
299        AtomicOnceCell::new()
300    }
301
302    #[test]
303    fn test_drop_no_panic() {
304        let array = default_panic_on_drop();
305        std::mem::drop(array);
306    }
307
308    fn default_i32() -> AtomicOnceCell<i32> {
309        AtomicOnceCell::new()
310    }
311
312    #[test]
313    fn test_set_0() {
314        let array = default_i32();
315        array.set(7);
316        assert_eq!(array.get(), &7);
317    }
318
319    #[test]
320    #[should_panic(expected = "cannot be set more than once")]
321    fn test_set_0_twice() {
322        let array = default_i32();
323        array.set(12);
324        assert_eq!(array.get(), &12);
325        array.set(-2);
326    }
327
328    #[test]
329    #[should_panic(expected = "not initialized")]
330    fn test_get_0_uninitialized() {
331        let array = default_i32();
332        array.get();
333    }
334
335    // NOTE(dvd): The zero-sized T variant of the struct requires separate tests.
336
337    struct ZeroSizedType {}
338
339    fn default_zst() -> AtomicOnceCell<ZeroSizedType> {
340        AtomicOnceCell::new()
341    }
342
343    #[test]
344    fn test_zst_set_7() {
345        let array = default_zst();
346        array.set(ZeroSizedType {});
347        array.get();
348    }
349
350    #[test]
351    #[should_panic(expected = "not initialized")]
352    fn test_zst_get_7_uninitialized() {
353        let array = default_zst();
354
355        // NOTE(dvd): See comment on `test_zst_get_0_uninitialized_private_type`.
356        array.get();
357    }
358
359    mod zst_lifetime {
360        struct PrivateInnerZst {}
361
362        pub struct CannotConstructZstLifetime<'a, T> {
363            _guard: PrivateInnerZst,
364            _phantom: std::marker::PhantomData<&'a T>,
365        }
366    }
367
368    #[test]
369    #[should_panic(expected = "not initialized")]
370    fn test_zst_get_0_uninitialized_lifetime<'a>() {
371        use zst_lifetime::CannotConstructZstLifetime;
372
373        let array = AtomicOnceCell::new();
374
375        // NOTE(dvd): See comment on `test_zst_get_0_uninitialized_private_type`.
376        let _val: &CannotConstructZstLifetime<'a, u32> = array.get();
377    }
378
379    mod zst_private {
380        struct PrivateInnerZst {}
381
382        pub struct CannotConstructZstInner(PrivateInnerZst);
383    }
384
385    #[test]
386    #[should_panic(expected = "not initialized")]
387    fn test_zst_get_0_uninitialized_private_type() {
388        use zst_private::CannotConstructZstInner;
389
390        let array = AtomicOnceCell::new();
391
392        // NOTE(dvd): Even though T is zero-sized, we must have
393        // a proof that the user could construct T, otherwise
394        // this container would allow the user to get a &T that
395        // they aren't supposed to have -- e.g. due to a private
396        // zero-sized member in T, or a lifetime requirement.
397        let _val: &CannotConstructZstInner = array.get();
398    }
399
400    enum Void {}
401
402    #[test]
403    #[should_panic(expected = "not initialized")]
404    fn test_zst_get_0_uninitialized_void() {
405        let array = AtomicOnceCell::new();
406
407        // NOTE(dvd): See comment on `test_zst_get_0_uninitialized_private_type`.
408        let _val: &Void = array.get();
409    }
410
411    #[test]
412    fn test_zst_observable_drop() {
413        mod zst_drop {
414            // IMPORTANT(dvd): This mod is defined inside of the function because
415            // the use of a static atomic here is a hilarious race condition if
416            // multiple tests try to use the `ObservableZstDrop`. The reason why
417            // we can't put a reference to the counter inside of the zero-sized type
418            // is because then it wouldn't be zero-sized anymore.
419
420            use std::sync::atomic::{AtomicU32, Ordering};
421
422            static ATOMIC_COUNTER: AtomicU32 = AtomicU32::new(0);
423
424            struct PrivateInnerZst {}
425
426            pub struct ObservableZstDrop(PrivateInnerZst);
427
428            impl ObservableZstDrop {
429                pub fn new() -> Self {
430                    assert_eq!(std::mem::size_of::<Self>(), 0);
431                    ATOMIC_COUNTER.fetch_add(1, Ordering::Relaxed);
432                    ObservableZstDrop(PrivateInnerZst {})
433                }
434            }
435
436            impl Drop for ObservableZstDrop {
437                fn drop(&mut self) {
438                    ATOMIC_COUNTER.fetch_sub(1, Ordering::Relaxed);
439                }
440            }
441
442            pub fn get_counter() -> u32 {
443                ATOMIC_COUNTER.load(Ordering::Relaxed)
444            }
445        }
446
447        use zst_drop::{get_counter, ObservableZstDrop};
448
449        assert_eq!(get_counter(), 0);
450        let array = AtomicOnceCell::new();
451        array.set(ObservableZstDrop::new());
452        assert_eq!(get_counter(), 1);
453
454        std::mem::drop(array);
455        assert_eq!(get_counter(), 0);
456    }
457}