use crate::{VMContext, VMInterrupts};
use backtrace::Backtrace;
use std::any::Any;
use std::cell::{Cell, UnsafeCell};
use std::error::Error;
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::Once;
use wasmtime_environ::ir;
pub use self::tls::{tls_eager_initialize, TlsRestore};
extern "C" {
#[allow(improper_ctypes)]
fn RegisterSetjmp(
jmp_buf: *mut *const u8,
callback: extern "C" fn(*mut u8, *mut VMContext),
payload: *mut u8,
callee: *mut VMContext,
) -> i32;
fn Unwind(jmp_buf: *const u8) -> !;
}
cfg_if::cfg_if! {
if #[cfg(target_os = "macos")] {
mod macos;
use macos as sys;
} else if #[cfg(unix)] {
mod unix;
use unix as sys;
} else if #[cfg(target_os = "windows")] {
mod windows;
use windows as sys;
}
}
pub use sys::SignalHandler;
static mut IS_WASM_PC: fn(usize) -> bool = |_| false;
pub fn init_traps(is_wasm_pc: fn(usize) -> bool) {
static INIT: Once = Once::new();
INIT.call_once(|| unsafe {
IS_WASM_PC = is_wasm_pc;
sys::platform_init();
});
}
pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::UserTrap(data)))
}
pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::LibTrap(trap)))
}
pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(payload)))
}
#[derive(Debug)]
pub enum Trap {
User(Box<dyn Error + Send + Sync>),
Jit {
pc: usize,
backtrace: Backtrace,
maybe_interrupted: bool,
},
Wasm {
trap_code: ir::TrapCode,
backtrace: Backtrace,
},
OOM {
backtrace: Backtrace,
},
}
impl Trap {
pub fn wasm(trap_code: ir::TrapCode) -> Self {
let backtrace = Backtrace::new_unresolved();
Trap::Wasm {
trap_code,
backtrace,
}
}
pub fn oom() -> Self {
let backtrace = Backtrace::new_unresolved();
Trap::OOM { backtrace }
}
}
pub unsafe fn catch_traps<'a, F>(
vminterrupts: *mut VMInterrupts,
signal_handler: Option<*const SignalHandler<'static>>,
callee: *mut VMContext,
mut closure: F,
) -> Result<(), Trap>
where
F: FnMut(*mut VMContext),
{
return CallThreadState::new(signal_handler).with(vminterrupts, |cx| {
RegisterSetjmp(
cx.jmp_buf.as_ptr(),
call_closure::<F>,
&mut closure as *mut F as *mut u8,
callee,
)
});
extern "C" fn call_closure<F>(payload: *mut u8, callee: *mut VMContext)
where
F: FnMut(*mut VMContext),
{
unsafe { (*(payload as *mut F))(callee) }
}
}
pub struct CallThreadState {
unwind: UnsafeCell<MaybeUninit<UnwindReason>>,
jmp_buf: Cell<*const u8>,
handling_trap: Cell<bool>,
signal_handler: Option<*const SignalHandler<'static>>,
prev: Cell<tls::Ptr>,
}
enum UnwindReason {
Panic(Box<dyn Any + Send>),
UserTrap(Box<dyn Error + Send + Sync>),
LibTrap(Trap),
JitTrap { backtrace: Backtrace, pc: usize },
}
impl CallThreadState {
#[inline]
fn new(signal_handler: Option<*const SignalHandler<'static>>) -> CallThreadState {
CallThreadState {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
handling_trap: Cell::new(false),
signal_handler,
prev: Cell::new(ptr::null()),
}
}
fn with(
self,
interrupts: *mut VMInterrupts,
closure: impl FnOnce(&CallThreadState) -> i32,
) -> Result<(), Trap> {
let ret = tls::set(&self, || closure(&self))?;
if ret != 0 {
return Ok(());
}
match unsafe { (*self.unwind.get()).as_ptr().read() } {
UnwindReason::UserTrap(data) => Err(Trap::User(data)),
UnwindReason::LibTrap(trap) => Err(trap),
UnwindReason::JitTrap { backtrace, pc } => {
let maybe_interrupted = unsafe {
(*interrupts).stack_limit.load(SeqCst) == wasmtime_environ::INTERRUPTED
};
Err(Trap::Jit {
pc,
backtrace,
maybe_interrupted,
})
}
UnwindReason::Panic(panic) => std::panic::resume_unwind(panic),
}
}
fn unwind_with(&self, reason: UnwindReason) -> ! {
unsafe {
(*self.unwind.get()).as_mut_ptr().write(reason);
Unwind(self.jmp_buf.get());
}
}
#[cfg_attr(target_os = "macos", allow(dead_code))] fn jmp_buf_if_trap(
&self,
pc: *const u8,
call_handler: impl Fn(&SignalHandler) -> bool,
) -> *const u8 {
if self.handling_trap.replace(true) {
return ptr::null();
}
let _reset = ResetCell(&self.handling_trap, false);
if self.jmp_buf.get().is_null() {
return ptr::null();
}
if let Some(handler) = self.signal_handler {
if unsafe { call_handler(&*handler) } {
return 1 as *const _;
}
}
if unsafe { !IS_WASM_PC(pc as usize) } {
return ptr::null();
}
self.jmp_buf.get()
}
fn capture_backtrace(&self, pc: *const u8) {
let backtrace = Backtrace::new_unresolved();
unsafe {
(*self.unwind.get())
.as_mut_ptr()
.write(UnwindReason::JitTrap {
backtrace,
pc: pc as usize,
});
}
}
}
struct ResetCell<'a, T: Copy>(&'a Cell<T>, T);
impl<T: Copy> Drop for ResetCell<'_, T> {
#[inline]
fn drop(&mut self) {
self.0.set(self.1);
}
}
mod tls {
use super::CallThreadState;
use crate::Trap;
use std::ptr;
pub use raw::Ptr;
mod raw {
use super::CallThreadState;
use crate::Trap;
use std::cell::Cell;
use std::ptr;
pub type Ptr = *const CallThreadState;
thread_local!(static PTR: Cell<(Ptr, bool)> = Cell::new((ptr::null(), false)));
#[inline(never)] pub fn replace(val: Ptr) -> Result<Ptr, Trap> {
PTR.with(|p| {
let (prev, mut initialized) = p.get();
if !initialized {
super::super::sys::lazy_per_thread_init()?;
initialized = true;
}
p.set((val, initialized));
Ok(prev)
})
}
#[inline(never)]
pub fn initialize() -> Result<(), Trap> {
PTR.with(|p| {
let (state, initialized) = p.get();
if initialized {
return Ok(());
}
super::super::sys::lazy_per_thread_init()?;
p.set((state, true));
Ok(())
})
}
#[inline(never)] pub fn get() -> Ptr {
PTR.with(|p| p.get().0)
}
}
pub use raw::initialize as tls_eager_initialize;
pub struct TlsRestore(raw::Ptr);
impl TlsRestore {
pub unsafe fn take() -> Result<TlsRestore, Trap> {
let raw = raw::get();
assert!(!raw.is_null());
let prev = (*raw).prev.replace(ptr::null());
raw::replace(prev)?;
Ok(TlsRestore(raw))
}
pub unsafe fn replace(self) -> Result<(), super::Trap> {
let prev = raw::get();
assert!((*self.0).prev.get().is_null());
(*self.0).prev.set(prev);
raw::replace(self.0)?;
Ok(())
}
}
#[inline]
pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a>(&'a CallThreadState);
impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()))
.expect("tls should be previously initialized");
}
}
let prev = raw::replace(state)?;
state.prev.set(prev);
let _reset = Reset(state);
Ok(closure())
}
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
}
}