use core::cell::UnsafeCell;
use core::convert::Infallible;
use core::fmt;
use core::future::Future;
use core::mem::{forget, MaybeUninit};
use core::ptr;
use core::sync::atomic::{AtomicUsize, Ordering};
#[cfg(all(feature = "std", not(target_family = "wasm")))]
use core::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use event_listener::{Event, EventListener};
use event_listener_strategy::{NonBlocking, Strategy};
#[derive(Copy, Clone, PartialEq, Eq)]
#[repr(usize)]
enum State {
Uninitialized = 0,
Initializing = 1,
Initialized = 2,
}
impl From<usize> for State {
fn from(val: usize) -> Self {
match val {
0 => State::Uninitialized,
1 => State::Initializing,
2 => State::Initialized,
_ => unreachable!("Invalid state"),
}
}
}
impl From<State> for usize {
fn from(val: State) -> Self {
val as usize
}
}
pub struct OnceCell<T> {
active_initializers: Event,
passive_waiters: Event,
state: AtomicUsize,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Send for OnceCell<T> {}
unsafe impl<T: Send + Sync> Sync for OnceCell<T> {}
impl<T> OnceCell<T> {
pub const fn new() -> Self {
Self {
active_initializers: Event::new(),
passive_waiters: Event::new(),
state: AtomicUsize::new(State::Uninitialized as _),
value: UnsafeCell::new(MaybeUninit::uninit()),
}
}
pub fn is_initialized(&self) -> bool {
State::from(self.state.load(Ordering::Acquire)) == State::Initialized
}
pub fn get(&self) -> Option<&T> {
if self.is_initialized() {
Some(unsafe { self.get_unchecked() })
} else {
None
}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
if State::from(*self.state.get_mut()) == State::Initialized {
Some(unsafe { &mut *self.value.get().cast() })
} else {
None
}
}
pub fn take(&mut self) -> Option<T> {
if State::from(*self.state.get_mut()) == State::Initialized {
let value = unsafe { ptr::read(self.value.get().cast()) };
*self.state.get_mut() = State::Uninitialized.into();
Some(value)
} else {
None
}
}
pub fn into_inner(mut self) -> Option<T> {
self.take()
}
pub async fn wait(&self) -> &T {
if let Some(value) = self.get() {
return value;
}
let listener = EventListener::new();
pin!(listener);
listener.as_mut().listen(&self.passive_waiters);
if let Some(value) = self.get() {
return value;
}
listener.await;
debug_assert!(self.is_initialized());
unsafe { self.get_unchecked() }
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn wait_blocking(&self) -> &T {
if let Some(value) = self.get() {
return value;
}
let listener = EventListener::new();
pin!(listener);
listener.as_mut().listen(&self.passive_waiters);
if let Some(value) = self.get() {
return value;
}
listener.wait();
debug_assert!(self.is_initialized());
unsafe { self.get_unchecked() }
}
pub async fn get_or_try_init<E, Fut: Future<Output = Result<T, E>>>(
&self,
closure: impl FnOnce() -> Fut,
) -> Result<&T, E> {
if let Some(value) = self.get() {
return Ok(value);
}
self.initialize_or_wait(closure, &mut NonBlocking::default())
.await?;
debug_assert!(self.is_initialized());
Ok(unsafe { self.get_unchecked() })
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn get_or_try_init_blocking<E>(
&self,
closure: impl FnOnce() -> Result<T, E>,
) -> Result<&T, E> {
if let Some(value) = self.get() {
return Ok(value);
}
now_or_never(self.initialize_or_wait(
move || core::future::ready(closure()),
&mut event_listener_strategy::Blocking::default(),
))?;
debug_assert!(self.is_initialized());
Ok(unsafe { self.get_unchecked() })
}
pub async fn get_or_init<Fut: Future<Output = T>>(&self, closure: impl FnOnce() -> Fut) -> &T {
match self
.get_or_try_init(move || async move {
let result: Result<T, Infallible> = Ok(closure().await);
result
})
.await
{
Ok(value) => value,
Err(infallible) => match infallible {},
}
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn get_or_init_blocking(&self, closure: impl FnOnce() -> T + Unpin) -> &T {
match self.get_or_try_init_blocking(move || {
let result: Result<T, Infallible> = Ok(closure());
result
}) {
Ok(value) => value,
Err(infallible) => match infallible {},
}
}
pub async fn set(&self, value: T) -> Result<&T, T> {
let mut value = Some(value);
self.get_or_init(|| async { value.take().unwrap() }).await;
match value {
Some(value) => Err(value),
None => {
Ok(unsafe { self.get_unchecked() })
}
}
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
pub fn set_blocking(&self, value: T) -> Result<&T, T> {
let mut value = Some(value);
self.get_or_init_blocking(|| value.take().unwrap());
match value {
Some(value) => Err(value),
None => {
Ok(unsafe { self.get_unchecked() })
}
}
}
#[cold]
async fn initialize_or_wait<E, Fut: Future<Output = Result<T, E>>, F: FnOnce() -> Fut>(
&self,
closure: F,
strategy: &mut impl for<'a> Strategy<'a>,
) -> Result<(), E> {
let event_listener = EventListener::new();
pin!(event_listener);
let mut closure = Some(closure);
loop {
let state = self.state.load(Ordering::Acquire);
match state.into() {
State::Initialized => {
return Ok(());
}
State::Initializing => {
if event_listener.is_listening() {
strategy.wait(event_listener.as_mut()).await;
} else {
event_listener.as_mut().listen(&self.active_initializers);
}
}
State::Uninitialized => {
if self
.state
.compare_exchange(
State::Uninitialized.into(),
State::Initializing.into(),
Ordering::AcqRel,
Ordering::Acquire,
)
.is_err()
{
continue;
}
let _guard = Guard(self);
let initializer = closure.take().unwrap();
match (initializer)().await {
Ok(value) => {
unsafe {
ptr::write(self.value.get().cast(), value);
}
forget(_guard);
self.state
.store(State::Initialized.into(), Ordering::Release);
self.active_initializers.notify_additional(core::usize::MAX);
self.passive_waiters.notify_additional(core::usize::MAX);
return Ok(());
}
Err(err) => {
drop(_guard);
return Err(err);
}
}
}
}
}
struct Guard<'a, T>(&'a OnceCell<T>);
impl<'a, T> Drop for Guard<'a, T> {
fn drop(&mut self) {
self.0
.state
.store(State::Uninitialized.into(), Ordering::Release);
self.0.active_initializers.notify(1);
}
}
}
pub unsafe fn get_unchecked(&self) -> &T {
&*self.value.get().cast()
}
}
impl<T> From<T> for OnceCell<T> {
fn from(value: T) -> Self {
Self {
active_initializers: Event::new(),
passive_waiters: Event::new(),
state: AtomicUsize::new(State::Initialized.into()),
value: UnsafeCell::new(MaybeUninit::new(value)),
}
}
}
impl<T: fmt::Debug> fmt::Debug for OnceCell<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Inner<'a, T>(&'a OnceCell<T>);
impl<T: fmt::Debug> fmt::Debug for Inner<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0.state.load(Ordering::Acquire).into() {
State::Uninitialized => f.write_str("<uninitialized>"),
State::Initializing => f.write_str("<initializing>"),
State::Initialized => {
let value = unsafe { self.0.get_unchecked() };
fmt::Debug::fmt(value, f)
}
}
}
}
f.debug_tuple("OnceCell").field(&Inner(self)).finish()
}
}
impl<T> Drop for OnceCell<T> {
fn drop(&mut self) {
if State::from(*self.state.get_mut()) == State::Initialized {
unsafe { self.value.get().cast::<T>().drop_in_place() }
}
}
}
impl<T> Default for OnceCell<T> {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[cfg(all(feature = "std", not(target_family = "wasm")))]
fn now_or_never<T>(f: impl Future<Output = T>) -> T {
const NOOP_WAKER: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
unsafe fn wake(_: *const ()) {}
unsafe fn wake_by_ref(_: *const ()) {}
unsafe fn clone(_: *const ()) -> RawWaker {
RawWaker::new(ptr::null(), &NOOP_WAKER)
}
unsafe fn drop(_: *const ()) {}
pin!(f);
let waker = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &NOOP_WAKER)) };
let mut cx = Context::from_waker(&waker);
match f.poll(&mut cx) {
Poll::Ready(value) => value,
Poll::Pending => unreachable!("future not ready"),
}
}