near_vm_vm/trap/traphandlers.rs
1// This file contains code from external sources.
2// Attributions: https://github.com/wasmerio/wasmer/blob/2.3.0/ATTRIBUTIONS.md
3
4//! WebAssembly trap handling, which is built on top of the lower-level
5//! signalhandling mechanisms.
6
7use super::trapcode::TrapCode;
8use crate::vmcontext::{VMFunctionBody, VMFunctionEnvironment, VMTrampoline};
9use backtrace::Backtrace;
10use std::any::Any;
11use std::cell::{Cell, UnsafeCell};
12use std::error::Error;
13use std::mem::{self, MaybeUninit};
14use std::ptr;
15pub use tls::TlsRestore;
16
17unsafe extern "C" {
18 fn near_vm_register_setjmp(
19 jmp_buf: *mut *const u8,
20 callback: extern "C" fn(*mut u8),
21 payload: *mut u8,
22 ) -> i32;
23 fn near_vm_unwind(jmp_buf: *const u8) -> !;
24}
25
26/// Raises a user-defined trap immediately.
27///
28/// This function performs as-if a wasm trap was just executed, only the trap
29/// has a dynamic payload associated with it which is user-provided. This trap
30/// payload is then returned from `catch_traps` below.
31///
32/// # Safety
33///
34/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
35/// have been previous called and not yet returned.
36/// Additionally no Rust destructors may be on the stack.
37/// They will be skipped and not executed.
38pub unsafe fn raise_user_trap(data: Box<dyn Error + Send + Sync>) -> ! {
39 tls::with(|info| info.unwrap().unwind_with(UnwindReason::UserTrap(data)))
40}
41
42/// Raises a trap from inside library code immediately.
43///
44/// This function performs as-if a wasm trap was just executed. This trap
45/// payload is then returned from `catch_traps` below.
46///
47/// # Safety
48///
49/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
50/// have been previous called and not yet returned.
51/// Additionally no Rust destructors may be on the stack.
52/// They will be skipped and not executed.
53pub unsafe fn raise_lib_trap(trap: Trap) -> ! {
54 tls::with(|info| info.unwrap().unwind_with(UnwindReason::LibTrap(trap)))
55}
56
57/// Carries a Rust panic across wasm code and resumes the panic on the other
58/// side.
59///
60/// # Safety
61///
62/// Only safe to call when wasm code is on the stack, aka `catch_traps` must
63/// have been previously called and not returned. Additionally no Rust destructors may be on the
64/// stack. They will be skipped and not executed.
65pub unsafe fn resume_panic(payload: Box<dyn Any + Send>) -> ! {
66 tls::with(|info| info.unwrap().unwind_with(UnwindReason::Panic(payload)))
67}
68
69/// Stores trace message with backtrace.
70#[derive(Debug)]
71pub enum Trap {
72 /// A user-raised trap through `raise_user_trap`.
73 User(Box<dyn Error + Send + Sync>),
74
75 /// A trap raised from the Wasm generated code
76 ///
77 /// Note: this trap is deterministic (assuming a deterministic host implementation)
78 Wasm {
79 /// The program counter in generated code where this trap happened.
80 pc: usize,
81 /// Native stack backtrace at the time the trap occurred
82 backtrace: Backtrace,
83 /// Optional trapcode associated to the signal that caused the trap
84 signal_trap: Option<TrapCode>,
85 },
86
87 /// A trap raised from a wasm libcall
88 ///
89 /// Note: this trap is deterministic (assuming a deterministic host implementation)
90 Lib {
91 /// Code of the trap.
92 trap_code: TrapCode,
93 /// Native stack backtrace at the time the trap occurred
94 backtrace: Backtrace,
95 },
96
97 /// A trap indicating that the runtime was unable to allocate sufficient memory.
98 ///
99 /// Note: this trap is nondeterministic, since it depends on the host system.
100 OOM {
101 /// Native stack backtrace at the time the OOM occurred
102 backtrace: Backtrace,
103 },
104}
105
106impl Trap {
107 /// Construct a new Wasm trap with the given source location and backtrace.
108 ///
109 /// Internally saves a backtrace when constructed.
110 pub fn wasm(pc: usize, backtrace: Backtrace, signal_trap: Option<TrapCode>) -> Self {
111 Self::Wasm { pc, backtrace, signal_trap }
112 }
113
114 /// Construct a new Wasm trap with the given trap code.
115 ///
116 /// Internally saves a backtrace when constructed.
117 pub fn lib(trap_code: TrapCode) -> Self {
118 let backtrace = Backtrace::new_unresolved();
119 Self::Lib { trap_code, backtrace }
120 }
121
122 /// Construct a new OOM trap with the given source location and trap code.
123 ///
124 /// Internally saves a backtrace when constructed.
125 pub fn oom() -> Self {
126 let backtrace = Backtrace::new_unresolved();
127 Self::OOM { backtrace }
128 }
129}
130
131/// Call the VM function pointed to by `callee`.
132///
133/// * `callee_env` - the function environment
134/// * `trampoline` - the jit-generated trampoline whose ABI takes 3 values, the
135/// callee funcenv, the `callee` argument below, and then the `values_vec` argument.
136/// * `callee` - the 2nd argument to the `trampoline` function
137/// * `values_vec` - points to a buffer which holds the incoming arguments, and to
138/// which the outgoing return values will be written.
139///
140/// Prefer invoking this via `Instance::invoke_trampoline`.
141///
142/// # Safety
143///
144/// Wildly unsafe because it calls raw function pointers and reads/writes raw
145/// function pointers.
146pub unsafe fn near_vm_call_trampoline(
147 callee_env: VMFunctionEnvironment,
148 trampoline: VMTrampoline,
149 callee: *const VMFunctionBody,
150 values_vec: *mut u8,
151) -> Result<(), Trap> {
152 unsafe {
153 catch_traps(|| {
154 mem::transmute::<
155 VMTrampoline,
156 extern "C" fn(VMFunctionEnvironment, *const VMFunctionBody, *mut u8),
157 >(trampoline)(callee_env, callee, values_vec);
158 })
159 }
160}
161
162/// Catches any wasm traps that happen within the execution of `closure`,
163/// returning them as a `Result`.
164///
165/// # Safety
166///
167/// Soundness must not depend on `closure` destructors being run.
168pub unsafe fn catch_traps<F>(mut closure: F) -> Result<(), Trap>
169where
170 F: FnMut(),
171{
172 return CallThreadState::new().with(|cx| unsafe {
173 near_vm_register_setjmp(
174 cx.jmp_buf.as_ptr(),
175 call_closure::<F>,
176 &mut closure as *mut F as *mut u8,
177 )
178 });
179
180 extern "C" fn call_closure<F>(payload: *mut u8)
181 where
182 F: FnMut(),
183 {
184 unsafe { (*(payload as *mut F))() }
185 }
186}
187
188/// Catches any wasm traps that happen within the execution of `closure`,
189/// returning them as a `Result`, with the closure contents.
190///
191/// The main difference from this method and `catch_traps`, is that is able
192/// to return the results from the closure.
193///
194/// # Safety
195///
196/// Check [`catch_traps`].
197pub unsafe fn catch_traps_with_result<F, R>(mut closure: F) -> Result<R, Trap>
198where
199 F: FnMut() -> R,
200{
201 let mut global_results = MaybeUninit::<R>::uninit();
202 unsafe {
203 catch_traps(|| {
204 global_results.as_mut_ptr().write(closure());
205 })?;
206 // FIXME: whoa here, what happens if `closure()` *does* trap?
207 Ok(global_results.assume_init())
208 }
209}
210
211/// Temporary state stored on the stack which is registered in the `tls` module
212/// below for calls into wasm.
213pub struct CallThreadState {
214 unwind: UnsafeCell<MaybeUninit<UnwindReason>>,
215 jmp_buf: Cell<*const u8>,
216 prev: Cell<tls::Ptr>,
217}
218
219enum UnwindReason {
220 /// A panic caused by the host
221 Panic(Box<dyn Any + Send>),
222 /// A custom error triggered by the user
223 UserTrap(Box<dyn Error + Send + Sync>),
224 /// A Trap triggered by a wasm libcall
225 LibTrap(Trap),
226 /// A trap caused by the Wasm generated code
227 WasmTrap { backtrace: Backtrace, pc: usize, signal_trap: Option<TrapCode> },
228}
229
230impl<'a> CallThreadState {
231 #[inline]
232 fn new() -> Self {
233 Self {
234 unwind: UnsafeCell::new(MaybeUninit::uninit()),
235 jmp_buf: Cell::new(ptr::null()),
236 prev: Cell::new(ptr::null()),
237 }
238 }
239
240 fn with(self, closure: impl FnOnce(&Self) -> i32) -> Result<(), Trap> {
241 let ret = tls::set(&self, || closure(&self))?;
242 if ret != 0 {
243 return Ok(());
244 }
245 // We will only reach this path if ret == 0. And that will
246 // only happen if a trap did happen. As such, it's safe to
247 // assume that the `unwind` field is already initialized
248 // at this moment.
249 match unsafe { (*self.unwind.get()).as_ptr().read() } {
250 UnwindReason::UserTrap(data) => Err(Trap::User(data)),
251 UnwindReason::LibTrap(trap) => Err(trap),
252 UnwindReason::WasmTrap { backtrace, pc, signal_trap } => {
253 Err(Trap::wasm(pc, backtrace, signal_trap))
254 }
255 UnwindReason::Panic(panic) => std::panic::resume_unwind(panic),
256 }
257 }
258
259 fn unwind_with(&self, reason: UnwindReason) -> ! {
260 unsafe {
261 (*self.unwind.get()).as_mut_ptr().write(reason);
262 near_vm_unwind(self.jmp_buf.get());
263 }
264 }
265}
266
267// A private inner module for managing the TLS state that we require across
268// calls in wasm. The WebAssembly code is called from C++ and then a trap may
269// happen which requires us to read some contextual state to figure out what to
270// do with the trap. This `tls` module is used to persist that information from
271// the caller to the trap site.
272mod tls {
273 use super::CallThreadState;
274 use crate::Trap;
275 use std::mem;
276 use std::ptr;
277
278 pub use raw::Ptr;
279
280 // An even *more* inner module for dealing with TLS. This actually has the
281 // thread local variable and has functions to access the variable.
282 //
283 // Note that this is specially done to fully encapsulate that the accessors
284 // for tls must not be inlined. Wasmer's async support will employ stack
285 // switching which can resume execution on different OS threads. This means
286 // that borrows of our TLS pointer must never live across accesses because
287 // otherwise the access may be split across two threads and cause unsafety.
288 //
289 // This also means that extra care is taken by the runtime to save/restore
290 // these TLS values when the runtime may have crossed threads.
291 mod raw {
292 use super::CallThreadState;
293 use crate::Trap;
294 use std::cell::Cell;
295 use std::ptr;
296
297 pub type Ptr = *const CallThreadState;
298
299 // The first entry here is the `Ptr` which is what's used as part of the
300 // public interface of this module. The second entry is a boolean which
301 // allows the runtime to perform per-thread initialization if necessary
302 // for handling traps (e.g. setting up ports on macOS and sigaltstack on
303 // Unix).
304 thread_local!(static PTR: Cell<Ptr> = const { Cell::new(ptr::null()) });
305
306 #[inline(never)] // see module docs for why this is here
307 pub fn replace(val: Ptr) -> Result<Ptr, Trap> {
308 PTR.with(|p| {
309 // When a new value is configured that means that we may be
310 // entering WebAssembly so check to see if this thread has
311 // performed per-thread initialization for traps.
312 let prev = p.get();
313 p.set(val);
314 Ok(prev)
315 })
316 }
317
318 #[inline(never)] // see module docs for why this is here
319 pub fn get() -> Ptr {
320 PTR.with(|p| p.get())
321 }
322 }
323
324 /// Opaque state used to help control TLS state across stack switches for
325 /// async support.
326 pub struct TlsRestore(raw::Ptr);
327
328 impl TlsRestore {
329 /// Takes the TLS state that is currently configured and returns a
330 /// token that is used to replace it later.
331 ///
332 /// # Safety
333 ///
334 /// This is not a safe operation since it's intended to only be used
335 /// with stack switching found with fibers and async near_vm.
336 pub unsafe fn take() -> Result<Self, Trap> {
337 // Our tls pointer must be set at this time, and it must not be
338 // null. We need to restore the previous pointer since we're
339 // removing ourselves from the call-stack, and in the process we
340 // null out our own previous field for safety in case it's
341 // accidentally used later.
342 let raw = raw::get();
343 unsafe {
344 assert!(!raw.is_null());
345 let prev = (*raw).prev.replace(ptr::null());
346 raw::replace(prev)?;
347 }
348 Ok(Self(raw))
349 }
350
351 /// Restores a previous tls state back into this thread's TLS.
352 ///
353 /// # Safety
354 ///
355 /// This is unsafe because it's intended to only be used within the
356 /// context of stack switching within near_vm.
357 pub unsafe fn replace(self) -> Result<(), super::Trap> {
358 // We need to configure our previous TLS pointer to whatever is in
359 // TLS at this time, and then we set the current state to ourselves.
360 let prev = raw::get();
361 unsafe {
362 assert!((*self.0).prev.get().is_null());
363 (*self.0).prev.set(prev);
364 }
365 raw::replace(self.0)?;
366 Ok(())
367 }
368 }
369
370 /// Configures thread local state such that for the duration of the
371 /// execution of `closure` any call to `with` will yield `ptr`, unless this
372 /// is recursively called again.
373 pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> Result<R, Trap> {
374 struct Reset<'a>(&'a CallThreadState);
375
376 impl Drop for Reset<'_> {
377 #[inline]
378 fn drop(&mut self) {
379 raw::replace(self.0.prev.replace(ptr::null()))
380 .expect("tls should be previously initialized");
381 }
382 }
383
384 // Note that this extension of the lifetime to `'static` should be
385 // safe because we only ever access it below with an anonymous
386 // lifetime, meaning `'static` never leaks out of this module.
387 let ptr = unsafe { mem::transmute::<*const CallThreadState, _>(state) };
388 let prev = raw::replace(ptr)?;
389 state.prev.set(prev);
390 let _reset = Reset(state);
391 Ok(closure())
392 }
393
394 /// Returns the last pointer configured with `set` above. Panics if `set`
395 /// has not been previously called and not returned.
396 pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
397 let p = raw::get();
398 unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
399 }
400}
401
402extern "C" fn signal_less_trap_handler(pc: *const u8, trap: TrapCode) {
403 let jmp_buf = tls::with(|info| {
404 let backtrace = Backtrace::new_unresolved();
405 let info = info.unwrap();
406 unsafe {
407 (*info.unwind.get()).as_mut_ptr().write(UnwindReason::WasmTrap {
408 backtrace,
409 signal_trap: Some(trap),
410 pc: pc as usize,
411 });
412 info.jmp_buf.get()
413 }
414 });
415 unsafe {
416 near_vm_unwind(jmp_buf);
417 }
418}
419
420/// Returns pointer to the trap handler used in VMContext.
421pub fn get_trap_handler() -> *const u8 {
422 signal_less_trap_handler as *const u8
423}