use std::{
cell::UnsafeCell,
convert::Infallible,
future::Future,
mem,
panic::{RefUnwindSafe, UnwindSafe},
pin::Pin,
ptr,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
sync::{Arc, Mutex},
task,
};
#[derive(Debug)]
pub struct OnceCell<T> {
value: UnsafeCell<Option<T>>,
inner: Inner,
}
unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
unsafe impl<T: Send> Send for OnceCell<T> {}
impl<T> Unpin for OnceCell<T> {}
impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
#[derive(Debug)]
struct Inner {
state: AtomicUsize,
queue: AtomicPtr<Queue>,
}
struct Queue {
wakers: Mutex<Option<Vec<task::Waker>>>,
}
struct QueueRef<'a> {
inner: &'a Inner,
queue: *const Queue,
}
unsafe impl<'a> Sync for QueueRef<'a> {}
unsafe impl<'a> Send for QueueRef<'a> {}
#[derive(Debug)]
struct QuickInitGuard<'a>(&'a Inner);
struct QueueWaiter<'a> {
guard: Option<QueueRef<'a>>,
}
struct QueueHead<'a> {
guard: QueueRef<'a>,
}
const NEW: usize = 0x0;
const QINIT_BIT: usize = 1 + (usize::MAX >> 2);
const READY_BIT: usize = 1 + (usize::MAX >> 1);
impl Inner {
const fn new() -> Self {
Inner { state: AtomicUsize::new(NEW), queue: AtomicPtr::new(ptr::null_mut()) }
}
const fn new_ready() -> Self {
Inner { state: AtomicUsize::new(READY_BIT), queue: AtomicPtr::new(ptr::null_mut()) }
}
#[cold]
fn initialize(&self, try_quick: bool) -> Result<QueueWaiter, QuickInitGuard> {
if try_quick {
if self
.state
.compare_exchange(NEW, QINIT_BIT, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
{
return Err(QuickInitGuard(self));
}
}
let prev_state = self.state.fetch_add(1, Ordering::Acquire);
let mut guard = QueueRef { inner: self, queue: self.queue.load(Ordering::Acquire) };
if guard.queue.is_null() && prev_state & READY_BIT == 0 {
let wakers = Mutex::new(None);
let new_queue = Box::into_raw(Box::new(Queue { wakers }));
match self.queue.compare_exchange(
ptr::null_mut(),
new_queue,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_null) => {
guard.queue = new_queue;
}
Err(actual) => {
guard.queue = actual;
unsafe {
Box::from_raw(new_queue);
}
}
}
}
Ok(QueueWaiter { guard: Some(guard) })
}
fn set_ready(&self) {
let prev_state = self.state.fetch_or(READY_BIT, Ordering::Release);
debug_assert_eq!(prev_state & READY_BIT, 0, "Invalid state: somoene else set READY_BIT");
}
}
impl<'a> Drop for QueueRef<'a> {
fn drop(&mut self) {
let prev_state = self.inner.state.fetch_sub(1, Ordering::Release);
let curr_state = prev_state - 1;
if curr_state == READY_BIT || curr_state == READY_BIT | QINIT_BIT {
let queue = self.inner.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Box::from_raw(queue);
}
}
}
}
}
impl<'a> Drop for QuickInitGuard<'a> {
fn drop(&mut self) {
let prev_state = self.0.state.load(Ordering::Relaxed);
if prev_state == QINIT_BIT | READY_BIT || prev_state == QINIT_BIT {
let target = prev_state & !QINIT_BIT;
if self
.0
.state
.compare_exchange(prev_state, target, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
if target == READY_BIT {
let queue = self.0.queue.swap(ptr::null_mut(), Ordering::Relaxed);
if !queue.is_null() {
std::sync::atomic::fence(Ordering::Acquire);
unsafe {
Box::from_raw(queue);
}
}
}
return;
}
}
let waiter = self.0.initialize(false).expect("Got a QuickInitGuard in slow init");
let guard = waiter.guard.expect("No guard available even without polling");
if guard.queue.is_null() {
drop(guard);
} else {
let queue = unsafe { &*guard.queue };
let mut lock = queue.wakers.lock().unwrap();
lock.get_or_insert_with(Vec::new);
self.0.state.fetch_and(!QINIT_BIT, Ordering::Relaxed);
drop(lock);
drop(QueueHead { guard })
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
let queue = *self.queue.get_mut();
if !queue.is_null() {
unsafe {
Box::from_raw(queue);
}
}
}
}
impl<'a> Future for QueueWaiter<'a> {
type Output = Option<QueueHead<'a>>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> task::Poll<Option<QueueHead<'a>>> {
let guard = self.guard.as_ref().expect("Polled future after finished");
let state = guard.inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
let queue = unsafe { &*guard.queue };
let mut lock = queue.wakers.lock().unwrap();
let state = guard.inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
match lock.as_mut() {
None if state & QINIT_BIT == 0 => {
*lock = Some(Vec::new());
drop(lock);
task::Poll::Ready(Some(QueueHead { guard: self.guard.take().unwrap() }))
}
None => {
let waker = cx.waker().clone();
*lock = Some(vec![waker]);
task::Poll::Pending
}
Some(wakers) => {
let my_waker = cx.waker();
for waker in wakers.iter() {
if waker.will_wake(my_waker) {
return task::Poll::Pending;
}
}
wakers.push(my_waker.clone());
task::Poll::Pending
}
}
}
}
impl<'a> Drop for QueueHead<'a> {
fn drop(&mut self) {
if let Some(queue) = unsafe { self.guard.queue.as_ref() } {
let wakers = queue
.wakers
.lock()
.expect("Lock poisoned")
.take()
.expect("QueueHead dropped without a waker list");
for waker in wakers {
waker.wake();
}
}
}
}
impl<T> OnceCell<T> {
pub const fn new() -> Self {
Self { value: UnsafeCell::new(None), inner: Inner::new() }
}
pub const fn new_with(value: Option<T>) -> Self {
let inner = match value {
Some(_) => Inner::new_ready(),
None => Inner::new(),
};
Self { value: UnsafeCell::new(value), inner }
}
pub async fn get_or_init(&self, init: impl Future<Output = T>) -> &T {
match self.get_or_try_init(async move { Ok::<T, Infallible>(init.await) }).await {
Ok(t) => t,
Err(e) => match e {},
}
}
pub async fn get_or_try_init<E>(
&self,
init: impl Future<Output = Result<T, E>>,
) -> Result<&T, E> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
self.init_slow(state == NEW, init).await?;
}
Ok(unsafe { (&*self.value.get()).as_ref().unwrap() })
}
#[cold]
async fn init_slow<E>(
&self,
try_quick: bool,
init: impl Future<Output = Result<T, E>>,
) -> Result<(), E> {
match self.inner.initialize(try_quick) {
Err(guard) => {
let value = init.await?;
unsafe {
*self.value.get() = Some(value);
}
self.inner.set_ready();
drop(guard);
}
Ok(guard) => {
if let Some(init_lock) = guard.await {
let value = init.await?;
unsafe {
*self.value.get() = Some(value);
}
init_lock.guard.inner.set_ready();
} else {
}
}
}
Ok(())
}
pub fn get(&self) -> Option<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
None
} else {
unsafe { (&*self.value.get()).as_ref() }
}
}
pub fn get_mut(&mut self) -> Option<&mut T> {
self.value.get_mut().as_mut()
}
pub fn take(&mut self) -> Option<T> {
self.value.get_mut().take()
}
pub fn into_inner(self) -> Option<T> {
self.value.into_inner()
}
}
#[derive(Debug)]
pub struct OnceFuture<T, F = Pin<Box<dyn Future<Output = T> + Send>>, I = Infallible> {
value: UnsafeCell<LazyState<T, I>>,
inner: LazyInner<F>,
}
unsafe impl<T: Sync + Send, F: Send, I: Send> Sync for OnceFuture<T, F, I> {}
unsafe impl<T: Send, F: Send, I: Send> Send for OnceFuture<T, F, I> {}
impl<T, F, I> Unpin for OnceFuture<T, F, I> {}
impl<T: RefUnwindSafe + UnwindSafe, F, I: RefUnwindSafe> RefUnwindSafe for OnceFuture<T, F, I> {}
impl<T: UnwindSafe, F, I: UnwindSafe> UnwindSafe for OnceFuture<T, F, I> {}
enum LazyState<T, I> {
New(I),
Running,
Ready(T),
}
#[derive(Debug)]
struct LazyInner<F> {
state: AtomicUsize,
queue: AtomicPtr<LazyWaker<F>>,
}
struct LazyWaker<F> {
future: UnsafeCell<Option<F>>,
wakers: Mutex<(WakerState, Vec<task::Waker>)>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum WakerState {
Unlocked,
LockedWithoutWake,
Pending,
LockedWoken,
}
unsafe impl<F: Send> Send for LazyWaker<F> {}
unsafe impl<F: Send> Sync for LazyWaker<F> {}
struct LazyHead<'a, F> {
waker: &'a Arc<LazyWaker<F>>,
}
impl<F> LazyInner<F> {
fn initialize(&self) -> Option<Arc<LazyWaker<F>>> {
let prev_state = self.state.fetch_add(1, Ordering::Acquire);
let mut queue = self.queue.load(Ordering::Acquire);
if queue.is_null() && prev_state & READY_BIT == 0 {
let waker: LazyWaker<F> = LazyWaker {
future: UnsafeCell::new(None),
wakers: Mutex::new((WakerState::Unlocked, Vec::new())),
};
let new_queue = Arc::into_raw(Arc::new(waker)) as *mut _;
match self.queue.compare_exchange(
ptr::null_mut(),
new_queue,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_null) => {
queue = new_queue;
}
Err(actual) => {
queue = actual;
unsafe {
Arc::from_raw(new_queue as *const _);
}
}
}
}
let rv = if queue.is_null() {
None
} else {
unsafe {
Arc::increment_strong_count(queue as *const _);
Some(Arc::from_raw(queue as *const _))
}
};
let prev_state = self.state.fetch_sub(1, Ordering::AcqRel);
if prev_state & READY_BIT == 0 {
debug_assert!(rv.is_some());
rv
} else {
if prev_state == READY_BIT + 1 {
let queue = self.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue as *const _);
}
}
}
None
}
}
fn set_ready(&self) {
let prev_state = self.state.fetch_or(READY_BIT, Ordering::Release);
debug_assert_eq!(prev_state & READY_BIT, 0, "Invalid state: somoene else set READY_BIT");
if prev_state == NEW {
let queue = self.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue as *const _);
}
}
}
}
}
impl<F> Drop for LazyInner<F> {
fn drop(&mut self) {
let queue = *self.queue.get_mut();
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue);
}
}
}
}
impl<F> LazyWaker<F> {
fn poll_head<'a>(
self: &'a Arc<Self>,
cx: &mut task::Context<'_>,
inner: &LazyInner<F>,
) -> task::Poll<Option<LazyHead<'a, F>>> {
let mut lock = self.wakers.lock().unwrap();
let state = inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
let wakers = &mut lock.1;
let my_waker = cx.waker();
for waker in wakers.iter() {
if waker.will_wake(my_waker) {
return task::Poll::Pending;
}
}
wakers.push(my_waker.clone());
match lock.0 {
WakerState::Unlocked => {
lock.0 = WakerState::LockedWithoutWake;
task::Poll::Ready(Some(LazyHead { waker: self }))
}
_ => {
task::Poll::Pending
}
}
}
}
impl<F> task::Wake for LazyWaker<F> {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
let mut lock = self.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWithoutWake => {
lock.0 = WakerState::LockedWoken;
return;
}
WakerState::LockedWoken => return,
WakerState::Pending => {
lock.0 = WakerState::Unlocked;
}
WakerState::Unlocked => {
}
}
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
}
impl<'a, F> LazyHead<'a, F> {
fn poll_inner(self, init: impl FnOnce() -> F) -> task::Poll<(Self, F::Output)>
where
F: Future + Send + 'static,
{
let ptr = self.waker.future.get();
let fut = unsafe { Pin::new_unchecked((*ptr).get_or_insert_with(init)) };
let shared_waker = task::Waker::from(Arc::clone(self.waker));
let mut ctx = task::Context::from_waker(&shared_waker);
match fut.poll(&mut ctx) {
task::Poll::Pending => {
let mut lock = self.waker.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWithoutWake => {
lock.0 = WakerState::Pending;
drop(lock);
}
WakerState::LockedWoken => {
lock.0 = WakerState::Unlocked;
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
WakerState::Pending | WakerState::Unlocked => {
unreachable!();
}
}
mem::forget(self);
task::Poll::Pending
}
task::Poll::Ready(value) => {
unsafe {
*ptr = None;
}
task::Poll::Ready((self, value))
}
}
}
}
impl<'a, F> Drop for LazyHead<'a, F> {
fn drop(&mut self) {
let mut lock = self.waker.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWoken | WakerState::LockedWithoutWake => {
lock.0 = WakerState::Unlocked;
}
WakerState::Unlocked | WakerState::Pending => {
unreachable!();
}
}
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
}
impl<T, F, I> OnceFuture<T, F, I> {
pub const fn with_init(init: I) -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::New(init)),
inner: LazyInner {
state: AtomicUsize::new(NEW),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub const fn with_no_init() -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::Running),
inner: LazyInner {
state: AtomicUsize::new(NEW),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub const fn with_value(value: T) -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::Ready(value)),
inner: LazyInner {
state: AtomicUsize::new(READY_BIT),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub fn get(&self) -> Option<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
None
} else {
unsafe {
match &*self.value.get() {
LazyState::Ready(v) => Some(v),
_ => unreachable!(),
}
}
}
}
pub fn get_mut(&mut self) -> (Option<&mut I>, Option<&mut T>) {
match self.value.get_mut() {
LazyState::New(i) => (Some(i), None),
LazyState::Running => (None, None),
LazyState::Ready(v) => (None, Some(v)),
}
}
pub fn into_inner(self) -> (Option<I>, Option<T>) {
match self.value.into_inner() {
LazyState::New(i) => (Some(i), None),
LazyState::Running => (None, None),
LazyState::Ready(v) => (None, Some(v)),
}
}
}
impl<T, F> OnceFuture<T, F> {
pub const fn new() -> Self {
Self::with_no_init()
}
}
impl<F> OnceFuture<F::Output, F>
where
F: Future + Send + 'static,
{
pub fn from_future(future: F) -> Self {
let rv = Self::new();
let waker = rv.inner.initialize().unwrap();
unsafe {
*waker.future.get() = Some(future);
}
rv
}
}
impl<T, F, I> OnceFuture<T, F, I>
where
F: Future<Output = T> + Send + 'static,
{
pub async fn get_or_init_with(&self, gen_future: impl FnOnce() -> F) -> &T {
self.get_or_populate_with(move |_| gen_future()).await
}
pub async fn get_or_populate_with(&self, into_future: impl FnOnce(Option<I>) -> F) -> &T {
struct Get<'a, T, F, I, P>(&'a OnceFuture<T, F, I>, Option<P>);
impl<'a, T, F, I, P> Unpin for Get<'a, T, F, I, P> {}
impl<'a, T, F, I, P> Future for Get<'a, T, F, I, P>
where
F: Future<Output = T> + Send + 'static,
P: FnOnce(Option<I>) -> F,
{
type Output = &'a T;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.0.poll_populate(cx, |i| (self.1.take().unwrap())(i))
}
}
Get(self, Some(into_future)).await
}
pub fn poll_populate(
&self,
cx: &mut task::Context<'_>,
into_future: impl FnOnce(Option<I>) -> F,
) -> task::Poll<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
match self.init_slow(cx, into_future) {
task::Poll::Pending => return task::Poll::Pending,
task::Poll::Ready(()) => {}
}
}
unsafe {
match &*self.value.get() {
LazyState::Ready(v) => task::Poll::Ready(v),
_ => unreachable!(),
}
}
}
#[cold]
fn init_slow(
&self,
cx: &mut task::Context<'_>,
into_future: impl FnOnce(Option<I>) -> F,
) -> task::Poll<()> {
let waker = self.inner.initialize();
let waker = match waker {
Some(waker) => waker,
None => return task::Poll::Ready(()),
};
match waker.poll_head(cx, &self.inner) {
task::Poll::Ready(Some(init_lock)) => {
let value = mem::replace(unsafe { &mut *self.value.get() }, LazyState::Running);
let init = match value {
LazyState::New(init) => Some(init),
LazyState::Running => None,
LazyState::Ready(_) => unreachable!(),
};
match init_lock.poll_inner(move || into_future(init)) {
task::Poll::Ready((lock, value)) => {
unsafe {
*self.value.get() = LazyState::Ready(value);
}
self.inner.set_ready();
drop(lock);
}
task::Poll::Pending => return task::Poll::Pending,
}
}
task::Poll::Ready(None) => return task::Poll::Ready(()),
task::Poll::Pending => return task::Poll::Pending,
}
task::Poll::Ready(())
}
}
#[derive(Debug)]
pub struct Lazy<T, F = Pin<Box<dyn Future<Output = T> + Send>>> {
once: OnceFuture<T, F>,
}
impl<T, F> Lazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
pub fn new(future: F) -> Self {
Lazy { once: OnceFuture::from_future(future) }
}
pub async fn get(&self) -> &T {
self.await
}
}
impl<T, F> Lazy<T, F> {
pub const fn with_value(value: T) -> Self {
Self { once: OnceFuture::with_value(value) }
}
pub fn try_get(&self) -> Option<&T> {
self.once.get()
}
pub fn try_get_mut(&mut self) -> Option<&mut T> {
self.once.get_mut().1
}
pub fn into_value(self) -> Option<T> {
self.once.into_inner().1
}
}
impl<'a, T, F> Future for &'a Lazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
type Output = &'a T;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.once.poll_populate(cx, |_| unreachable!())
}
}
#[derive(Debug)]
pub struct ConstLazy<T, F> {
once: OnceFuture<T, F, F>,
}
impl<T, F> ConstLazy<T, F> {
pub const fn new(future: F) -> Self {
ConstLazy { once: OnceFuture::with_init(future) }
}
pub const fn with_value(value: T) -> Self {
Self { once: OnceFuture::with_value(value) }
}
pub fn try_get(&self) -> Option<&T> {
self.once.get()
}
pub fn try_get_mut(&mut self) -> Option<&mut T> {
self.once.get_mut().1
}
pub fn into_value(self) -> Option<T> {
self.once.into_inner().1
}
}
impl<T, F> ConstLazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
pub async fn get(&self) -> &T {
self.await
}
}
impl<'a, T, F> Future for &'a ConstLazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
type Output = &'a T;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.once.poll_populate(cx, |i| i.unwrap_or_else(|| unreachable!()))
}
}