use std::{
cell::RefCell,
collections::HashMap,
future::Future,
marker::PhantomData,
pin::Pin,
ptr,
rc::Rc,
sync::{LazyLock, atomic::AtomicU64},
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};
use corosensei::{Coroutine, CoroutineResult, Yielder};
use dashmap::DashMap;
use super::entities::function::Function as SysFunction;
use crate::{
AsStoreAsync, AsStoreMut, AsStoreRef, ForcedStoreInstallGuard, LocalRwLockWriteGuard,
RuntimeError, Store, StoreAsync, StoreContext, StoreInner, StoreMut, StoreRef, Value,
};
use wasmer_types::StoreId;
type HostFuture = Pin<Box<dyn Future<Output = Result<Vec<Value>, RuntimeError>> + 'static>>;
pub(crate) fn call_function_async(
function: SysFunction,
store: StoreAsync,
params: Vec<Value>,
) -> AsyncCallFuture {
AsyncCallFuture::new(function, store, params)
}
struct AsyncYield(HostFuture);
enum AsyncResume {
Start,
HostFutureReady(Result<Vec<Value>, RuntimeError>),
}
type AsyncCallFutureId = u64;
static NEXT_ASYNC_CALL_FUTURE_ID: AtomicU64 = AtomicU64::new(0);
static ASYNC_CALL_FUTURE_WAKERS: LazyLock<DashMap<StoreId, HashMap<AsyncCallFutureId, Waker>>> =
LazyLock::new(DashMap::new);
#[allow(clippy::type_complexity)]
pub(crate) struct AsyncCallFuture {
id: AsyncCallFutureId,
coroutine: Option<Coroutine<AsyncResume, AsyncYield, Result<Box<[Value]>, RuntimeError>>>,
pending_store_install: Option<Pin<Box<dyn Future<Output = ForcedStoreInstallGuard>>>>,
pending_future: Option<HostFuture>,
next_resume: Option<AsyncResume>,
result: Option<Result<Box<[Value]>, RuntimeError>>,
store: StoreAsync,
}
struct AsyncCallStoreMut {
store_id: StoreId,
}
impl AsStoreRef for AsyncCallStoreMut {
fn as_store_ref(&self) -> StoreRef<'_> {
unsafe {
StoreRef {
inner: StoreContext::get_current_transient(self.store_id)
.as_ref()
.unwrap(),
}
}
}
}
impl AsStoreMut for AsyncCallStoreMut {
fn as_store_mut(&mut self) -> StoreMut<'_> {
unsafe {
StoreMut {
inner: StoreContext::get_current_transient(self.store_id)
.as_mut()
.unwrap(),
}
}
}
fn objects_mut(&mut self) -> &mut crate::StoreObjects {
unsafe {
&mut StoreContext::get_current_transient(self.store_id)
.as_mut()
.unwrap()
.objects
}
}
}
impl AsyncCallFuture {
pub(crate) fn new(function: SysFunction, store: StoreAsync, params: Vec<Value>) -> Self {
let store_id = store.id;
let coroutine =
Coroutine::new(move |yielder: &Yielder<AsyncResume, AsyncYield>, resume| {
assert!(matches!(resume, AsyncResume::Start));
let ctx_state = CoroutineContext::new(yielder);
ctx_state.enter();
let result = {
let mut store_mut = AsyncCallStoreMut { store_id };
function.call(&mut store_mut, ¶ms)
};
ctx_state.leave();
result
});
Self {
id: NEXT_ASYNC_CALL_FUTURE_ID.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
coroutine: Some(coroutine),
pending_store_install: None,
pending_future: None,
next_resume: Some(AsyncResume::Start),
result: None,
store,
}
}
fn remove_from_wakers_list(&self) {
let mut wakers_entry = match ASYNC_CALL_FUTURE_WAKERS.entry(self.store.store_id()) {
dashmap::Entry::Occupied(o) => o,
dashmap::Entry::Vacant(v) => return,
};
let mut waker_map_ref = wakers_entry.get_mut();
waker_map_ref.remove(&self.id);
if waker_map_ref.is_empty() {
wakers_entry.remove();
}
}
}
impl Future for AsyncCallFuture {
type Output = Result<Box<[Value]>, RuntimeError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let store_id = self.store.store_id();
#[cfg(feature = "experimental-host-interrupt")]
{
if super::vm::interrupt_registry::is_interrupted(store_id) {
self.remove_from_wakers_list();
return Poll::Ready(Err(super::vm::Trap::lib(
super::vm::TrapCode::HostInterrupt,
)
.into()));
}
}
if let Some(future) = self.pending_future.as_mut() {
match future.as_mut().poll(cx) {
Poll::Ready(result) => {
self.pending_future = None;
self.next_resume = Some(AsyncResume::HostFutureReady(result));
}
Poll::Pending => return Poll::Pending,
}
}
if self.coroutine.is_none() {
self.remove_from_wakers_list();
return Poll::Ready(self.result.take().expect("polled after completion"));
}
{
let mut wakers_entry = ASYNC_CALL_FUTURE_WAKERS.entry(store_id).or_default();
wakers_entry.insert(self.id, cx.waker().clone());
}
if self.pending_store_install.is_none() {
self.pending_store_install = Some(Box::pin(install_store_context(StoreAsync {
id: self.store.id,
inner: self.store.inner.clone(),
})));
}
let store_context_guard = match self
.pending_store_install
.as_mut()
.unwrap()
.as_mut()
.poll(cx)
{
Poll::Ready(guard) => {
self.pending_store_install = None;
guard
}
Poll::Pending => return Poll::Pending,
};
let resume_arg = self.next_resume.take().expect("no resume arg available");
let coroutine = self.coroutine.as_mut().unwrap();
match coroutine.resume(resume_arg) {
CoroutineResult::Yield(AsyncYield(fut)) => {
self.pending_future = Some(fut);
}
CoroutineResult::Return(result) => {
self.coroutine = None;
self.result = Some(result);
}
}
drop(store_context_guard);
}
}
}
impl Drop for AsyncCallFuture {
fn drop(&mut self) {
self.remove_from_wakers_list();
}
}
async fn install_store_context(store: StoreAsync) -> ForcedStoreInstallGuard {
match unsafe { crate::StoreContext::try_get_current_async(store.id) } {
crate::GetStoreAsyncGuardResult::NotInstalled => {
let store_guard = store.inner.write().await;
unsafe { crate::StoreContext::install_async(store_guard) }
}
_ => {
panic!(
"Function::call_async futures cannot be polled recursively \
from within another imported function. If you need to await \
a recursive call_async, consider spawning the future into \
your async runtime and awaiting the resulting task; \
e.g. tokio::task::spawn(func.call_async(...)).await"
);
}
}
}
pub enum AsyncRuntimeError {
YieldOutsideAsyncContext,
RuntimeError(RuntimeError),
}
pub(crate) fn block_on_host_future<Fut>(future: Fut) -> Result<Vec<Value>, AsyncRuntimeError>
where
Fut: Future<Output = Result<Vec<Value>, RuntimeError>> + 'static,
{
match CoroutineContext::get_current() {
None => {
run_immediate(future)
}
Some(context) => unsafe { context.as_ref().expect("valid context pointer") }
.block_on_future(Box::pin(future))
.map_err(AsyncRuntimeError::RuntimeError),
}
}
pub(crate) fn notify_pending_futures_of_interrupt(store_id: StoreId) {
let dashmap::Entry::Occupied(entry) = ASYNC_CALL_FUTURE_WAKERS.entry(store_id) else {
return;
};
for waker in entry.get().values() {
waker.wake_by_ref();
}
}
thread_local! {
static CURRENT_CONTEXT: RefCell<Vec<*const CoroutineContext>> = const { RefCell::new(Vec::new()) };
}
struct CoroutineContext {
yielder: *const Yielder<AsyncResume, AsyncYield>,
}
impl CoroutineContext {
fn new(yielder: &Yielder<AsyncResume, AsyncYield>) -> Self {
Self {
yielder: yielder as *const _,
}
}
fn enter(&self) {
CURRENT_CONTEXT.with(|cell| {
let mut borrow = cell.borrow_mut();
borrow.push(self as *const _);
})
}
fn leave(&self) {
CURRENT_CONTEXT.with(|cell| {
let mut borrow = cell.borrow_mut();
assert_eq!(
borrow.pop(),
Some(self as *const _),
"Active coroutine stack corrupted"
);
});
}
fn get_current() -> Option<*const Self> {
CURRENT_CONTEXT.with(|cell| cell.borrow().last().copied())
}
fn block_on_future(&self, future: HostFuture) -> Result<Vec<Value>, RuntimeError> {
self.leave();
let yielder = unsafe { self.yielder.as_ref().expect("yielder pointer valid") };
let result = match yielder.suspend(AsyncYield(future)) {
AsyncResume::HostFutureReady(result) => result,
AsyncResume::Start => unreachable!("coroutine resumed without start"),
};
self.enter();
result
}
}
fn run_immediate(
future: impl Future<Output = Result<Vec<Value>, RuntimeError>> + 'static,
) -> Result<Vec<Value>, AsyncRuntimeError> {
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let mut future = Box::pin(future);
match future.as_mut().poll(&mut cx) {
Poll::Ready(result) => result.map_err(AsyncRuntimeError::RuntimeError),
Poll::Pending => Err(AsyncRuntimeError::YieldOutsideAsyncContext),
}
}