use std::{
cell::RefCell,
collections::HashMap,
future::Future,
marker::PhantomData,
pin::Pin,
ptr,
rc::Rc,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
use corosensei::{Coroutine, CoroutineResult, Yielder};
use super::entities::function::Function as SysFunction;
use crate::{
AsStoreMut, AsStoreRef, ForcedStoreInstallGuard, LocalRwLockWriteGuard, RuntimeError, Store,
StoreAsync, StoreContext, StoreInner, StoreMut, StoreRef, Value,
};
use wasmer_types::StoreId;
type HostFuture = Pin<Box<dyn Future<Output = Result<Vec<Value>, RuntimeError>> + 'static>>;
pub(crate) fn call_function_async(
function: SysFunction,
store: StoreAsync,
params: Vec<Value>,
) -> AsyncCallFuture {
AsyncCallFuture::new(function, store, params)
}
struct AsyncYield(HostFuture);
enum AsyncResume {
Start,
HostFutureReady(Result<Vec<Value>, RuntimeError>),
}
#[allow(clippy::type_complexity)]
pub(crate) struct AsyncCallFuture {
coroutine: Option<Coroutine<AsyncResume, AsyncYield, Result<Box<[Value]>, RuntimeError>>>,
pending_store_install: Option<Pin<Box<dyn Future<Output = ForcedStoreInstallGuard>>>>,
pending_future: Option<HostFuture>,
next_resume: Option<AsyncResume>,
result: Option<Result<Box<[Value]>, RuntimeError>>,
store: StoreAsync,
}
struct AsyncCallStoreMut {
store_id: StoreId,
}
impl AsStoreRef for AsyncCallStoreMut {
fn as_store_ref(&self) -> StoreRef<'_> {
unsafe {
StoreRef {
inner: StoreContext::get_current_transient(self.store_id)
.as_ref()
.unwrap(),
}
}
}
}
impl AsStoreMut for AsyncCallStoreMut {
fn as_store_mut(&mut self) -> StoreMut<'_> {
unsafe {
StoreMut {
inner: StoreContext::get_current_transient(self.store_id)
.as_mut()
.unwrap(),
}
}
}
fn objects_mut(&mut self) -> &mut crate::StoreObjects {
unsafe {
&mut StoreContext::get_current_transient(self.store_id)
.as_mut()
.unwrap()
.objects
}
}
}
impl AsyncCallFuture {
pub(crate) fn new(function: SysFunction, store: StoreAsync, params: Vec<Value>) -> Self {
let store_id = store.id;
let coroutine =
Coroutine::new(move |yielder: &Yielder<AsyncResume, AsyncYield>, resume| {
assert!(matches!(resume, AsyncResume::Start));
let ctx_state = CoroutineContext::new(yielder);
ctx_state.enter();
let result = {
let mut store_mut = AsyncCallStoreMut { store_id };
function.call(&mut store_mut, ¶ms)
};
ctx_state.leave();
result
});
Self {
coroutine: Some(coroutine),
pending_store_install: None,
pending_future: None,
next_resume: Some(AsyncResume::Start),
result: None,
store,
}
}
}
impl Future for AsyncCallFuture {
type Output = Result<Box<[Value]>, RuntimeError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
if let Some(future) = self.pending_future.as_mut() {
match future.as_mut().poll(cx) {
Poll::Ready(result) => {
self.pending_future = None;
self.next_resume = Some(AsyncResume::HostFutureReady(result));
}
Poll::Pending => return Poll::Pending,
}
}
if self.coroutine.is_none() {
return Poll::Ready(self.result.take().expect("polled after completion"));
}
if self.pending_store_install.is_none() {
self.pending_store_install = Some(Box::pin(install_store_context(StoreAsync {
id: self.store.id,
inner: self.store.inner.clone(),
})));
}
let store_context_guard = match self
.pending_store_install
.as_mut()
.unwrap()
.as_mut()
.poll(cx)
{
Poll::Ready(guard) => {
self.pending_store_install = None;
guard
}
Poll::Pending => return Poll::Pending,
};
let resume_arg = self.next_resume.take().expect("no resume arg available");
let coroutine = self.coroutine.as_mut().unwrap();
match coroutine.resume(resume_arg) {
CoroutineResult::Yield(AsyncYield(fut)) => {
self.pending_future = Some(fut);
}
CoroutineResult::Return(result) => {
self.coroutine = None;
self.result = Some(result);
}
}
drop(store_context_guard);
}
}
}
async fn install_store_context(store: StoreAsync) -> ForcedStoreInstallGuard {
match unsafe { crate::StoreContext::try_get_current_async(store.id) } {
crate::GetStoreAsyncGuardResult::NotInstalled => {
let store_guard = store.inner.write().await;
unsafe { crate::StoreContext::install_async(store_guard) }
}
_ => {
panic!(
"Function::call_async futures cannot be polled recursively \
from within another imported function. If you need to await \
a recursive call_async, consider spawning the future into \
your async runtime and awaiting the resulting task; \
e.g. tokio::task::spawn(func.call_async(...)).await"
);
}
}
}
pub enum AsyncRuntimeError {
YieldOutsideAsyncContext,
RuntimeError(RuntimeError),
}
pub(crate) fn block_on_host_future<Fut>(future: Fut) -> Result<Vec<Value>, AsyncRuntimeError>
where
Fut: Future<Output = Result<Vec<Value>, RuntimeError>> + 'static,
{
CURRENT_CONTEXT.with(|cell| {
match CoroutineContext::get_current() {
None => {
run_immediate(future)
}
Some(context) => unsafe { context.as_ref().expect("valid context pointer") }
.block_on_future(Box::pin(future))
.map_err(AsyncRuntimeError::RuntimeError),
}
})
}
thread_local! {
static CURRENT_CONTEXT: RefCell<Vec<*const CoroutineContext>> = const { RefCell::new(Vec::new()) };
}
struct CoroutineContext {
yielder: *const Yielder<AsyncResume, AsyncYield>,
}
impl CoroutineContext {
fn new(yielder: &Yielder<AsyncResume, AsyncYield>) -> Self {
Self {
yielder: yielder as *const _,
}
}
fn enter(&self) {
CURRENT_CONTEXT.with(|cell| {
let mut borrow = cell.borrow_mut();
borrow.push(self as *const _);
})
}
fn leave(&self) {
CURRENT_CONTEXT.with(|cell| {
let mut borrow = cell.borrow_mut();
assert_eq!(
borrow.pop(),
Some(self as *const _),
"Active coroutine stack corrupted"
);
});
}
fn get_current() -> Option<*const Self> {
CURRENT_CONTEXT.with(|cell| cell.borrow().last().copied())
}
fn block_on_future(&self, future: HostFuture) -> Result<Vec<Value>, RuntimeError> {
self.leave();
let yielder = unsafe { self.yielder.as_ref().expect("yielder pointer valid") };
let result = match yielder.suspend(AsyncYield(future)) {
AsyncResume::HostFutureReady(result) => result,
AsyncResume::Start => unreachable!("coroutine resumed without start"),
};
self.enter();
result
}
}
fn run_immediate(
future: impl Future<Output = Result<Vec<Value>, RuntimeError>> + 'static,
) -> Result<Vec<Value>, AsyncRuntimeError> {
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let mut future = Box::pin(future);
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => result.map_err(AsyncRuntimeError::RuntimeError),
Poll::Pending => Err(AsyncRuntimeError::YieldOutsideAsyncContext),
}
}