rafx_base/
atomic_once_cell.rs1use core::ptr;
2use std::cell::UnsafeCell;
3use std::mem::MaybeUninit;
4use std::sync::atomic::{AtomicU8, Ordering};
5
6const SET_ACQUIRE_FLAG: u8 = 1 << 1;
7const SET_RELEASE_FLAG: u8 = 1 << 0;
8const IS_INIT_BITMASK: u8 = SET_ACQUIRE_FLAG | SET_RELEASE_FLAG;
9
10pub struct AtomicOnceCell<T> {
20 data: MaybeUninit<UnsafeCell<T>>,
21 is_initialized: AtomicU8,
22}
23
24impl<T> Default for AtomicOnceCell<T> {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl<T> AtomicOnceCell<T> {
31 pub fn new() -> Self {
32 Self {
33 data: MaybeUninit::uninit(),
34 is_initialized: AtomicU8::new(0),
35 }
36 }
37
38 #[inline(always)]
39 fn start_set(&self) {
40 match self
42 .is_initialized
43 .fetch_update(Ordering::Acquire, Ordering::Relaxed, |atomic_val| {
44 Some(atomic_val | SET_ACQUIRE_FLAG)
45 }) {
46 Ok(atomic_val) => {
47 if atomic_val & IS_INIT_BITMASK > 0 {
48 panic!("cannot be set more than once");
50 }
51 }
52 _ => unreachable!(),
53 };
54 }
55
56 #[inline(always)]
57 fn end_set(&self) {
58 match self
60 .is_initialized
61 .fetch_update(Ordering::Release, Ordering::Relaxed, |atomic_val| {
62 Some(atomic_val | SET_RELEASE_FLAG)
63 }) {
64 Ok(_) => {}
65 _ => unreachable!(),
66 };
67 }
68
69 pub fn set(
70 &self,
71 val: T,
72 ) {
73 self.start_set();
75
76 {
77 let maybe_uninit = self.ptr_to_maybe_uninit();
78 unsafe {
79 let ptr = AtomicOnceCell::maybe_uninit_as_ptr(maybe_uninit);
87 AtomicOnceCell::unsafe_cell_raw_get(ptr).write(val);
88 }
89 }
90
91 self.end_set();
93 }
94
95 pub fn get(&self) -> &T {
96 let is_initialized = self.is_initialized.load(Ordering::Acquire);
97 if is_initialized == 0 {
98 panic!("not initialized");
100 }
101
102 let maybe_uninit = self.ptr_to_maybe_uninit();
103 let assume_init = unsafe {
104 let maybe_uninit_ref = maybe_uninit.as_ref().unwrap();
107
108 AtomicOnceCell::maybe_uninit_assume_init_ref(maybe_uninit_ref)
110 };
111
112 let val = unsafe {
113 &*assume_init.get()
117 };
118
119 val
120 }
121
122 #[inline(always)]
123 fn ptr_to_maybe_uninit(&self) -> *const MaybeUninit<UnsafeCell<T>> {
124 &self.data as *const MaybeUninit<UnsafeCell<T>>
125 }
126
127 #[inline(always)]
128 fn ptr_to_maybe_uninit_mut(&mut self) -> *mut MaybeUninit<UnsafeCell<T>> {
129 &mut self.data as *mut MaybeUninit<UnsafeCell<T>>
130 }
131
132 #[inline(always)]
133 unsafe fn maybe_uninit_as_ptr(
134 maybe_uninit: *const MaybeUninit<UnsafeCell<T>>
135 ) -> *const UnsafeCell<T> {
136 maybe_uninit as *const _ as *const UnsafeCell<T>
139 }
140
141 #[inline(always)]
142 unsafe fn maybe_uninit_as_mut_ptr(
143 maybe_uninit: *mut MaybeUninit<UnsafeCell<T>>
144 ) -> *mut UnsafeCell<T> {
145 maybe_uninit as *mut _ as *mut UnsafeCell<T>
148 }
149
150 #[inline(always)]
151 unsafe fn unsafe_cell_raw_get(cell: *const UnsafeCell<T>) -> *mut T {
152 cell as *const T as *mut T
155 }
156
157 #[inline(always)]
158 unsafe fn maybe_uninit_assume_init_ref(
159 maybe_uninit: &MaybeUninit<UnsafeCell<T>>
160 ) -> &UnsafeCell<T> {
161 &*maybe_uninit.as_ptr()
164 }
165}
166
167impl<T> Drop for AtomicOnceCell<T> {
168 fn drop(&mut self) {
169 let atomic_val = self.is_initialized.load(Ordering::Relaxed);
173 let is_initialized = atomic_val & IS_INIT_BITMASK == IS_INIT_BITMASK;
174
175 if is_initialized {
176 let maybe_uninit = self.ptr_to_maybe_uninit_mut();
177 unsafe {
178 ptr::drop_in_place(AtomicOnceCell::maybe_uninit_as_mut_ptr(maybe_uninit))
180 }
181 } else {
182 }
188 }
189}
190
191unsafe impl<T> Sync for AtomicOnceCell<T> {}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use std::sync::mpsc;
197 use std::sync::mpsc::{Receiver, Sender};
198 use std::{panic, thread};
199
200 struct DroppableElement {
201 id: usize,
202 sender: Option<Sender<usize>>,
203 }
204
205 impl DroppableElement {
206 pub fn new(
207 id: usize,
208 sender: Option<&Sender<usize>>,
209 ) -> Self {
210 Self {
211 id,
212 sender: sender.map(|sender| sender.clone()),
213 }
214 }
215 }
216
217 impl Drop for DroppableElement {
218 fn drop(&mut self) {
219 if let Some(sender) = &self.sender {
220 let _ = sender.send(self.id);
221 }
222 }
223 }
224
225 fn default_drop() -> (AtomicOnceCell<DroppableElement>, Receiver<usize>) {
226 let array = AtomicOnceCell::new();
227
228 let receiver = {
229 let (sender, receiver) = mpsc::channel();
230 array.set(DroppableElement::new(0, Some(&sender)));
231 receiver
232 };
233
234 (array, receiver)
235 }
236
237 #[test]
238 fn test_drop() {
239 let (array, receiver) = default_drop();
240
241 assert_eq!(receiver.try_recv().ok(), None);
242
243 std::mem::drop(array);
245
246 let indices = receiver.iter().collect::<Vec<_>>();
247 assert_eq!(indices.len(), 1);
248 assert_eq!(indices[0], 0);
249 }
250
251 #[test]
252 fn test_drop_panic() {
253 let (array, receiver) = default_drop();
254
255 assert_eq!(receiver.try_recv().ok(), None);
256
257 let result = thread::spawn(move || {
258 array.set(DroppableElement::new(1, None)); })
260 .join();
261
262 assert!(result.is_err());
263
264 let indices = receiver.iter().collect::<Vec<_>>();
265 assert_eq!(indices.len(), 1);
266 assert_eq!(indices[0], 0);
267 }
268
269 #[test]
270 fn test_drop_thread() {
271 let (array, receiver) = default_drop();
272
273 assert_eq!(receiver.try_recv().ok(), None);
274
275 let result = thread::spawn(move || {
276 assert_eq!(array.get().id, 0);
277 })
279 .join();
280
281 assert!(result.is_ok());
282
283 let indices = receiver.iter().collect::<Vec<_>>();
284 assert_eq!(indices.len(), 1);
285 assert_eq!(indices[0], 0);
286 }
287
288 struct PanicOnDropElement {
289 _id: u32,
290 }
291
292 impl Drop for PanicOnDropElement {
293 fn drop(&mut self) {
294 panic!("element dropped");
295 }
296 }
297
298 fn default_panic_on_drop() -> AtomicOnceCell<PanicOnDropElement> {
299 AtomicOnceCell::new()
300 }
301
302 #[test]
303 fn test_drop_no_panic() {
304 let array = default_panic_on_drop();
305 std::mem::drop(array);
306 }
307
308 fn default_i32() -> AtomicOnceCell<i32> {
309 AtomicOnceCell::new()
310 }
311
312 #[test]
313 fn test_set_0() {
314 let array = default_i32();
315 array.set(7);
316 assert_eq!(array.get(), &7);
317 }
318
319 #[test]
320 #[should_panic(expected = "cannot be set more than once")]
321 fn test_set_0_twice() {
322 let array = default_i32();
323 array.set(12);
324 assert_eq!(array.get(), &12);
325 array.set(-2);
326 }
327
328 #[test]
329 #[should_panic(expected = "not initialized")]
330 fn test_get_0_uninitialized() {
331 let array = default_i32();
332 array.get();
333 }
334
335 struct ZeroSizedType {}
338
339 fn default_zst() -> AtomicOnceCell<ZeroSizedType> {
340 AtomicOnceCell::new()
341 }
342
343 #[test]
344 fn test_zst_set_7() {
345 let array = default_zst();
346 array.set(ZeroSizedType {});
347 array.get();
348 }
349
350 #[test]
351 #[should_panic(expected = "not initialized")]
352 fn test_zst_get_7_uninitialized() {
353 let array = default_zst();
354
355 array.get();
357 }
358
359 mod zst_lifetime {
360 struct PrivateInnerZst {}
361
362 pub struct CannotConstructZstLifetime<'a, T> {
363 _guard: PrivateInnerZst,
364 _phantom: std::marker::PhantomData<&'a T>,
365 }
366 }
367
368 #[test]
369 #[should_panic(expected = "not initialized")]
370 fn test_zst_get_0_uninitialized_lifetime<'a>() {
371 use zst_lifetime::CannotConstructZstLifetime;
372
373 let array = AtomicOnceCell::new();
374
375 let _val: &CannotConstructZstLifetime<'a, u32> = array.get();
377 }
378
379 mod zst_private {
380 struct PrivateInnerZst {}
381
382 pub struct CannotConstructZstInner(PrivateInnerZst);
383 }
384
385 #[test]
386 #[should_panic(expected = "not initialized")]
387 fn test_zst_get_0_uninitialized_private_type() {
388 use zst_private::CannotConstructZstInner;
389
390 let array = AtomicOnceCell::new();
391
392 let _val: &CannotConstructZstInner = array.get();
398 }
399
400 enum Void {}
401
402 #[test]
403 #[should_panic(expected = "not initialized")]
404 fn test_zst_get_0_uninitialized_void() {
405 let array = AtomicOnceCell::new();
406
407 let _val: &Void = array.get();
409 }
410
411 #[test]
412 fn test_zst_observable_drop() {
413 mod zst_drop {
414 use std::sync::atomic::{AtomicU32, Ordering};
421
422 static ATOMIC_COUNTER: AtomicU32 = AtomicU32::new(0);
423
424 struct PrivateInnerZst {}
425
426 pub struct ObservableZstDrop(PrivateInnerZst);
427
428 impl ObservableZstDrop {
429 pub fn new() -> Self {
430 assert_eq!(std::mem::size_of::<Self>(), 0);
431 ATOMIC_COUNTER.fetch_add(1, Ordering::Relaxed);
432 ObservableZstDrop(PrivateInnerZst {})
433 }
434 }
435
436 impl Drop for ObservableZstDrop {
437 fn drop(&mut self) {
438 ATOMIC_COUNTER.fetch_sub(1, Ordering::Relaxed);
439 }
440 }
441
442 pub fn get_counter() -> u32 {
443 ATOMIC_COUNTER.load(Ordering::Relaxed)
444 }
445 }
446
447 use zst_drop::{get_counter, ObservableZstDrop};
448
449 assert_eq!(get_counter(), 0);
450 let array = AtomicOnceCell::new();
451 array.set(ObservableZstDrop::new());
452 assert_eq!(get_counter(), 1);
453
454 std::mem::drop(array);
455 assert_eq!(get_counter(), 0);
456 }
457}