use std::{error::Error, fmt::Display, ptr::null_mut};
use num_enum::{IntoPrimitive, TryFromPrimitive};
use crate::{
atomic_try_update,
bits::{Align8, FlagPtr},
Atom,
};
#[derive(IntoPrimitive, TryFromPrimitive)]
#[repr(usize)]
enum Lifecycle {
NotSet = 0,
Setting,
Set,
Dead,
}
enum OnceLockFreeInternalError {
AlreadySet,
AttemptToReadWhenUnset,
AttemptToSetConcurrently,
UseAfterFreeBug,
UnpreparedForSet,
}
#[derive(Debug, PartialEq, Eq)]
pub enum OnceLockFreeError {
AlreadySet,
AttemptToReadWhenUnset,
AttemptToSetConcurrently,
UnpreparedForSet,
}
impl Error for OnceLockFreeError {}
impl Display for OnceLockFreeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
fn panic_on_memory_bug(err: OnceLockFreeInternalError) -> OnceLockFreeError {
match err {
OnceLockFreeInternalError::AlreadySet => OnceLockFreeError::AlreadySet,
OnceLockFreeInternalError::AttemptToReadWhenUnset => {
OnceLockFreeError::AttemptToReadWhenUnset
}
OnceLockFreeInternalError::AttemptToSetConcurrently => {
OnceLockFreeError::AttemptToSetConcurrently
}
OnceLockFreeInternalError::UseAfterFreeBug => {
panic!("Encountered use-after-free in OnceLockFree");
}
OnceLockFreeInternalError::UnpreparedForSet => OnceLockFreeError::UnpreparedForSet,
}
}
#[derive(Default)]
struct OnceLockFreeState<T> {
flag_ptr: FlagPtr<Align8<T>>,
}
pub struct OnceLockFree<T> {
inner: Atom<OnceLockFreeState<T>, u64>,
}
impl<'a, T> OnceLockFree<T> {
pub fn new() -> Self {
Default::default()
}
pub fn get_or_prepare_to_set(&'a self) -> Result<Option<&'a T>, OnceLockFreeError> {
unsafe {
Ok(
atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::NotSet) => {
s.flag_ptr.set_flag(Lifecycle::Setting.into());
(true, Ok(None))
}
Ok(Lifecycle::Setting) => (
false,
Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
),
Ok(Lifecycle::Set) => {
let ptr = s.flag_ptr.get_ptr();
(false, Ok(if ptr.is_null() { None } else { Some(ptr) }))
}
Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
Err(_) => {
panic!("torn read?")
}
})
.map_err(panic_on_memory_bug)?
.map(|ptr| &(*ptr).inner),
)
}
}
pub fn get(&'a self) -> Result<&'a T, OnceLockFreeError> {
match self.get_or_seal()? {
Some(t) => Ok(t),
None => Err(OnceLockFreeInternalError::AttemptToReadWhenUnset),
}
.map_err(panic_on_memory_bug)
}
pub fn get_poll(&'a self) -> Option<&'a T> {
unsafe {
atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::Set) => {
let ptr = s.flag_ptr.get_ptr();
(false, if ptr.is_null() { None } else { Some(ptr) })
}
_ => (false, None),
})
.map(|ptr| &(*ptr).inner)
}
}
pub fn get_or_seal(&'a self) -> Result<Option<&'a T>, OnceLockFreeError> {
unsafe {
Ok(
atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::NotSet) => {
s.flag_ptr.set_flag(Lifecycle::Set.into());
s.flag_ptr.set_ptr(null_mut());
(true, Ok(None))
}
Ok(Lifecycle::Setting) => (
false,
Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
),
Ok(Lifecycle::Set) => {
let ptr = s.flag_ptr.get_ptr();
(false, Ok(if ptr.is_null() { None } else { Some(ptr) }))
}
Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
Err(_) => {
panic!("torn read?")
}
})
.map_err(panic_on_memory_bug)?
.map(|ptr| (&(*ptr).inner)),
)
}
}
pub fn set_prepared(&'a self, val: T) -> Result<&'a T, OnceLockFreeError> {
let ptr: *mut Align8<T> = Box::into_raw(Box::new(val.into()));
unsafe {
atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::NotSet) => (false, Err(OnceLockFreeInternalError::UnpreparedForSet)),
Ok(Lifecycle::Setting) => {
s.flag_ptr.set_flag(Lifecycle::Set.into());
s.flag_ptr.set_ptr(ptr);
(true, Ok(()))
}
Ok(Lifecycle::Set) => (false, Err(OnceLockFreeInternalError::AlreadySet)),
Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
Err(_) => {
panic!("torn read?")
}
})
.map_err(panic_on_memory_bug)?;
Ok(&(*ptr).inner)
}
}
pub fn set(&'a self, val: T) -> Result<&'a T, OnceLockFreeError> {
let ptr: *mut Align8<T> = Box::into_raw(Box::new(val.into()));
unsafe {
atomic_try_update(&self.inner, |s| match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::NotSet) => {
s.flag_ptr.set_flag(Lifecycle::Set.into());
s.flag_ptr.set_ptr(ptr);
(true, Ok(()))
}
Ok(Lifecycle::Setting) => (
false,
Err(OnceLockFreeInternalError::AttemptToSetConcurrently),
),
Ok(Lifecycle::Set) => (false, Err(OnceLockFreeInternalError::AlreadySet)),
Ok(Lifecycle::Dead) => (false, Err(OnceLockFreeInternalError::UseAfterFreeBug)),
Err(_) => {
panic!("torn read?")
}
})
.map_err(panic_on_memory_bug)?;
Ok(&(*ptr).inner)
}
}
}
impl<T> Default for OnceLockFree<T> {
fn default() -> Self {
Self {
inner: Default::default(),
}
}
}
impl<T> Drop for OnceLockFree<T> {
fn drop(&mut self) {
unsafe {
match atomic_try_update(&self.inner, |s| {
match s.flag_ptr.get_flag().try_into() {
Ok(Lifecycle::NotSet) => {
s.flag_ptr.set_flag(Lifecycle::Dead.into());
(true, Ok(None))
}
Ok(Lifecycle::Setting) => {
s.flag_ptr.set_flag(Lifecycle::Dead.into());
(true, Ok(None))
}
Ok(Lifecycle::Set) => {
s.flag_ptr.set_flag(Lifecycle::Dead.into());
let ptr = s.flag_ptr.get_ptr();
(
true,
if ptr.is_null() {
Ok(None)
} else {
Ok(Some(ptr))
},
)
}
Ok(Lifecycle::Dead) => {
(false, Err(OnceLockFreeInternalError::UseAfterFreeBug))
}
Err(_) => {
(true, Ok(None)) }
}
})
.map_err(panic_on_memory_bug)
.unwrap()
{
None => (),
Some(ptr) => {
let _drop = Box::from_raw(ptr);
}
};
}
}
}