use core::cell::UnsafeCell;
use core::future::Future;
use core::mem::MaybeUninit;
use core::pin::Pin;
use core::sync::atomic::Ordering;
use core::task::{Context, Poll, Waker};
use alloc::boxed::Box;
use crate::sync::mutual::Mutual;
use crate::sync::{InitState, Waiter, Waiters, WaitersExt};
pub struct OnceLock<T> {
state: Mutual<InitState>,
value: UnsafeCell<MaybeUninit<T>>,
waiters: Waiters,
}
unsafe impl<T: Send> Send for OnceLock<T> {}
unsafe impl<T: Send + Sync> Sync for OnceLock<T> {}
impl<T> OnceLock<T> {
pub const fn new() -> Self {
Self {
state: Mutual::new(),
value: UnsafeCell::new(MaybeUninit::uninit()),
waiters: Waiters::new(),
}
}
pub fn get(&self) -> Option<&T> {
if self.state.is(&InitState::Initialized) {
Some(unsafe { (*self.value.get()).assume_init_ref() })
} else {
None
}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
if self.state.is(&InitState::Initialized) {
Some(unsafe { (*self.value.get()).assume_init_mut() })
} else {
None
}
}
pub async fn get_or_init<F, Fut>(&self, f: F) -> &T
where
F: FnOnce() -> Fut,
Fut: Future<Output = T>,
{
if self.state.is(&InitState::Initialized) {
return unsafe { (*self.value.get()).assume_init_ref() };
}
if self.state.is(&InitState::Uninitialized) {
match self.state.compare_exchange(
InitState::Uninitialized,
InitState::Initializing,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
let value = f().await;
unsafe {
(*self.value.get()).write(value);
}
self.state.set(InitState::Initialized);
self.waiters.notify_all();
return unsafe { (*self.value.get()).assume_init_ref() };
}
Err(_) => {}
}
}
Wait { once: self }.await
}
pub async fn get_or_try_init<F, Fut, E>(&self, f: F) -> Result<&T, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<T, E>>,
{
if self.state.is(&InitState::Initialized) {
return Ok(unsafe { (*self.value.get()).assume_init_ref() });
}
if self.state.is(&InitState::Uninitialized) {
match self.state.compare_exchange(
InitState::Uninitialized,
InitState::Initializing,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => match f().await {
Ok(value) => {
unsafe {
(*self.value.get()).write(value);
}
self.state.set(InitState::Initialized);
self.waiters.notify_all();
return Ok(unsafe { (*self.value.get()).assume_init_ref() });
}
Err(e) => {
self.state.set(InitState::Uninitialized);
self.waiters.notify_all();
return Err(e);
}
},
Err(_) => {}
}
}
Ok(Wait { once: self }.await)
}
pub fn set(&self, value: T) -> Result<(), T> {
match self.state.compare_exchange(
InitState::Uninitialized,
InitState::Initializing,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
unsafe {
(*self.value.get()).write(value);
}
self.state.set(InitState::Initialized);
self.waiters.notify_all();
Ok(())
}
Err(_) => Err(value),
}
}
pub fn take(&mut self) -> Option<T> {
if self.state.is(&InitState::Initialized) {
self.state.set(InitState::Uninitialized);
Some(unsafe { (*self.value.get()).assume_init_read() })
} else {
None
}
}
pub fn into_inner(mut self) -> Option<T> {
self.take()
}
}
impl<T> Drop for OnceLock<T> {
fn drop(&mut self) {
if self.state.is(&InitState::Initialized) {
unsafe {
(*self.value.get()).assume_init_drop();
}
}
}
}
enum InitResult {
Initialized,
Failed,
}
struct Wait<'a, T> {
once: &'a OnceLock<T>,
}
impl<'a, T> Future for Wait<'a, T> {
type Output = &'a T;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
if self.once.state.is(&InitState::Initialized) {
return Poll::Ready(unsafe { (*self.once.value.get()).assume_init_ref() });
}
if self.once.state.is(&InitState::Uninitialized) {
panic!("AsyncOnceLock initialization failed");
}
if self.once.state.is(&InitState::Initializing) {
unsafe {
self.once
.waiters
.enqueue(Waiter::from_waker(cx.waker().clone()))
};
if self.once.state.is(&InitState::Initialized) {
return Poll::Ready(unsafe { (*self.once.value.get()).assume_init_ref() });
}
return Poll::Pending;
}
}
}
}
mod tests {
use super::*;
#[test]
fn test_sync_get() {
let once = OnceLock::<i32>::new();
assert_eq!(once.get(), None);
let _ = once.set(42);
assert_eq!(once.get(), Some(&42));
}
}