use std::{
cell::UnsafeCell,
convert::Infallible,
future::Future,
mem,
panic::{RefUnwindSafe, UnwindSafe},
pin::Pin,
ptr,
sync::atomic::{AtomicPtr, AtomicUsize, Ordering},
sync::{Arc, Mutex},
task,
};
use super::{NEW, READY_BIT};
#[derive(Debug)]
pub struct OnceFuture<T, F = Pin<Box<dyn Future<Output = T> + Send>>, I = Infallible> {
value: UnsafeCell<LazyState<T, I>>,
inner: LazyInner<F>,
}
unsafe impl<T: Sync + Send, F: Send, I: Send> Sync for OnceFuture<T, F, I> {}
unsafe impl<T: Send, F: Send, I: Send> Send for OnceFuture<T, F, I> {}
impl<T, F, I> Unpin for OnceFuture<T, F, I> {}
impl<T: RefUnwindSafe + UnwindSafe, F, I: RefUnwindSafe> RefUnwindSafe for OnceFuture<T, F, I> {}
impl<T: UnwindSafe, F, I: UnwindSafe> UnwindSafe for OnceFuture<T, F, I> {}
enum LazyState<T, I> {
New(I),
Running,
Ready(T),
}
#[derive(Debug)]
struct LazyInner<F> {
state: AtomicUsize,
queue: AtomicPtr<LazyWaker<F>>,
}
struct LazyWaker<F> {
future: UnsafeCell<Option<F>>,
wakers: Mutex<(WakerState, Vec<task::Waker>)>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum WakerState {
Unlocked,
LockedWithoutWake,
Pending,
LockedWoken,
}
unsafe impl<F: Send> Send for LazyWaker<F> {}
unsafe impl<F: Send> Sync for LazyWaker<F> {}
struct LazyHead<'a, F> {
waker: &'a Arc<LazyWaker<F>>,
}
impl<F> LazyInner<F> {
fn initialize(&self) -> Option<Arc<LazyWaker<F>>> {
let prev_state = self.state.fetch_add(1, Ordering::Acquire);
let mut queue = self.queue.load(Ordering::Acquire);
if queue.is_null() && prev_state & READY_BIT == 0 {
let waker: LazyWaker<F> = LazyWaker {
future: UnsafeCell::new(None),
wakers: Mutex::new((WakerState::Unlocked, Vec::new())),
};
let new_queue = Arc::into_raw(Arc::new(waker)) as *mut _;
match self.queue.compare_exchange(
ptr::null_mut(),
new_queue,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_null) => {
queue = new_queue;
}
Err(actual) => {
queue = actual;
unsafe {
Arc::from_raw(new_queue as *const _);
}
}
}
}
let rv = if queue.is_null() {
None
} else {
unsafe {
Arc::increment_strong_count(queue as *const _);
Some(Arc::from_raw(queue as *const _))
}
};
let prev_state = self.state.fetch_sub(1, Ordering::AcqRel);
if prev_state & READY_BIT == 0 {
debug_assert!(rv.is_some());
rv
} else {
if prev_state == READY_BIT + 1 {
let queue = self.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue as *const _);
}
}
}
None
}
}
fn set_ready(&self) {
let prev_state = self.state.fetch_or(READY_BIT, Ordering::Release);
debug_assert_eq!(prev_state & READY_BIT, 0, "Invalid state: somoene else set READY_BIT");
if prev_state == NEW {
let queue = self.queue.swap(ptr::null_mut(), Ordering::Acquire);
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue as *const _);
}
}
}
}
}
impl<F> Drop for LazyInner<F> {
fn drop(&mut self) {
let queue = *self.queue.get_mut();
if !queue.is_null() {
unsafe {
Arc::decrement_strong_count(queue);
}
}
}
}
impl<F> LazyWaker<F> {
fn poll_head<'a>(
self: &'a Arc<Self>,
cx: &mut task::Context<'_>,
inner: &LazyInner<F>,
) -> task::Poll<Option<LazyHead<'a, F>>> {
let mut lock = self.wakers.lock().unwrap();
let state = inner.state.load(Ordering::Acquire);
if state & READY_BIT != 0 {
return task::Poll::Ready(None);
}
let wakers = &mut lock.1;
let my_waker = cx.waker();
for waker in wakers.iter() {
if waker.will_wake(my_waker) {
return task::Poll::Pending;
}
}
wakers.push(my_waker.clone());
match lock.0 {
WakerState::Unlocked => {
lock.0 = WakerState::LockedWithoutWake;
task::Poll::Ready(Some(LazyHead { waker: self }))
}
_ => {
task::Poll::Pending
}
}
}
}
impl<F> task::Wake for LazyWaker<F> {
fn wake(self: Arc<Self>) {
self.wake_by_ref()
}
fn wake_by_ref(self: &Arc<Self>) {
let mut lock = self.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWithoutWake => {
lock.0 = WakerState::LockedWoken;
return;
}
WakerState::LockedWoken => return,
WakerState::Pending => {
lock.0 = WakerState::Unlocked;
}
WakerState::Unlocked => {
}
}
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
}
impl<'a, F> LazyHead<'a, F> {
fn poll_inner(self, init: impl FnOnce() -> F) -> task::Poll<(Self, F::Output)>
where
F: Future + Send + 'static,
{
let ptr = self.waker.future.get();
let fut = unsafe { Pin::new_unchecked((*ptr).get_or_insert_with(init)) };
let shared_waker = task::Waker::from(Arc::clone(self.waker));
let mut ctx = task::Context::from_waker(&shared_waker);
match fut.poll(&mut ctx) {
task::Poll::Pending => {
let mut lock = self.waker.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWithoutWake => {
lock.0 = WakerState::Pending;
drop(lock);
}
WakerState::LockedWoken => {
lock.0 = WakerState::Unlocked;
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
WakerState::Pending | WakerState::Unlocked => {
unreachable!();
}
}
mem::forget(self);
task::Poll::Pending
}
task::Poll::Ready(value) => {
unsafe {
*ptr = None;
}
task::Poll::Ready((self, value))
}
}
}
}
impl<'a, F> Drop for LazyHead<'a, F> {
fn drop(&mut self) {
let mut lock = self.waker.wakers.lock().unwrap();
match lock.0 {
WakerState::LockedWoken | WakerState::LockedWithoutWake => {
lock.0 = WakerState::Unlocked;
}
WakerState::Unlocked | WakerState::Pending => {
unreachable!();
}
}
let wakers = mem::replace(&mut lock.1, Vec::new());
drop(lock);
for waker in wakers {
waker.wake();
}
}
}
impl<T, F, I> OnceFuture<T, F, I> {
pub const fn with_init(init: I) -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::New(init)),
inner: LazyInner {
state: AtomicUsize::new(NEW),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub const fn with_no_init() -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::Running),
inner: LazyInner {
state: AtomicUsize::new(NEW),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub const fn with_value(value: T) -> Self {
OnceFuture {
value: UnsafeCell::new(LazyState::Ready(value)),
inner: LazyInner {
state: AtomicUsize::new(READY_BIT),
queue: AtomicPtr::new(ptr::null_mut()),
},
}
}
pub fn get(&self) -> Option<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
None
} else {
unsafe {
match &*self.value.get() {
LazyState::Ready(v) => Some(v),
_ => unreachable!(),
}
}
}
}
pub fn get_mut(&mut self) -> (Option<&mut I>, Option<&mut T>) {
match self.value.get_mut() {
LazyState::New(i) => (Some(i), None),
LazyState::Running => (None, None),
LazyState::Ready(v) => (None, Some(v)),
}
}
pub fn into_inner(self) -> (Option<I>, Option<T>) {
match self.value.into_inner() {
LazyState::New(i) => (Some(i), None),
LazyState::Running => (None, None),
LazyState::Ready(v) => (None, Some(v)),
}
}
}
impl<T, F> OnceFuture<T, F> {
pub const fn new() -> Self {
Self::with_no_init()
}
}
impl<F> OnceFuture<F::Output, F>
where
F: Future + Send + 'static,
{
pub fn from_future(future: F) -> Self {
let rv = Self::new();
let waker = rv.inner.initialize().unwrap();
unsafe {
*waker.future.get() = Some(future);
}
rv
}
}
impl<T, F, I> OnceFuture<T, F, I>
where
F: Future<Output = T> + Send + 'static,
{
pub async fn get_or_init_with(&self, gen_future: impl FnOnce() -> F) -> &T {
self.get_or_populate_with(move |_| gen_future()).await
}
pub async fn get_or_populate_with(&self, into_future: impl FnOnce(Option<I>) -> F) -> &T {
struct Get<'a, T, F, I, P>(&'a OnceFuture<T, F, I>, Option<P>);
impl<'a, T, F, I, P> Unpin for Get<'a, T, F, I, P> {}
impl<'a, T, F, I, P> Future for Get<'a, T, F, I, P>
where
F: Future<Output = T> + Send + 'static,
P: FnOnce(Option<I>) -> F,
{
type Output = &'a T;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.0.poll_populate(cx, |i| (self.1.take().unwrap())(i))
}
}
Get(self, Some(into_future)).await
}
pub fn poll_populate(
&self,
cx: &mut task::Context<'_>,
into_future: impl FnOnce(Option<I>) -> F,
) -> task::Poll<&T> {
let state = self.inner.state.load(Ordering::Acquire);
if state & READY_BIT == 0 {
match self.init_slow(cx, into_future) {
task::Poll::Pending => return task::Poll::Pending,
task::Poll::Ready(()) => {}
}
}
unsafe {
match &*self.value.get() {
LazyState::Ready(v) => task::Poll::Ready(v),
_ => unreachable!(),
}
}
}
#[cold]
fn init_slow(
&self,
cx: &mut task::Context<'_>,
into_future: impl FnOnce(Option<I>) -> F,
) -> task::Poll<()> {
let waker = self.inner.initialize();
let waker = match waker {
Some(waker) => waker,
None => return task::Poll::Ready(()),
};
match waker.poll_head(cx, &self.inner) {
task::Poll::Ready(Some(init_lock)) => {
let value = mem::replace(unsafe { &mut *self.value.get() }, LazyState::Running);
let init = match value {
LazyState::New(init) => Some(init),
LazyState::Running => None,
LazyState::Ready(_) => unreachable!(),
};
match init_lock.poll_inner(move || into_future(init)) {
task::Poll::Ready((lock, value)) => {
unsafe {
*self.value.get() = LazyState::Ready(value);
}
self.inner.set_ready();
drop(lock);
}
task::Poll::Pending => return task::Poll::Pending,
}
}
task::Poll::Ready(None) => return task::Poll::Ready(()),
task::Poll::Pending => return task::Poll::Pending,
}
task::Poll::Ready(())
}
}
#[derive(Debug)]
pub struct Lazy<T, F = Pin<Box<dyn Future<Output = T> + Send>>> {
once: OnceFuture<T, F>,
}
impl<T, F> Lazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
pub fn new(future: F) -> Self {
Lazy { once: OnceFuture::from_future(future) }
}
pub async fn get(&self) -> &T {
self.await
}
}
impl<T, F> Lazy<T, F> {
pub const fn with_value(value: T) -> Self {
Self { once: OnceFuture::with_value(value) }
}
pub fn try_get(&self) -> Option<&T> {
self.once.get()
}
pub fn try_get_mut(&mut self) -> Option<&mut T> {
self.once.get_mut().1
}
pub fn into_value(self) -> Option<T> {
self.once.into_inner().1
}
}
impl<'a, T, F> Future for &'a Lazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
type Output = &'a T;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.once.poll_populate(cx, |_| unreachable!())
}
}
#[derive(Debug)]
pub struct ConstLazy<T, F> {
once: OnceFuture<T, F, F>,
}
impl<T, F> ConstLazy<T, F> {
pub const fn new(future: F) -> Self {
ConstLazy { once: OnceFuture::with_init(future) }
}
pub const fn with_value(value: T) -> Self {
Self { once: OnceFuture::with_value(value) }
}
pub fn try_get(&self) -> Option<&T> {
self.once.get()
}
pub fn try_get_mut(&mut self) -> Option<&mut T> {
self.once.get_mut().1
}
pub fn into_value(self) -> Option<T> {
self.once.into_inner().1
}
}
impl<T, F> ConstLazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
pub async fn get(&self) -> &T {
self.await
}
}
impl<'a, T, F> Future for &'a ConstLazy<T, F>
where
F: Future<Output = T> + Send + 'static,
{
type Output = &'a T;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<&'a T> {
self.once.poll_populate(cx, |i| i.unwrap_or_else(|| unreachable!()))
}
}