use super::trapcode::TrapCode;
use crate::vmcontext::{VMFunctionBody, VMFunctionEnvironment, VMTrampoline};
use backtrace::Backtrace;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::error::Error;
use std::mem::{self, MaybeUninit};
use std::ptr;
pub use tls::TlsRestore;
extern "C" {
fn unc_vm_register_setjmp(
jmp_buf: *mut *const u8,
callback: extern "C" fn(*mut u8),
payload: *mut u8,
) -> i32;
fn unc_vm_unwind(jmp_buf: *const u8) -> !;
}
pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::UserTrap(data)))
}
pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::LibTrap(trap)))
}
pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(payload)))
}
#[derive(Debug)]
pub enum Trap {
User(Box<dyn Error + Send + Sync>),
Wasm {
pc: usize,
backtrace: Backtrace,
signal_trap: Option<TrapCode>,
},
Lib {
trap_code: TrapCode,
backtrace: Backtrace,
},
OOM {
backtrace: Backtrace,
},
}
impl Trap {
pub fn wasm(pc: usize, backtrace: Backtrace, signal_trap: Option<TrapCode>) -> Self {
Self::Wasm { pc, backtrace, signal_trap }
}
pub fn lib(trap_code: TrapCode) -> Self {
let backtrace = Backtrace::new_unresolved();
Self::Lib { trap_code, backtrace }
}
pub fn oom() -> Self {
let backtrace = Backtrace::new_unresolved();
Self::OOM { backtrace }
}
}
pub unsafe fn unc_vm_call_trampoline(
callee_env: VMFunctionEnvironment,
trampoline: VMTrampoline,
callee: *const VMFunctionBody,
values_vec: *mut u8,
) -> Result<(), Trap> {
catch_traps(|| {
mem::transmute::<_, extern "C" fn(VMFunctionEnvironment, *const VMFunctionBody, *mut u8)>(
trampoline,
)(callee_env, callee, values_vec);
})
}
pub unsafe fn catch_traps<F>(mut closure: F) -> Result<(), Trap>
where
F: FnMut(),
{
return CallThreadState::new().with(|cx| {
unc_vm_register_setjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
)
});
extern "C" fn call_closure<F>(payload: *mut u8)
where
F: FnMut(),
{
unsafe { (*(payload as *mut F))() }
}
}
pub unsafe fn catch_traps_with_result<F, R>(mut closure: F) -> Result<R, Trap>
where
F: FnMut() -> R,
{
let mut global_results = MaybeUninit::<R>::uninit();
catch_traps(|| {
global_results.as_mut_ptr().write(closure());
})?;
Ok(global_results.assume_init())
}
pub struct CallThreadState {
unwind: UnsafeCell<MaybeUninit<UnwindReason>>,
jmp_buf: Cell<*const u8>,
prev: Cell<tls::Ptr>,
}
enum UnwindReason {
Panic(Box<dyn Any + Send>),
UserTrap(Box<dyn Error + Send + Sync>),
LibTrap(Trap),
WasmTrap { backtrace: Backtrace, pc: usize, signal_trap: Option<TrapCode> },
}
impl<'a> CallThreadState {
#[inline]
fn new() -> Self {
Self {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
prev: Cell::new(ptr::null()),
}
}
fn with(self, closure: impl FnOnce(&Self) -> i32) -> Result<(), Trap> {
let ret = tls::set(&self, || closure(&self))?;
if ret != 0 {
return Ok(());
}
match unsafe { (*self.unwind.get()).as_ptr().read() } {
UnwindReason::UserTrap(data) => Err(Trap::User(data)),
UnwindReason::LibTrap(trap) => Err(trap),
UnwindReason::WasmTrap { backtrace, pc, signal_trap } => {
Err(Trap::wasm(pc, backtrace, signal_trap))
}
UnwindReason::Panic(panic) => std::panic::resume_unwind(panic),
}
}
fn unwind_with(&self, reason: UnwindReason) -> ! {
unsafe {
(*self.unwind.get()).as_mut_ptr().write(reason);
unc_vm_unwind(self.jmp_buf.get());
}
}
}
mod tls {
use super::CallThreadState;
use crate::Trap;
use std::mem;
use std::ptr;
pub use raw::Ptr;
mod raw {
use super::CallThreadState;
use crate::Trap;
use std::cell::Cell;
use std::ptr;
pub type Ptr = *const CallThreadState;
thread_local!(static PTR: Cell<Ptr> = const { Cell::new(ptr::null()) });
#[inline(never)] pub fn replace(val: Ptr) -> Result<Ptr, Trap> {
PTR.with(|p| {
let prev = p.get();
p.set(val);
Ok(prev)
})
}
#[inline(never)] pub fn get() -> Ptr {
PTR.with(|p| p.get())
}
}
pub struct TlsRestore(raw::Ptr);
impl TlsRestore {
pub unsafe fn take() -> Result<Self, Trap> {
let raw = raw::get();
assert!(!raw.is_null());
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev)?;
Ok(Self(raw))
}
pub unsafe fn replace(self) -> Result<(), super::Trap> {
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0)?;
Ok(())
}
}
pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a>(&'a CallThreadState);
impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()))
.expect("tls should be previously initialized");
}
}
let ptr = unsafe { mem::transmute::<*const CallThreadState, _>(state) };
let prev = raw::replace(ptr)?;
state.prev.set(prev);
let _reset = Reset(state);
Ok(closure())
}
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
}
}
extern "C" fn signal_less_trap_handler(pc: *const u8, trap: TrapCode) {
let jmp_buf = tls::with(|info| {
let backtrace = Backtrace::new_unresolved();
let info = info.unwrap();
unsafe {
(*info.unwind.get()).as_mut_ptr().write(UnwindReason::WasmTrap {
backtrace,
signal_trap: Some(trap),
pc: pc as usize,
});
info.jmp_buf.get()
}
});
unsafe {
unc_vm_unwind(jmp_buf);
}
}
pub fn get_trap_handler() -> *const u8 {
signal_less_trap_handler as *const u8
}