use crate::{
loom::sync::atomic::{AtomicU8, Ordering},
util::{Backoff, CheckedMaybeUninit},
};
use core::{
any,
cell::UnsafeCell,
fmt,
ops::{Deref, DerefMut},
};
pub struct InitOnce<T> {
value: UnsafeCell<CheckedMaybeUninit<T>>,
state: AtomicU8,
}
pub struct Lazy<T, F = fn() -> T> {
value: UnsafeCell<CheckedMaybeUninit<T>>,
state: AtomicU8,
initializer: F,
}
pub struct TryInitError<T> {
value: T,
actual: u8,
}
const UNINITIALIZED: u8 = 0;
const INITIALIZING: u8 = 1;
const INITIALIZED: u8 = 2;
impl<T> InitOnce<T> {
loom_const_fn! {
#[must_use]
pub fn uninitialized() -> Self {
Self {
value: UnsafeCell::new(CheckedMaybeUninit::uninit()),
state: AtomicU8::new(UNINITIALIZED),
}
}
}
pub fn try_init(&self, value: T) -> Result<(), TryInitError<T>> {
if let Err(actual) = self.state.compare_exchange(
UNINITIALIZED,
INITIALIZING,
Ordering::AcqRel,
Ordering::Acquire,
) {
return Err(TryInitError { value, actual });
};
unsafe {
*(self.value.get()) = CheckedMaybeUninit::new(value);
}
let _prev = self.state.swap(INITIALIZED, Ordering::AcqRel);
debug_assert_eq!(
_prev,
INITIALIZING,
"InitOnce<{}>: state changed while locked. This is a bug!",
any::type_name::<T>(),
);
Ok(())
}
#[track_caller]
pub fn init(&self, value: T) -> &T {
self.try_init(value).unwrap();
self.get()
}
#[inline]
#[must_use]
pub fn try_get(&self) -> Option<&T> {
if self.state.load(Ordering::Acquire) != INITIALIZED {
return None;
}
unsafe {
Some(&*((*self.value.get()).as_ptr()))
}
}
#[track_caller]
#[inline]
#[must_use]
pub fn get(&self) -> &T {
if self.state.load(Ordering::Acquire) != INITIALIZED {
panic!("InitOnce<{}> not yet initialized!", any::type_name::<T>());
}
unsafe {
&*((*self.value.get()).as_ptr())
}
}
#[must_use]
pub fn get_or_else(&self, f: impl FnOnce() -> T) -> &T {
if let Some(val) = self.try_get() {
return val;
}
let _ = self.try_init(f());
self.get()
}
#[cfg_attr(not(debug_assertions), inline(always))]
#[cfg_attr(debug_assertions, track_caller)]
#[must_use]
pub unsafe fn get_unchecked(&self) -> &T {
debug_assert_eq!(
INITIALIZED,
self.state.load(Ordering::Acquire),
"InitOnce<{}>: accessed before initialized!\n\
/!\\ EXTREMELY SERIOUS WARNING: /!\\ This is REAL BAD! If you were \
running in release mode, you would have just read uninitialized \
memory! That's bad news indeed, buddy. Double- or triple-check \
your assumptions, or consider Just Using A Goddamn Mutex --- it's \
much safer that way. Maybe this whole `InitOnce` thing was a \
mistake...
",
any::type_name::<T>(),
);
unsafe {
&*((*self.value.get()).as_ptr())
}
}
}
impl<T: fmt::Debug> fmt::Debug for InitOnce<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.state.load(Ordering::Acquire) {
INITIALIZED => self.get().fmt(f),
INITIALIZING => f.pad("<initializing>"),
UNINITIALIZED => f.pad("<uninitialized>"),
_state => unsafe {
unreachable_unchecked!("unexpected state value {}, this is a bug!", _state)
},
}
}
}
unsafe impl<T: Send> Send for InitOnce<T> {}
unsafe impl<T: Sync> Sync for InitOnce<T> {}
impl<T, F> Lazy<T, F> {
loom_const_fn! {
#[must_use]
pub fn new(initializer: F) -> Self {
Self {
value: UnsafeCell::new(CheckedMaybeUninit::uninit()),
state: AtomicU8::new(UNINITIALIZED),
initializer,
}
}
}
#[inline]
#[must_use]
pub fn get_if_present(&self) -> Option<&T> {
if self.state.load(Ordering::Acquire) == INITIALIZED {
let value = unsafe {
&*((*self.value.get()).as_ptr())
};
Some(value)
} else {
None
}
}
}
impl<T, F> Lazy<T, F>
where
F: Fn() -> T,
{
#[inline]
#[must_use]
pub fn get(&self) -> &T {
self.init();
unsafe {
&*((*self.value.get()).as_ptr())
}
}
#[inline]
#[must_use]
pub fn get_mut(&mut self) -> &mut T {
self.init();
unsafe {
&mut *((*self.value.get()).as_mut_ptr())
}
}
pub fn init(&self) {
let state = self.state.compare_exchange(
UNINITIALIZED,
INITIALIZING,
Ordering::AcqRel,
Ordering::Acquire,
);
match state {
Err(INITIALIZED) => {
}
Err(INITIALIZING) => {
let mut backoff = Backoff::new();
while self.state.load(Ordering::Acquire) != INITIALIZED {
backoff.spin();
}
}
Ok(_) => {
unsafe {
*(self.value.get()) = CheckedMaybeUninit::new((self.initializer)());
}
if let Err(actual) = self.state.compare_exchange(
INITIALIZING,
INITIALIZED,
Ordering::AcqRel,
Ordering::Acquire,
) {
unreachable!(
"Lazy<{}>: state changed while locked. This is a bug! (state={})",
any::type_name::<T>(),
actual
);
}
}
Err(_state) => unsafe {
unreachable_unchecked!(
"Lazy<{}>: unexpected state {}!. This is a bug!",
any::type_name::<T>(),
_state
)
},
};
}
}
impl<T, F> Deref for Lazy<T, F>
where
F: Fn() -> T,
{
type Target = T;
fn deref(&self) -> &Self::Target {
self.get()
}
}
impl<T, F> DerefMut for Lazy<T, F>
where
F: Fn() -> T,
{
fn deref_mut(&mut self) -> &mut Self::Target {
self.get_mut()
}
}
impl<T, F> fmt::Debug for Lazy<T, F>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.state.load(Ordering::Acquire) {
INITIALIZED => self
.get_if_present()
.expect("if state is `INITIALIZED`, value should be present")
.fmt(f),
INITIALIZING => f.pad("<initializing>"),
UNINITIALIZED => f.pad("<uninitialized>"),
_state => unsafe {
unreachable_unchecked!("unexpected state value {}, this is a bug!", _state)
},
}
}
}
unsafe impl<T: Send, F: Send> Send for Lazy<T, F> {}
unsafe impl<T: Sync, F: Sync> Sync for Lazy<T, F> {}
impl<T> TryInitError<T> {
#[must_use]
pub fn into_inner(self) -> T {
self.value
}
}
impl<T> fmt::Debug for TryInitError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TryInitError")
.field("type", &any::type_name::<T>())
.field("value", &format_args!("..."))
.field(
"state",
&format_args!("State::{}", match self.actual {
UNINITIALIZED => "UNINITIALIZED",
INITIALIZING => "INITIALIZING",
INITIALIZED => unsafe { unreachable_unchecked!("an error should not be returned when InitOnce is in the initialized state, this is a bug!") },
_state => unsafe { unreachable_unchecked!("unexpected state value {}, this is a bug!", _state) },
}),
)
.finish()
}
}
impl<T> fmt::Display for TryInitError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "InitOnce<{}> already initialized", any::type_name::<T>())
}
}
#[cfg(feature = "core-error")]
impl<T> core::error::Error for TryInitError<T> {}