use crate::prelude::*;
use crate::store::{AsStoreOpaque, Asyncness, Executor, StoreId, StoreOpaque};
use crate::vm::mpk::{self, ProtectionMask};
use crate::vm::{AlwaysMut, AsyncWasmCallState};
use crate::{Engine, StoreContextMut};
use core::mem;
use core::ops::Range;
use core::pin::Pin;
use core::ptr::{self, NonNull};
use core::task::{Context, Poll};
use wasmtime_fiber::{Fiber, FiberStack, Suspend};
type WasmtimeResume = Result<NonNull<Context<'static>>>;
type WasmtimeYield = StoreFiberYield;
type WasmtimeComplete = Result<()>;
type WasmtimeSuspend = Suspend<WasmtimeResume, WasmtimeYield, WasmtimeComplete>;
type WasmtimeFiber<'a> = Fiber<'a, WasmtimeResume, WasmtimeYield, WasmtimeComplete>;
pub(crate) struct AsyncState {
current_suspend: Option<NonNull<WasmtimeSuspend>>,
current_future_cx: Option<NonNull<Context<'static>>>,
last_fiber_stack: Option<wasmtime_fiber::FiberStack>,
pub(crate) async_required: bool,
}
unsafe impl Send for AsyncState {}
unsafe impl Sync for AsyncState {}
impl Default for AsyncState {
fn default() -> Self {
Self {
current_suspend: None,
current_future_cx: None,
last_fiber_stack: None,
async_required: false,
}
}
}
impl AsyncState {
pub(crate) fn last_fiber_stack(&mut self) -> &mut Option<wasmtime_fiber::FiberStack> {
&mut self.last_fiber_stack
}
#[inline]
pub(crate) fn can_block(&mut self) -> bool {
self.current_future_cx.is_some()
}
}
pub(crate) struct BlockingContext<'a, 'b> {
suspend: &'a mut WasmtimeSuspend,
future_cx: Option<&'a mut Context<'b>>,
}
impl<'a, 'b> BlockingContext<'a, 'b> {
fn with<S, R>(store: &mut S, f: impl FnOnce(&mut S, &mut BlockingContext<'_, '_>) -> R) -> R
where
S: AsStoreOpaque,
{
let opaque = store.as_store_opaque();
let state = opaque.fiber_async_state_mut();
let future_cx = unsafe { Some(state.current_future_cx.take().unwrap().as_mut()) };
let suspend = unsafe { state.current_suspend.take().unwrap().as_mut() };
let mut reset = ResetBlockingContext {
store,
cx: BlockingContext { future_cx, suspend },
};
return f(&mut reset.store, &mut reset.cx);
struct ResetBlockingContext<'a, 'b, S: AsStoreOpaque> {
store: &'a mut S,
cx: BlockingContext<'a, 'b>,
}
impl<S: AsStoreOpaque> Drop for ResetBlockingContext<'_, '_, S> {
fn drop(&mut self) {
let store = self.store.as_store_opaque();
let state = store.fiber_async_state_mut();
debug_assert!(state.current_future_cx.is_none());
debug_assert!(state.current_suspend.is_none());
state.current_suspend = Some(NonNull::from(&mut *self.cx.suspend));
if let Some(cx) = &mut self.cx.future_cx {
state.current_future_cx =
Some(NonNull::from(unsafe { change_context_lifetime(cx) }));
}
}
}
}
pub(crate) fn block_on<F>(&mut self, future: F) -> Result<F::Output>
where
F: Future + Send,
{
let mut future = core::pin::pin!(future);
loop {
match future.as_mut().poll(self.future_cx.as_mut().unwrap()) {
Poll::Ready(v) => break Ok(v),
Poll::Pending => self.suspend(StoreFiberYield::KeepStore)?,
}
}
}
pub(crate) fn suspend(&mut self, yield_: StoreFiberYield) -> Result<()> {
self.future_cx.take();
let mut new_future_cx: NonNull<Context<'static>> = self.suspend.suspend(yield_)?;
unsafe {
self.future_cx = Some(change_context_lifetime(new_future_cx.as_mut()));
}
Ok(())
}
}
impl<T> StoreContextMut<'_, T> {
#[cfg(feature = "component-model")]
pub(crate) fn block_on<R>(
self,
f: impl FnOnce(StoreContextMut<'_, T>) -> Pin<Box<dyn Future<Output = R> + Send + '_>>,
) -> Result<R> {
self.with_blocking(|store, cx| cx.block_on(f(store).as_mut()))
}
pub(crate) fn with_blocking<R>(
self,
f: impl FnOnce(StoreContextMut<'_, T>, &mut BlockingContext<'_, '_>) -> R,
) -> R {
BlockingContext::with(self.0, |store, cx| f(StoreContextMut(store), cx))
}
}
impl StoreOpaque {
pub(crate) fn with_blocking<R>(
&mut self,
f: impl FnOnce(&mut Self, &mut BlockingContext<'_, '_>) -> R,
) -> R {
BlockingContext::with(self, |store, cx| f(store, cx))
}
pub(crate) fn set_async_required(&mut self, asyncness: Asyncness) {
match asyncness {
Asyncness::Yes => {
self.fiber_async_state_mut().async_required = true;
}
Asyncness::No => {}
}
}
}
pub(crate) enum StoreFiberYield {
KeepStore,
#[cfg(feature = "component-model-async")]
ReleaseStore,
}
pub(crate) struct StoreFiber<'a> {
fiber: Option<AlwaysMut<RawFiber<'a>>>,
state: Option<AlwaysMut<FiberResumeState>>,
engine: Engine,
id: StoreId,
}
struct RawFiber<'a>(WasmtimeFiber<'a>);
impl<'a> StoreFiber<'a> {
fn fiber(&mut self) -> Option<&mut WasmtimeFiber<'a>> {
Some(&mut self.fiber.as_mut()?.get_mut().0)
}
fn take_fiber_stack(&mut self) -> Option<FiberStack> {
self.fiber.take().map(|f| f.into_inner().0.into_stack())
}
pub(crate) fn dispose(&mut self, store: &mut StoreOpaque) {
if let Some(fiber) = self.fiber() {
if !fiber.done() {
let result = resume_fiber(store, self, Err(format_err!("future dropped")));
debug_assert!(result.is_ok());
}
}
}
}
impl Drop for StoreFiber<'_> {
fn drop(&mut self) {
if self.fiber.is_none() {
return;
}
assert!(
self.fiber().unwrap().done(),
"attempted to drop in-progress fiber without first calling `StoreFiber::dispose`"
);
self.state.take().unwrap().into_inner().dispose();
unsafe {
let stack = self.take_fiber_stack().unwrap();
self.engine.allocator().deallocate_fiber_stack(stack);
}
}
}
unsafe impl Send for RawFiber<'_> {}
struct FiberResumeState {
tls: crate::runtime::vm::AsyncWasmCallState,
mpk: Option<ProtectionMask>,
stack_limit: usize,
executor: Executor,
}
impl FiberResumeState {
unsafe fn replace(
self,
store: &mut StoreOpaque,
fiber: &mut StoreFiber<'_>,
) -> PriorFiberResumeState {
let tls = unsafe { self.tls.push() };
let mpk = swap_mpk_states(self.mpk);
let async_guard_range = fiber
.fiber()
.unwrap()
.stack()
.guard_range()
.unwrap_or(ptr::null_mut()..ptr::null_mut());
let mut executor = self.executor;
store.swap_executor(&mut executor);
PriorFiberResumeState {
tls,
mpk,
executor,
stack_limit: store.replace_stack_limit(self.stack_limit),
async_guard_range: store.replace_async_guard_range(async_guard_range),
current_suspend: store.replace_current_suspend(None),
current_future_cx: store.replace_current_future_cx(None),
}
}
fn dispose(self) {
self.tls.assert_null();
}
}
impl StoreOpaque {
fn replace_stack_limit(&mut self, stack_limit: usize) -> usize {
mem::replace(
&mut self.vm_store_context_mut().stack_limit.get_mut(),
stack_limit,
)
}
fn replace_async_guard_range(&mut self, range: Range<*mut u8>) -> Range<*mut u8> {
mem::replace(&mut self.vm_store_context_mut().async_guard_range, range)
}
fn replace_current_suspend(
&mut self,
ptr: Option<NonNull<WasmtimeSuspend>>,
) -> Option<NonNull<WasmtimeSuspend>> {
mem::replace(&mut self.fiber_async_state_mut().current_suspend, ptr)
}
fn replace_current_future_cx(
&mut self,
ptr: Option<NonNull<Context<'static>>>,
) -> Option<NonNull<Context<'static>>> {
mem::replace(&mut self.fiber_async_state_mut().current_future_cx, ptr)
}
}
struct PriorFiberResumeState {
tls: crate::runtime::vm::PreviousAsyncWasmCallState,
mpk: Option<ProtectionMask>,
stack_limit: usize,
async_guard_range: Range<*mut u8>,
current_suspend: Option<NonNull<WasmtimeSuspend>>,
current_future_cx: Option<NonNull<Context<'static>>>,
executor: Executor,
}
impl PriorFiberResumeState {
unsafe fn replace(self, store: &mut StoreOpaque) -> FiberResumeState {
let tls = unsafe { self.tls.restore() };
let mpk = swap_mpk_states(self.mpk);
let _my_guard = store.replace_async_guard_range(self.async_guard_range);
let prev = store.replace_current_suspend(self.current_suspend);
assert!(prev.is_none());
let prev = store.replace_current_future_cx(self.current_future_cx);
assert!(prev.is_none());
let mut executor = self.executor;
store.swap_executor(&mut executor);
FiberResumeState {
tls,
mpk,
executor,
stack_limit: store.replace_stack_limit(self.stack_limit),
}
}
}
fn swap_mpk_states(mask: Option<ProtectionMask>) -> Option<ProtectionMask> {
mask.map(|mask| {
let current = mpk::current_mask();
mpk::allow(mask);
current
})
}
fn resume_fiber<'a>(
store: &mut StoreOpaque,
fiber: &mut StoreFiber<'a>,
result: WasmtimeResume,
) -> Result<WasmtimeComplete, StoreFiberYield> {
assert_eq!(store.id(), fiber.id);
struct Restore<'a, 'b> {
store: &'b mut StoreOpaque,
fiber: &'b mut StoreFiber<'a>,
state: Option<PriorFiberResumeState>,
}
impl Drop for Restore<'_, '_> {
fn drop(&mut self) {
self.fiber.state =
Some(unsafe { self.state.take().unwrap().replace(self.store).into() });
}
}
let result = unsafe {
let prev = fiber
.state
.take()
.unwrap()
.into_inner()
.replace(store, fiber);
let restore = Restore {
store,
fiber,
state: Some(prev),
};
restore.fiber.fiber().unwrap().resume(result)
};
match &result {
Ok(_) => {
if let Some(stack) = fiber.take_fiber_stack() {
store.deallocate_fiber_stack(stack);
}
}
Err(_) => {
if let Some(range) = fiber.fiber().unwrap().stack().range() {
AsyncWasmCallState::assert_current_state_not_in_range(range);
}
}
}
result
}
pub(crate) unsafe fn make_fiber_unchecked<'a, S>(
store: &mut S,
fun: impl FnOnce(&mut S) -> Result<()> + Send + Sync + 'a,
) -> Result<StoreFiber<'a>>
where
S: AsStoreOpaque + ?Sized + 'a,
{
let opaque = store.as_store_opaque();
let engine = opaque.engine().clone();
let executor = Executor::new(&engine);
let id = opaque.id();
let stack = opaque.allocate_fiber_stack()?;
let track_pkey_context_switch = opaque.has_pkey();
let store = &raw mut *store;
let fiber = Fiber::new(stack, move |result: WasmtimeResume, suspend| {
let future_cx = match result {
Ok(cx) => cx,
Err(_) => return Ok(()),
};
let store_ref = unsafe { &mut *store };
let async_state = store_ref.as_store_opaque().fiber_async_state_mut();
assert!(async_state.current_suspend.is_none());
assert!(async_state.current_future_cx.is_none());
async_state.current_suspend = Some(NonNull::from(suspend));
async_state.current_future_cx = Some(future_cx);
struct ResetCurrentPointersToNull<'a, S>(&'a mut S)
where
S: AsStoreOpaque + ?Sized;
impl<S> Drop for ResetCurrentPointersToNull<'_, S>
where
S: AsStoreOpaque + ?Sized,
{
fn drop(&mut self) {
let state = self.0.as_store_opaque().fiber_async_state_mut();
debug_assert!(state.current_suspend.is_some());
state.current_suspend = None;
state.current_future_cx = None;
}
}
let reset = ResetCurrentPointersToNull(store_ref);
fun(reset.0)
})?;
Ok(StoreFiber {
state: Some(
FiberResumeState {
tls: crate::runtime::vm::AsyncWasmCallState::new(),
mpk: if track_pkey_context_switch {
Some(ProtectionMask::all())
} else {
None
},
stack_limit: usize::MAX,
executor,
}
.into(),
),
engine,
id,
fiber: Some(RawFiber(fiber).into()),
})
}
#[cfg(feature = "component-model-async")]
pub(crate) fn make_fiber<'a, S>(
store: &mut S,
fun: impl FnOnce(&mut S) -> Result<()> + Send + Sync + 'a,
) -> Result<StoreFiber<'a>>
where
S: AsStoreOpaque + Send + ?Sized + 'a,
{
unsafe { make_fiber_unchecked(store, fun) }
}
pub(crate) async fn on_fiber<S, R>(
store: &mut S,
func: impl FnOnce(&mut S) -> R + Send + Sync,
) -> Result<R>
where
S: AsStoreOpaque + ?Sized,
R: Send + Sync,
{
let opaque = store.as_store_opaque();
let config = opaque.engine().config();
debug_assert!(config.async_stack_size > 0);
let mut result = None;
let fiber = unsafe {
make_fiber_unchecked(store, |store| {
result = Some(func(store));
Ok(())
})?
};
{
let fiber = FiberFuture {
store: store.as_store_opaque(),
fiber: Some(fiber),
#[cfg(feature = "component-model-async")]
on_release: OnRelease::ReturnPending,
}
.await
.unwrap();
debug_assert!(fiber.is_none());
}
Ok(result.unwrap())
}
#[cfg(feature = "component-model-async")]
pub(crate) async fn resolve_or_release<'a>(
store: &mut StoreOpaque,
fiber: StoreFiber<'a>,
) -> Result<Option<StoreFiber<'a>>> {
FiberFuture {
store,
fiber: Some(fiber),
on_release: OnRelease::ReturnReady,
}
.await
}
#[cfg(feature = "component-model-async")]
enum OnRelease {
ReturnPending,
ReturnReady,
}
struct FiberFuture<'a, 'b> {
store: &'a mut StoreOpaque,
fiber: Option<StoreFiber<'b>>,
#[cfg(feature = "component-model-async")]
on_release: OnRelease,
}
impl<'b> Future for FiberFuture<'_, 'b> {
type Output = Result<Option<StoreFiber<'b>>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let me = self.get_mut();
let cx: &mut Context<'static> = unsafe { change_context_lifetime(cx) };
let cx = NonNull::from(cx);
match resume_fiber(me.store, me.fiber.as_mut().unwrap(), Ok(cx)) {
Ok(Ok(())) => Poll::Ready(Ok(None)),
Ok(Err(e)) => Poll::Ready(Err(e)),
Err(StoreFiberYield::KeepStore) => Poll::Pending,
#[cfg(feature = "component-model-async")]
Err(StoreFiberYield::ReleaseStore) => match &me.on_release {
OnRelease::ReturnPending => Poll::Pending,
OnRelease::ReturnReady => Poll::Ready(Ok(me.fiber.take())),
},
}
}
}
impl Drop for FiberFuture<'_, '_> {
fn drop(&mut self) {
if let Some(fiber) = &mut self.fiber {
fiber.dispose(self.store);
}
}
}
unsafe fn change_context_lifetime<'a, 'b>(cx: &'a mut Context<'_>) -> &'a mut Context<'b> {
unsafe { mem::transmute::<&mut Context<'_>, &mut Context<'b>>(cx) }
}