use std::cell::Cell;
use std::ptr::null_mut;
use std::sync::Once;
use super::JitContext;
use super::codegen::EXIT_PAGE_FAULT;
#[repr(C)]
pub struct SignalState {
pub code_start: usize,
pub code_end: usize,
pub exit_label_addr: usize,
pub ctx_ptr: *mut JitContext,
pub trap_table: Vec<(u32, u32)>,
}
unsafe impl Send for SignalState {}
thread_local! {
pub static SIGNAL_STATE: Cell<*mut SignalState> = const { Cell::new(null_mut()) };
}
static INIT: Once = Once::new();
static mut PREV_SIGSEGV: libc::sigaction = unsafe { std::mem::zeroed() };
pub fn ensure_installed() {
INIT.call_once(|| unsafe {
install_sigaltstack();
install_handler();
});
}
unsafe fn install_handler() {
let mut sa: libc::sigaction = std::mem::zeroed();
sa.sa_flags = libc::SA_SIGINFO | libc::SA_ONSTACK;
sa.sa_sigaction = sigsegv_handler as usize;
libc::sigemptyset(&mut sa.sa_mask);
let r = libc::sigaction(libc::SIGSEGV, &sa, &raw mut PREV_SIGSEGV);
assert_eq!(r, 0, "sigaction(SIGSEGV) failed: {}", std::io::Error::last_os_error());
}
unsafe fn install_sigaltstack() {
let mut old: libc::stack_t = std::mem::zeroed();
libc::sigaltstack(std::ptr::null(), &mut old);
const MIN_STACK: usize = 64 * 4096; if old.ss_flags & libc::SS_DISABLE == 0 && old.ss_size >= MIN_STACK {
return; }
let page_size: usize = 4096;
let alloc_size = page_size + MIN_STACK; let ptr = libc::mmap(
null_mut(),
alloc_size,
libc::PROT_NONE,
libc::MAP_PRIVATE | libc::MAP_ANONYMOUS,
-1,
0,
);
assert_ne!(ptr, libc::MAP_FAILED, "mmap for sigaltstack failed");
let stack_ptr = (ptr as usize + page_size) as *mut libc::c_void;
libc::mprotect(stack_ptr, MIN_STACK, libc::PROT_READ | libc::PROT_WRITE);
let new_stack = libc::stack_t {
ss_sp: stack_ptr,
ss_flags: 0,
ss_size: MIN_STACK,
};
let r = libc::sigaltstack(&new_stack, null_mut());
assert_eq!(r, 0, "sigaltstack failed: {}", std::io::Error::last_os_error());
}
unsafe extern "C" fn sigsegv_handler(
signum: libc::c_int,
_siginfo: *mut libc::siginfo_t,
ucontext: *mut libc::c_void,
) {
let state_ptr = SIGNAL_STATE.with(|cell| cell.get());
if state_ptr.is_null() {
delegate_to_previous(signum, _siginfo, ucontext);
return;
}
let state = &*state_ptr;
let cx = &mut *(ucontext as *mut libc::ucontext_t);
let pc = cx.uc_mcontext.gregs[libc::REG_RIP as usize] as usize;
if pc < state.code_start || pc >= state.code_end {
delegate_to_previous(signum, _siginfo, ucontext);
return;
}
let native_offset = (pc - state.code_start) as u32;
let pvm_pc = match state.trap_table.binary_search_by_key(&native_offset, |&(off, _)| off) {
Ok(idx) => state.trap_table[idx].1,
Err(_) => {
delegate_to_previous(signum, _siginfo, ucontext);
return;
}
};
let guest_addr = cx.uc_mcontext.gregs[libc::REG_RDX as usize] as u32;
let ctx = &mut *state.ctx_ptr;
ctx.exit_reason = EXIT_PAGE_FAULT;
ctx.exit_arg = guest_addr;
ctx.pc = pvm_pc;
cx.uc_mcontext.gregs[libc::REG_RIP as usize] = state.exit_label_addr as i64;
}
unsafe fn delegate_to_previous(
signum: libc::c_int,
siginfo: *mut libc::siginfo_t,
context: *mut libc::c_void,
) {
let prev = PREV_SIGSEGV;
if prev.sa_flags & libc::SA_SIGINFO != 0 {
let handler: extern "C" fn(libc::c_int, *mut libc::siginfo_t, *mut libc::c_void) =
std::mem::transmute(prev.sa_sigaction);
handler(signum, siginfo, context);
} else if prev.sa_sigaction == libc::SIG_DFL || prev.sa_sigaction == libc::SIG_IGN {
libc::sigaction(signum, &prev, null_mut());
} else {
let handler: extern "C" fn(libc::c_int) = std::mem::transmute(prev.sa_sigaction);
handler(signum);
}
}