use crate::RunResult;
use std::cell::Cell;
use std::io;
use std::ptr;
#[derive(Debug)]
pub struct FiberStack {
top: *mut u8,
len: Option<usize>,
}
impl FiberStack {
pub fn new(size: usize) -> io::Result<Self> {
let page_size = rustix::process::page_size();
let size = if size == 0 {
page_size
} else {
(size + (page_size - 1)) & (!(page_size - 1))
};
unsafe {
let mmap_len = size + page_size;
let mmap = rustix::io::mmap_anonymous(
ptr::null_mut(),
mmap_len,
rustix::io::ProtFlags::empty(),
rustix::io::MapFlags::PRIVATE,
)?;
rustix::io::mprotect(
mmap.cast::<u8>().add(page_size).cast(),
size,
rustix::io::MprotectFlags::READ | rustix::io::MprotectFlags::WRITE,
)?;
Ok(Self {
top: mmap.cast::<u8>().add(mmap_len),
len: Some(mmap_len),
})
}
}
pub unsafe fn from_top_ptr(top: *mut u8) -> io::Result<Self> {
Ok(Self { top, len: None })
}
pub fn top(&self) -> Option<*mut u8> {
Some(self.top)
}
}
impl Drop for FiberStack {
fn drop(&mut self) {
unsafe {
if let Some(len) = self.len {
let ret = rustix::io::munmap(self.top.sub(len) as _, len);
debug_assert!(ret.is_ok());
}
}
}
}
pub struct Fiber;
pub struct Suspend(*mut u8);
extern "C" {
fn wasmtime_fiber_init(
top_of_stack: *mut u8,
entry: extern "C" fn(*mut u8, *mut u8),
entry_arg0: *mut u8,
);
fn wasmtime_fiber_switch(top_of_stack: *mut u8);
}
extern "C" fn fiber_start<F, A, B, C>(arg0: *mut u8, top_of_stack: *mut u8)
where
F: FnOnce(A, &super::Suspend<A, B, C>) -> C,
{
unsafe {
let inner = Suspend(top_of_stack);
let initial = inner.take_resume::<A, B, C>();
super::Suspend::<A, B, C>::execute(inner, initial, Box::from_raw(arg0.cast::<F>()))
}
}
impl Fiber {
pub fn new<F, A, B, C>(stack: &FiberStack, func: F) -> io::Result<Self>
where
F: FnOnce(A, &super::Suspend<A, B, C>) -> C,
{
unsafe {
let data = Box::into_raw(Box::new(func)).cast();
wasmtime_fiber_init(stack.top, fiber_start::<F, A, B, C>, data);
}
Ok(Self)
}
pub(crate) fn resume<A, B, C>(&self, stack: &FiberStack, result: &Cell<RunResult<A, B, C>>) {
unsafe {
let addr = stack.top.cast::<usize>().offset(-1);
addr.write(result as *const _ as usize);
wasmtime_fiber_switch(stack.top);
addr.write(0);
}
}
}
impl Suspend {
pub(crate) fn switch<A, B, C>(&self, result: RunResult<A, B, C>) -> A {
unsafe {
(*self.result_location::<A, B, C>()).set(result);
wasmtime_fiber_switch(self.0);
self.take_resume::<A, B, C>()
}
}
unsafe fn take_resume<A, B, C>(&self) -> A {
match (*self.result_location::<A, B, C>()).replace(RunResult::Executing) {
RunResult::Resuming(val) => val,
_ => panic!("not in resuming state"),
}
}
unsafe fn result_location<A, B, C>(&self) -> *const Cell<RunResult<A, B, C>> {
let ret = self.0.cast::<*const u8>().offset(-1).read();
assert!(!ret.is_null());
ret.cast()
}
}