use crate::vmcontext::{VMFunctionContext, VMTrampoline};
use crate::{Trap, VMFunctionBody};
use backtrace::Backtrace;
use core::ptr::{read, read_unaligned};
use corosensei::stack::DefaultStack;
use corosensei::trap::{CoroutineTrapHandler, TrapHandlerRegs};
use corosensei::{CoroutineResult, ScopedCoroutine, Yielder};
use scopeguard::defer;
use std::any::Any;
use std::cell::Cell;
use std::error::Error;
use std::io;
use std::mem;
#[cfg(unix)]
use std::mem::MaybeUninit;
use std::ptr::{self, NonNull};
use std::sync::atomic::{compiler_fence, AtomicPtr, Ordering};
use std::sync::{Mutex, Once};
use wasmer_types::TrapCode;
static MAGIC: u8 = 0xc0;
cfg_if::cfg_if! {
if #[cfg(unix)] {
pub type TrapHandlerFn<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + Send + Sync + 'a;
} else if #[cfg(target_os = "windows")] {
pub type TrapHandlerFn<'a> = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool + Send + Sync + 'a;
}
}
unsafe fn process_illegal_op(addr: usize) -> Option<TrapCode> {
let mut val: Option<u8> = None;
if cfg!(target_arch = "x86_64") {
val = if read(addr as *mut u8) & 0xf0 == 0x40
&& read((addr + 1) as *mut u8) == 0x0f
&& read((addr + 2) as *mut u8) == 0xb9
{
Some(read((addr + 3) as *mut u8))
} else if read(addr as *mut u8) == 0x0f && read((addr + 1) as *mut u8) == 0xb9 {
Some(read((addr + 2) as *mut u8))
} else {
None
}
}
if cfg!(target_arch = "aarch64") {
val = if read_unaligned(addr as *mut u32) & 0xffff0000 == 0 {
Some(read(addr as *mut u8))
} else {
None
}
}
match val.and_then(|val| {
if val & MAGIC == MAGIC {
Some(val & 0xf)
} else {
None
}
}) {
None => None,
Some(val) => match val {
0 => Some(TrapCode::StackOverflow),
1 => Some(TrapCode::HeapAccessOutOfBounds),
2 => Some(TrapCode::HeapMisaligned),
3 => Some(TrapCode::TableAccessOutOfBounds),
4 => Some(TrapCode::IndirectCallToNull),
5 => Some(TrapCode::BadSignature),
6 => Some(TrapCode::IntegerOverflow),
7 => Some(TrapCode::IntegerDivisionByZero),
8 => Some(TrapCode::BadConversionToInteger),
9 => Some(TrapCode::UnreachableCodeReached),
10 => Some(TrapCode::UnalignedAtomic),
_ => None,
},
}
}
pub unsafe trait TrapHandler {
fn custom_trap_handler(&self, call: &dyn Fn(&TrapHandlerFn) -> bool) -> bool;
}
cfg_if::cfg_if! {
if #[cfg(unix)] {
static mut PREV_SIGSEGV: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
static mut PREV_SIGBUS: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
static mut PREV_SIGILL: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
static mut PREV_SIGFPE: MaybeUninit<libc::sigaction> = MaybeUninit::uninit();
unsafe fn platform_init() {
let register = |slot: &mut MaybeUninit<libc::sigaction>, signal: i32| {
let mut handler: libc::sigaction = mem::zeroed();
handler.sa_flags = libc::SA_SIGINFO | libc::SA_NODEFER | libc::SA_ONSTACK;
handler.sa_sigaction = trap_handler as usize;
libc::sigemptyset(&mut handler.sa_mask);
if libc::sigaction(signal, &handler, slot.as_mut_ptr()) != 0 {
panic!(
"unable to install signal handler: {}",
io::Error::last_os_error(),
);
}
};
register(&mut PREV_SIGSEGV, libc::SIGSEGV);
register(&mut PREV_SIGILL, libc::SIGILL);
if cfg!(target_arch = "x86") || cfg!(target_arch = "x86_64") {
register(&mut PREV_SIGFPE, libc::SIGFPE);
}
if cfg!(target_arch = "arm") || cfg!(target_vendor = "apple") {
register(&mut PREV_SIGBUS, libc::SIGBUS);
}
#[cfg(target_vendor = "apple")]
{
use mach::exception_types::*;
use mach::kern_return::*;
use mach::port::*;
use mach::thread_status::*;
use mach::traps::*;
use mach::mach_types::*;
extern "C" {
fn task_set_exception_ports(
task: task_t,
exception_mask: exception_mask_t,
new_port: mach_port_t,
behavior: exception_behavior_t,
new_flavor: thread_state_flavor_t,
) -> kern_return_t;
}
#[allow(non_snake_case)]
#[cfg(target_arch = "x86_64")]
let MACHINE_THREAD_STATE = x86_THREAD_STATE64;
#[allow(non_snake_case)]
#[cfg(target_arch = "aarch64")]
let MACHINE_THREAD_STATE = 6;
task_set_exception_ports(
mach_task_self(),
EXC_MASK_BAD_ACCESS | EXC_MASK_ARITHMETIC | EXC_MASK_BAD_INSTRUCTION,
MACH_PORT_NULL,
EXCEPTION_STATE_IDENTITY as exception_behavior_t,
MACHINE_THREAD_STATE,
);
}
}
unsafe extern "C" fn trap_handler(
signum: libc::c_int,
siginfo: *mut libc::siginfo_t,
context: *mut libc::c_void,
) {
let previous = match signum {
libc::SIGSEGV => &PREV_SIGSEGV,
libc::SIGBUS => &PREV_SIGBUS,
libc::SIGFPE => &PREV_SIGFPE,
libc::SIGILL => &PREV_SIGILL,
_ => panic!("unknown signal: {}", signum),
};
let maybe_fault_address = match signum {
libc::SIGSEGV | libc::SIGBUS => {
Some((*siginfo).si_addr() as usize)
}
_ => None,
};
let trap_code = match signum {
libc::SIGILL => {
let addr = (*siginfo).si_addr() as usize;
process_illegal_op(addr)
}
_ => None,
};
let ucontext = &mut *(context as *mut libc::ucontext_t);
let (pc, sp) = get_pc_sp(ucontext);
let handled = TrapHandlerContext::handle_trap(
pc,
sp,
maybe_fault_address,
trap_code,
|regs| update_context(ucontext, regs),
|handler| handler(signum, siginfo, context),
);
if handled {
return;
}
let previous = &*previous.as_ptr();
if previous.sa_flags & libc::SA_SIGINFO != 0 {
mem::transmute::<
usize,
extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void),
>(previous.sa_sigaction)(signum, siginfo, context)
} else if previous.sa_sigaction == libc::SIG_DFL
{
libc::sigaction(signum, previous, ptr::null_mut());
} else if previous.sa_sigaction != libc::SIG_IGN {
mem::transmute::<usize, extern "C" fn(libc::c_int)>(
previous.sa_sigaction
)(signum)
}
}
unsafe fn get_pc_sp(context: &libc::ucontext_t) -> (usize, usize) {
let (pc, sp);
cfg_if::cfg_if! {
if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "x86_64",
))] {
pc = context.uc_mcontext.gregs[libc::REG_RIP as usize] as usize;
sp = context.uc_mcontext.gregs[libc::REG_RSP as usize] as usize;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "x86",
))] {
pc = context.uc_mcontext.gregs[libc::REG_EIP as usize] as usize;
sp = context.uc_mcontext.gregs[libc::REG_ESP as usize] as usize;
} else if #[cfg(all(target_os = "freebsd", any(target_arch = "x86", target_arch = "x86_64")))] {
pc = context.uc_mcontext.mc_rip as usize;
sp = context.uc_mcontext.mc_rsp as usize;
} else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
pc = (*context.uc_mcontext).__ss.__rip as usize;
sp = (*context.uc_mcontext).__ss.__rsp as usize;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "aarch64",
))] {
pc = context.uc_mcontext.pc as usize;
sp = context.uc_mcontext.sp as usize;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "arm",
))] {
pc = context.uc_mcontext.arm_pc as usize;
sp = context.uc_mcontext.arm_sp as usize;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
any(target_arch = "riscv64", target_arch = "riscv32"),
))] {
pc = context.uc_mcontext.__gregs[libc::REG_PC] as usize;
sp = context.uc_mcontext.__gregs[libc::REG_SP] as usize;
} else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
pc = (*context.uc_mcontext).__ss.__pc as usize;
sp = (*context.uc_mcontext).__ss.__sp as usize;
} else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
pc = context.uc_mcontext.mc_gpregs.gp_elr as usize;
sp = context.uc_mcontext.mc_gpregs.gp_sp as usize;
} else {
compile_error!("Unsupported platform");
}
};
(pc, sp)
}
unsafe fn update_context(context: &mut libc::ucontext_t, regs: TrapHandlerRegs) {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "x86_64",
))] {
let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
context.uc_mcontext.gregs[libc::REG_RIP as usize] = rip as i64;
context.uc_mcontext.gregs[libc::REG_RSP as usize] = rsp as i64;
context.uc_mcontext.gregs[libc::REG_RBP as usize] = rbp as i64;
context.uc_mcontext.gregs[libc::REG_RDI as usize] = rdi as i64;
context.uc_mcontext.gregs[libc::REG_RSI as usize] = rsi as i64;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "x86",
))] {
let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
context.uc_mcontext.gregs[libc::REG_EIP as usize] = eip as i32;
context.uc_mcontext.gregs[libc::REG_ESP as usize] = esp as i32;
context.uc_mcontext.gregs[libc::REG_EBP as usize] = ebp as i32;
context.uc_mcontext.gregs[libc::REG_ECX as usize] = ecx as i32;
context.uc_mcontext.gregs[libc::REG_EDX as usize] = edx as i32;
} else if #[cfg(all(target_vendor = "apple", target_arch = "x86_64"))] {
let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
(*context.uc_mcontext).__ss.__rip = rip;
(*context.uc_mcontext).__ss.__rsp = rsp;
(*context.uc_mcontext).__ss.__rbp = rbp;
(*context.uc_mcontext).__ss.__rdi = rdi;
(*context.uc_mcontext).__ss.__rsi = rsi;
} else if #[cfg(all(target_os = "freebsd", target_arch = "x86_64"))] {
let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
context.uc_mcontext.mc_rip = rip as libc::register_t;
context.uc_mcontext.mc_rsp = rsp as libc::register_t;
context.uc_mcontext.mc_rbp = rbp as libc::register_t;
context.uc_mcontext.mc_rdi = rdi as libc::register_t;
context.uc_mcontext.mc_rsi = rsi as libc::register_t;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "aarch64",
))] {
let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
context.uc_mcontext.pc = pc;
context.uc_mcontext.sp = sp;
context.uc_mcontext.regs[0] = x0;
context.uc_mcontext.regs[1] = x1;
context.uc_mcontext.regs[29] = x29;
context.uc_mcontext.regs[30] = lr;
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
target_arch = "arm",
))] {
let TrapHandlerRegs {
pc,
r0,
r1,
r7,
r11,
r13,
r14,
cpsr_thumb,
cpsr_endian,
} = regs;
context.uc_mcontext.arm_pc = pc;
context.uc_mcontext.arm_r0 = r0;
context.uc_mcontext.arm_r1 = r1;
context.uc_mcontext.arm_r7 = r7;
context.uc_mcontext.arm_fp = r11;
context.uc_mcontext.arm_sp = r13;
context.uc_mcontext.arm_lr = r14;
if cpsr_thumb {
context.uc_mcontext.arm_cpsr |= 0x20;
} else {
context.uc_mcontext.arm_cpsr &= !0x20;
}
if cpsr_endian {
context.uc_mcontext.arm_cpsr |= 0x200;
} else {
context.uc_mcontext.arm_cpsr &= !0x200;
}
} else if #[cfg(all(
any(target_os = "linux", target_os = "android"),
any(target_arch = "riscv64", target_arch = "riscv32"),
))] {
let TrapHandlerRegs { pc, ra, sp, a0, a1, s0 } = regs;
context.uc_mcontext.__gregs[libc::REG_PC] = pc as libc::c_ulong;
context.uc_mcontext.__gregs[libc::REG_RA] = ra as libc::c_ulong;
context.uc_mcontext.__gregs[libc::REG_SP] = sp as libc::c_ulong;
context.uc_mcontext.__gregs[libc::REG_A0] = a0 as libc::c_ulong;
context.uc_mcontext.__gregs[libc::REG_A0 + 1] = a1 as libc::c_ulong;
context.uc_mcontext.__gregs[libc::REG_S0] = s0 as libc::c_ulong;
} else if #[cfg(all(target_vendor = "apple", target_arch = "aarch64"))] {
let TrapHandlerRegs { pc, sp, x0, x1, x29, lr } = regs;
(*context.uc_mcontext).__ss.__pc = pc;
(*context.uc_mcontext).__ss.__sp = sp;
(*context.uc_mcontext).__ss.__x[0] = x0;
(*context.uc_mcontext).__ss.__x[1] = x1;
(*context.uc_mcontext).__ss.__fp = x29;
(*context.uc_mcontext).__ss.__lr = lr;
} else if #[cfg(all(target_os = "freebsd", target_arch = "aarch64"))] {
context.uc_mcontext.mc_gpregs.gp_pc = pc as libc::register_t;
context.uc_mcontext.mc_gpregs.gp_sp = sp as libc::register_t;
context.uc_mcontext.mc_gpregs.gp_x[0] = x0 as libc::register_t;
context.uc_mcontext.mc_gpregs.gp_x[1] = x1 as libc::register_t;
context.uc_mcontext.mc_gpregs.gp_x[29] = x29 as libc::register_t;
context.uc_mcontext.mc_gpregs.gp_x[30] = lr as libc::register_t;
} else {
compile_error!("Unsupported platform");
}
};
}
} else if #[cfg(target_os = "windows")] {
use winapi::um::errhandlingapi::*;
use winapi::um::winnt::*;
use winapi::um::minwinbase::*;
use winapi::vc::excpt::*;
unsafe fn platform_init() {
if AddVectoredExceptionHandler(1, Some(exception_handler)).is_null() {
panic!("failed to add exception handler: {}", io::Error::last_os_error());
}
}
unsafe extern "system" fn exception_handler(
exception_info: PEXCEPTION_POINTERS
) -> LONG {
let record = &*(*exception_info).ExceptionRecord;
if record.ExceptionCode != EXCEPTION_ACCESS_VIOLATION &&
record.ExceptionCode != EXCEPTION_ILLEGAL_INSTRUCTION &&
record.ExceptionCode != EXCEPTION_STACK_OVERFLOW &&
record.ExceptionCode != EXCEPTION_INT_DIVIDE_BY_ZERO &&
record.ExceptionCode != EXCEPTION_INT_OVERFLOW
{
return EXCEPTION_CONTINUE_SEARCH;
}
let context = &mut *(*exception_info).ContextRecord;
let (pc, sp) = get_pc_sp(context);
let maybe_fault_address = match record.ExceptionCode {
EXCEPTION_ACCESS_VIOLATION => Some(record.ExceptionInformation[1]),
EXCEPTION_STACK_OVERFLOW => Some(sp),
_ => None,
};
let trap_code = match record.ExceptionCode {
EXCEPTION_ILLEGAL_INSTRUCTION => {
process_illegal_op(pc)
}
_ => None,
};
let handled = TrapHandlerContext::handle_trap(
pc,
sp,
maybe_fault_address,
trap_code,
|regs| update_context(context, regs),
|handler| handler(exception_info),
);
if handled {
EXCEPTION_CONTINUE_EXECUTION
} else {
EXCEPTION_CONTINUE_SEARCH
}
}
unsafe fn get_pc_sp(context: &CONTEXT) -> (usize, usize) {
let (pc, sp);
cfg_if::cfg_if! {
if #[cfg(target_arch = "x86_64")] {
pc = context.Rip as usize;
sp = context.Rsp as usize;
} else if #[cfg(target_arch = "x86")] {
pc = context.Rip as usize;
sp = context.Rsp as usize;
} else {
compile_error!("Unsupported platform");
}
};
(pc, sp)
}
unsafe fn update_context(context: &mut CONTEXT, regs: TrapHandlerRegs) {
cfg_if::cfg_if! {
if #[cfg(target_arch = "x86_64")] {
let TrapHandlerRegs { rip, rsp, rbp, rdi, rsi } = regs;
context.Rip = rip;
context.Rsp = rsp;
context.Rbp = rbp;
context.Rdi = rdi;
context.Rsi = rsi;
} else if #[cfg(target_arch = "x86")] {
let TrapHandlerRegs { eip, esp, ebp, ecx, edx } = regs;
context.Eip = eip;
context.Esp = esp;
context.Ebp = ebp;
context.Ecx = ecx;
context.Edx = edx;
} else {
compile_error!("Unsupported platform");
}
};
}
}
}
pub fn init_traps() {
static INIT: Once = Once::new();
INIT.call_once(|| unsafe {
platform_init();
});
}
pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
unwind_with(UnwindReason::UserTrap(data))
}
pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
unwind_with(UnwindReason::LibTrap(trap))
}
pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
unwind_with(UnwindReason::Panic(payload))
}
pub unsafe fn wasmer_call_trampoline(
trap_handler: Option<*const TrapHandlerFn<'static>>,
vmctx: VMFunctionContext,
trampoline: VMTrampoline,
callee: *const VMFunctionBody,
values_vec: *mut u8,
) -> Result<(), Trap> {
catch_traps(trap_handler, || {
mem::transmute::<_, extern "C" fn(VMFunctionContext, *const VMFunctionBody, *mut u8)>(
trampoline,
)(vmctx, callee, values_vec);
})
}
pub unsafe fn catch_traps<F, R>(
trap_handler: Option<*const TrapHandlerFn<'static>>,
closure: F,
) -> Result<R, Trap>
where
F: FnOnce() -> R,
{
lazy_per_thread_init()?;
on_wasm_stack(trap_handler, closure).map_err(UnwindReason::into_trap)
}
thread_local! {
static YIELDER: Cell<Option<NonNull<Yielder<(), UnwindReason>>>> = Cell::new(None);
static TRAP_HANDLER: AtomicPtr<TrapHandlerContext> = AtomicPtr::new(ptr::null_mut());
}
#[allow(clippy::type_complexity)]
struct TrapHandlerContext {
inner: *const u8,
handle_trap: fn(
*const u8,
usize,
usize,
Option<usize>,
Option<TrapCode>,
&mut dyn FnMut(TrapHandlerRegs),
) -> bool,
custom_trap: Option<*const TrapHandlerFn<'static>>,
}
struct TrapHandlerContextInner<T> {
coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
}
impl TrapHandlerContext {
fn install<T, R>(
custom_trap: Option<*const TrapHandlerFn<'static>>,
coro_trap_handler: CoroutineTrapHandler<Result<T, UnwindReason>>,
f: impl FnOnce() -> R,
) -> R {
fn func<T>(
ptr: *const u8,
pc: usize,
sp: usize,
maybe_fault_address: Option<usize>,
trap_code: Option<TrapCode>,
update_regs: &mut dyn FnMut(TrapHandlerRegs),
) -> bool {
unsafe {
(*(ptr as *const TrapHandlerContextInner<T>)).handle_trap(
pc,
sp,
maybe_fault_address,
trap_code,
update_regs,
)
}
}
let inner = TrapHandlerContextInner { coro_trap_handler };
let ctx = Self {
inner: &inner as *const _ as *const u8,
handle_trap: func::<T>,
custom_trap,
};
compiler_fence(Ordering::Release);
let prev = TRAP_HANDLER.with(|ptr| {
let prev = ptr.load(Ordering::Relaxed);
ptr.store(&ctx as *const Self as *mut Self, Ordering::Relaxed);
prev
});
defer! {
TRAP_HANDLER.with(|ptr| ptr.store(prev, Ordering::Relaxed));
compiler_fence(Ordering::Acquire);
}
f()
}
unsafe fn handle_trap(
pc: usize,
sp: usize,
maybe_fault_address: Option<usize>,
trap_code: Option<TrapCode>,
mut update_regs: impl FnMut(TrapHandlerRegs),
call_handler: impl Fn(&TrapHandlerFn<'static>) -> bool,
) -> bool {
let ptr = TRAP_HANDLER.with(|ptr| ptr.load(Ordering::Relaxed));
if ptr.is_null() {
return false;
}
let ctx = &*ptr;
if let Some(trap_handler) = ctx.custom_trap {
if call_handler(&*trap_handler) {
return true;
}
}
(ctx.handle_trap)(
ctx.inner,
pc,
sp,
maybe_fault_address,
trap_code,
&mut update_regs,
)
}
}
impl<T> TrapHandlerContextInner<T> {
unsafe fn handle_trap(
&self,
pc: usize,
sp: usize,
maybe_fault_address: Option<usize>,
trap_code: Option<TrapCode>,
update_regs: &mut dyn FnMut(TrapHandlerRegs),
) -> bool {
if !self.coro_trap_handler.stack_ptr_in_bounds(sp) {
return false;
}
let signal_trap = trap_code.or_else(|| {
maybe_fault_address.map(|addr| {
if self.coro_trap_handler.stack_ptr_in_bounds(addr) {
TrapCode::StackOverflow
} else {
TrapCode::HeapAccessOutOfBounds
}
})
});
let backtrace = if signal_trap == Some(TrapCode::StackOverflow) {
Backtrace::from(vec![])
} else {
Backtrace::new_unresolved()
};
let unwind = UnwindReason::WasmTrap {
backtrace,
signal_trap,
pc,
};
let regs = self
.coro_trap_handler
.setup_trap_handler(move || Err(unwind));
update_regs(regs);
true
}
}
enum UnwindReason {
Panic(Box<dyn Any + Send>),
UserTrap(Box<dyn Error + Send + Sync>),
LibTrap(Trap),
WasmTrap {
backtrace: Backtrace,
pc: usize,
signal_trap: Option<TrapCode>,
},
}
impl UnwindReason {
fn into_trap(self) -> Trap {
match self {
Self::UserTrap(data) => Trap::User(data),
Self::LibTrap(trap) => trap,
Self::WasmTrap {
backtrace,
pc,
signal_trap,
} => Trap::wasm(pc, backtrace, signal_trap),
Self::Panic(panic) => std::panic::resume_unwind(panic),
}
}
}
unsafe fn unwind_with(reason: UnwindReason) -> ! {
let yielder = YIELDER
.with(|cell| cell.replace(None))
.expect("not running on Wasm stack");
yielder.as_ref().suspend(reason);
unreachable!();
}
fn on_wasm_stack<F: FnOnce() -> T, T>(
trap_handler: Option<*const TrapHandlerFn<'static>>,
f: F,
) -> Result<T, UnwindReason> {
lazy_static::lazy_static! {
static ref STACK_POOL: Mutex<Vec<DefaultStack>> = Mutex::new(vec![]);
}
let stack = STACK_POOL.lock().unwrap().pop().unwrap_or_default();
let mut stack = scopeguard::guard(stack, |stack| STACK_POOL.lock().unwrap().push(stack));
let mut coro = ScopedCoroutine::with_stack(&mut *stack, move |yielder, ()| {
YIELDER.with(|cell| cell.set(Some(yielder.into())));
Ok(f())
});
defer! {
YIELDER.with(|cell| cell.set(None));
}
TrapHandlerContext::install(trap_handler, coro.trap_handler(), || {
match coro.resume(()) {
CoroutineResult::Yield(trap) => {
unsafe {
coro.force_reset();
}
Err(trap)
}
CoroutineResult::Return(result) => result,
}
})
}
pub fn on_host_stack<F: FnOnce() -> T, T>(f: F) -> T {
let yielder_ptr = YIELDER.with(|cell| cell.replace(None));
let yielder = match yielder_ptr {
Some(ptr) => unsafe { ptr.as_ref() },
None => return f(),
};
defer! {
YIELDER.with(|cell| cell.set(yielder_ptr));
}
struct SendWrapper<T>(T);
unsafe impl<T> Send for SendWrapper<T> {}
let wrapped = SendWrapper(f);
yielder.on_parent_stack(move || (wrapped.0)())
}
#[cfg(windows)]
pub fn lazy_per_thread_init() -> Result<(), Trap> {
use winapi::um::processthreadsapi::SetThreadStackGuarantee;
if unsafe { SetThreadStackGuarantee(&mut 0x10000) } == 0 {
panic!("failed to set thread stack guarantee");
}
Ok(())
}
#[cfg(unix)]
pub fn lazy_per_thread_init() -> Result<(), Trap> {
use std::ptr::null_mut;
thread_local! {
static TLS: Tls = unsafe { init_sigstack() };
}
const MIN_STACK_SIZE: usize = 16 * 4096;
enum Tls {
OutOfMemory,
Allocated {
mmap_ptr: *mut libc::c_void,
mmap_size: usize,
},
BigEnough,
}
unsafe fn init_sigstack() -> Tls {
let mut old_stack = mem::zeroed();
let r = libc::sigaltstack(ptr::null(), &mut old_stack);
assert_eq!(r, 0, "learning about sigaltstack failed");
if old_stack.ss_flags & libc::SS_DISABLE == 0 && old_stack.ss_size >= MIN_STACK_SIZE {
return Tls::BigEnough;
}
let page_size: usize = region::page::size();
let guard_size = page_size;
let alloc_size = guard_size + MIN_STACK_SIZE;
let ptr = libc::mmap(
null_mut(),
alloc_size,
libc::PROT_NONE,
libc::MAP_PRIVATE | libc::MAP_ANON,
-1,
0,
);
if ptr == libc::MAP_FAILED {
return Tls::OutOfMemory;
}
let stack_ptr = (ptr as usize + guard_size) as *mut libc::c_void;
let r = libc::mprotect(
stack_ptr,
MIN_STACK_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
);
assert_eq!(r, 0, "mprotect to configure memory for sigaltstack failed");
let new_stack = libc::stack_t {
ss_sp: stack_ptr,
ss_flags: 0,
ss_size: MIN_STACK_SIZE,
};
let r = libc::sigaltstack(&new_stack, ptr::null_mut());
assert_eq!(r, 0, "registering new sigaltstack failed");
Tls::Allocated {
mmap_ptr: ptr,
mmap_size: alloc_size,
}
}
return TLS.with(|tls| {
if let Tls::OutOfMemory = tls {
Err(Trap::oom())
} else {
Ok(())
}
});
impl Drop for Tls {
fn drop(&mut self) {
let (ptr, size) = match self {
Self::Allocated {
mmap_ptr,
mmap_size,
} => (*mmap_ptr, *mmap_size),
_ => return,
};
unsafe {
let r = libc::munmap(ptr, size);
debug_assert_eq!(r, 0, "munmap failed during thread shutdown");
}
}
}
}