use core::ptr;
use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicU8, Ordering};
const SET_ACQUIRE_FLAG: u8 = 1 << 1;
const SET_RELEASE_FLAG: u8 = 1 << 0;
const IS_INIT_BITMASK: u8 = SET_ACQUIRE_FLAG | SET_RELEASE_FLAG;
pub struct AtomicOnceCell<T> {
data: MaybeUninit<UnsafeCell<T>>,
is_initialized: AtomicU8,
}
impl<T> Default for AtomicOnceCell<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> AtomicOnceCell<T> {
pub fn new() -> Self {
Self {
data: MaybeUninit::uninit(),
is_initialized: AtomicU8::new(0),
}
}
#[inline(always)]
fn start_set(&self) {
match self
.is_initialized
.fetch_update(Ordering::Acquire, Ordering::Relaxed, |atomic_val| {
Some(atomic_val | SET_ACQUIRE_FLAG)
}) {
Ok(atomic_val) => {
if atomic_val & IS_INIT_BITMASK > 0 {
panic!("cannot be set more than once");
}
}
_ => unreachable!(),
};
}
#[inline(always)]
fn end_set(&self) {
match self
.is_initialized
.fetch_update(Ordering::Release, Ordering::Relaxed, |atomic_val| {
Some(atomic_val | SET_RELEASE_FLAG)
}) {
Ok(_) => {}
_ => unreachable!(),
};
}
pub fn set(
&self,
val: T,
) {
self.start_set();
{
let maybe_uninit = self.ptr_to_maybe_uninit();
unsafe {
let ptr = AtomicOnceCell::maybe_uninit_as_ptr(maybe_uninit);
AtomicOnceCell::unsafe_cell_raw_get(ptr).write(val);
}
}
self.end_set();
}
pub fn get(&self) -> &T {
let is_initialized = self.is_initialized.load(Ordering::Acquire);
if is_initialized == 0 {
panic!("not initialized");
}
let maybe_uninit = self.ptr_to_maybe_uninit();
let assume_init = unsafe {
let maybe_uninit_ref = maybe_uninit.as_ref().unwrap();
AtomicOnceCell::maybe_uninit_assume_init_ref(maybe_uninit_ref)
};
let val = unsafe {
&*assume_init.get()
};
val
}
#[inline(always)]
fn ptr_to_maybe_uninit(&self) -> *const MaybeUninit<UnsafeCell<T>> {
&self.data as *const MaybeUninit<UnsafeCell<T>>
}
#[inline(always)]
fn ptr_to_maybe_uninit_mut(&mut self) -> *mut MaybeUninit<UnsafeCell<T>> {
&mut self.data as *mut MaybeUninit<UnsafeCell<T>>
}
#[inline(always)]
unsafe fn maybe_uninit_as_ptr(
maybe_uninit: *const MaybeUninit<UnsafeCell<T>>
) -> *const UnsafeCell<T> {
maybe_uninit as *const _ as *const UnsafeCell<T>
}
#[inline(always)]
unsafe fn maybe_uninit_as_mut_ptr(
maybe_uninit: *mut MaybeUninit<UnsafeCell<T>>
) -> *mut UnsafeCell<T> {
maybe_uninit as *mut _ as *mut UnsafeCell<T>
}
#[inline(always)]
unsafe fn unsafe_cell_raw_get(cell: *const UnsafeCell<T>) -> *mut T {
cell as *const T as *mut T
}
#[inline(always)]
unsafe fn maybe_uninit_assume_init_ref(
maybe_uninit: &MaybeUninit<UnsafeCell<T>>
) -> &UnsafeCell<T> {
&*maybe_uninit.as_ptr()
}
}
impl<T> Drop for AtomicOnceCell<T> {
fn drop(&mut self) {
let atomic_val = self.is_initialized.load(Ordering::Relaxed);
let is_initialized = atomic_val & IS_INIT_BITMASK == IS_INIT_BITMASK;
if is_initialized {
let maybe_uninit = self.ptr_to_maybe_uninit_mut();
unsafe {
ptr::drop_in_place(AtomicOnceCell::maybe_uninit_as_mut_ptr(maybe_uninit))
}
} else {
}
}
}
unsafe impl<T> Sync for AtomicOnceCell<T> {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::mpsc;
use std::sync::mpsc::{Receiver, Sender};
use std::{panic, thread};
struct DroppableElement {
id: usize,
sender: Option<Sender<usize>>,
}
impl DroppableElement {
pub fn new(
id: usize,
sender: Option<&Sender<usize>>,
) -> Self {
Self {
id,
sender: sender.map(|sender| sender.clone()),
}
}
}
impl Drop for DroppableElement {
fn drop(&mut self) {
if let Some(sender) = &self.sender {
let _ = sender.send(self.id);
}
}
}
fn default_drop() -> (AtomicOnceCell<DroppableElement>, Receiver<usize>) {
let array = AtomicOnceCell::new();
let receiver = {
let (sender, receiver) = mpsc::channel();
array.set(DroppableElement::new(0, Some(&sender)));
receiver
};
(array, receiver)
}
#[test]
fn test_drop() {
let (array, receiver) = default_drop();
assert_eq!(receiver.try_recv().ok(), None);
std::mem::drop(array);
let indices = receiver.iter().collect::<Vec<_>>();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], 0);
}
#[test]
fn test_drop_panic() {
let (array, receiver) = default_drop();
assert_eq!(receiver.try_recv().ok(), None);
let result = thread::spawn(move || {
array.set(DroppableElement::new(1, None)); })
.join();
assert!(result.is_err());
let indices = receiver.iter().collect::<Vec<_>>();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], 0);
}
#[test]
fn test_drop_thread() {
let (array, receiver) = default_drop();
assert_eq!(receiver.try_recv().ok(), None);
let result = thread::spawn(move || {
assert_eq!(array.get().id, 0);
})
.join();
assert!(result.is_ok());
let indices = receiver.iter().collect::<Vec<_>>();
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], 0);
}
struct PanicOnDropElement {
_id: u32,
}
impl Drop for PanicOnDropElement {
fn drop(&mut self) {
panic!("element dropped");
}
}
fn default_panic_on_drop() -> AtomicOnceCell<PanicOnDropElement> {
AtomicOnceCell::new()
}
#[test]
fn test_drop_no_panic() {
let array = default_panic_on_drop();
std::mem::drop(array);
}
fn default_i32() -> AtomicOnceCell<i32> {
AtomicOnceCell::new()
}
#[test]
fn test_set_0() {
let array = default_i32();
array.set(7);
assert_eq!(array.get(), &7);
}
#[test]
#[should_panic(expected = "cannot be set more than once")]
fn test_set_0_twice() {
let array = default_i32();
array.set(12);
assert_eq!(array.get(), &12);
array.set(-2);
}
#[test]
#[should_panic(expected = "not initialized")]
fn test_get_0_uninitialized() {
let array = default_i32();
array.get();
}
struct ZeroSizedType {}
fn default_zst() -> AtomicOnceCell<ZeroSizedType> {
AtomicOnceCell::new()
}
#[test]
fn test_zst_set_7() {
let array = default_zst();
array.set(ZeroSizedType {});
array.get();
}
#[test]
#[should_panic(expected = "not initialized")]
fn test_zst_get_7_uninitialized() {
let array = default_zst();
array.get();
}
mod zst_lifetime {
struct PrivateInnerZst {}
pub struct CannotConstructZstLifetime<'a, T> {
_guard: PrivateInnerZst,
_phantom: std::marker::PhantomData<&'a T>,
}
}
#[test]
#[should_panic(expected = "not initialized")]
fn test_zst_get_0_uninitialized_lifetime<'a>() {
use zst_lifetime::CannotConstructZstLifetime;
let array = AtomicOnceCell::new();
let _val: &CannotConstructZstLifetime<'a, u32> = array.get();
}
mod zst_private {
struct PrivateInnerZst {}
pub struct CannotConstructZstInner(PrivateInnerZst);
}
#[test]
#[should_panic(expected = "not initialized")]
fn test_zst_get_0_uninitialized_private_type() {
use zst_private::CannotConstructZstInner;
let array = AtomicOnceCell::new();
let _val: &CannotConstructZstInner = array.get();
}
enum Void {}
#[test]
#[should_panic(expected = "not initialized")]
fn test_zst_get_0_uninitialized_void() {
let array = AtomicOnceCell::new();
let _val: &Void = array.get();
}
#[test]
fn test_zst_observable_drop() {
mod zst_drop {
use std::sync::atomic::{AtomicU32, Ordering};
static ATOMIC_COUNTER: AtomicU32 = AtomicU32::new(0);
struct PrivateInnerZst {}
pub struct ObservableZstDrop(PrivateInnerZst);
impl ObservableZstDrop {
pub fn new() -> Self {
assert_eq!(std::mem::size_of::<Self>(), 0);
ATOMIC_COUNTER.fetch_add(1, Ordering::Relaxed);
ObservableZstDrop(PrivateInnerZst {})
}
}
impl Drop for ObservableZstDrop {
fn drop(&mut self) {
ATOMIC_COUNTER.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn get_counter() -> u32 {
ATOMIC_COUNTER.load(Ordering::Relaxed)
}
}
use zst_drop::{get_counter, ObservableZstDrop};
assert_eq!(get_counter(), 0);
let array = AtomicOnceCell::new();
array.set(ObservableZstDrop::new());
assert_eq!(get_counter(), 1);
std::mem::drop(array);
assert_eq!(get_counter(), 0);
}
}