use alloc::boxed::Box;
use core::{
future::Future,
mem::{self, ManuallyDrop},
pin::Pin,
task::{ready, Context, Poll},
};
use async_lock::futures::Lock;
use crate::{
markers::ParallelSend,
runtime::{schedular::SchedularPoll, InnerRuntime},
AsyncContext, Ctx,
};
pub struct WithFuture<'a, F, R> {
context: &'a AsyncContext,
lock_state: LockState<'a>,
state: WithFutureState<'a, F, R>,
}
enum LockState<'a> {
Initial,
Pending(ManuallyDrop<Lock<'a, InnerRuntime>>),
}
impl<'a> Drop for LockState<'a> {
fn drop(&mut self) {
if let LockState::Pending(ref mut x) = self {
unsafe { ManuallyDrop::drop(x) }
}
}
}
enum WithFutureState<'a, F, R> {
Initial {
closure: F,
},
FutureCreated {
future: Pin<Box<dyn Future<Output = R> + 'a>>,
},
Done,
}
impl<'a, F, R> WithFuture<'a, F, R>
where
F: for<'js> AsyncFnOnce(Ctx<'js>) -> R + ParallelSend,
R: ParallelSend,
{
pub fn new(context: &'a AsyncContext, f: F) -> Self {
Self {
context,
lock_state: LockState::Initial,
state: WithFutureState::Initial { closure: f },
}
}
}
impl<'a, F, R> Future for WithFuture<'a, F, R>
where
F: for<'js> AsyncFnOnce(Ctx<'js>) -> R + ParallelSend + 'a,
R: ParallelSend + 'static,
{
type Output = R;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let mut lock = loop {
if let LockState::Pending(ref mut fut) = &mut this.lock_state {
let pin = unsafe { Pin::new_unchecked(&mut **fut) };
let lock = ready!(pin.poll(cx));
let old = mem::replace(&mut this.lock_state, LockState::Initial);
drop(old);
break lock;
} else {
this.lock_state =
LockState::Pending(ManuallyDrop::new(this.context.0.rt().inner.lock()));
}
};
lock.runtime.update_stack_top();
let mut future = match mem::replace(&mut this.state, WithFutureState::Done) {
WithFutureState::Initial { closure } => {
let ctx = unsafe { Ctx::new_async(this.context) };
Box::pin(closure(ctx))
}
WithFutureState::FutureCreated { future } => future,
WithFutureState::Done => panic!("With future called after it returned"),
};
let res = loop {
let mut made_progress = false;
if let Poll::Ready(x) = future.as_mut().poll(cx) {
break Poll::Ready(x);
};
let opaque = lock.runtime.get_opaque();
match opaque.poll(cx) {
SchedularPoll::Empty => {
}
SchedularPoll::ShouldYield => {
this.state = WithFutureState::FutureCreated { future };
return Poll::Pending;
}
SchedularPoll::Pending => {
}
SchedularPoll::PendingProgress => {
made_progress = true;
}
};
loop {
match lock.runtime.execute_pending_job() {
Ok(false) => break,
Ok(true) => made_progress = true,
Err(_ctx) => {
made_progress = true;
}
}
}
if !made_progress {
this.state = WithFutureState::FutureCreated { future };
break Poll::Pending;
}
};
mem::drop(lock);
res
}
}
#[cfg(feature = "parallel")]
unsafe impl<F, R> Send for WithFuture<'_, F, R> {}