1use std::{
5 any::{Any, TypeId},
6 cell::{Cell, RefCell},
7 collections::HashMap,
8 ffi::CStr,
9 marker::PhantomData,
10 mem::ManuallyDrop,
11 ops::{Deref, DerefMut},
12 os::raw::c_void,
13 pin::Pin,
14 ptr::NonNull,
15 rc::Rc,
16 sync::{
17 atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18 Arc, Once,
19 },
20 thread::ThreadId,
21 time::{Duration, Instant},
22};
23
24use corosensei::{
25 stack::{DefaultStack, Stack},
26 Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36 error::{Error, RuntimeError},
37 helpers::Helper,
38 linker::link_elf,
39 pointer_cage::PointerCage,
40 util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_MUTABLE_DEREF_REGIONS: usize = 4;
47const MAX_IMMUTABLE_DEREF_REGIONS: usize = 16;
48
49pub struct InvokeScope {
51 data: HashMap<TypeId, Box<dyn Any + Send>>,
52}
53
54impl InvokeScope {
55 pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
57 let ty = TypeId::of::<T>();
58 self
59 .data
60 .entry(ty)
61 .or_insert_with(|| Box::new(T::default()))
62 .downcast_mut()
63 .expect("InvokeScope::data_mut: downcast failed")
64 }
65}
66
67pub struct HelperScope<'a, 'b> {
69 pub program: &'a Program,
71 pub invoke: RefCell<&'a mut InvokeScope>,
73 resources: RefCell<&'a mut [&'b mut dyn Any]>,
74 mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_MUTABLE_DEREF_REGIONS],
75 immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_IMMUTABLE_DEREF_REGIONS],
76 can_post_task: bool,
77}
78
79pub struct MutableUserMemory<'a, 'b, 'c> {
81 _scope: &'c HelperScope<'a, 'b>,
82 region: NonNull<[u8]>,
83}
84
85impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
86 type Target = [u8];
87
88 fn deref(&self) -> &Self::Target {
89 unsafe { self.region.as_ref() }
90 }
91}
92
93impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
94 fn deref_mut(&mut self) -> &mut Self::Target {
95 unsafe { self.region.as_mut() }
96 }
97}
98
99impl<'a, 'b> HelperScope<'a, 'b> {
100 pub fn post_task(
102 &self,
103 task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
104 ) {
105 if !self.can_post_task {
106 panic!("HelperScope::post_task() called in a context where posting task is not allowed");
107 }
108
109 PENDING_ASYNC_TASK.with(|x| {
110 let mut x = x.borrow_mut();
111 if x.is_some() {
112 panic!("post_task called while another task is pending");
113 }
114 *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
115 });
116 }
117
118 pub fn with_resource_mut<'c, T: 'static, R>(
120 &'c self,
121 callback: impl FnOnce(Result<&mut T, ()>) -> R,
122 ) -> R {
123 let mut resources = self.resources.borrow_mut();
124 let Some(res) = resources
125 .iter_mut()
126 .filter_map(|x| x.downcast_mut::<T>())
127 .next()
128 else {
129 tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
130 return callback(Err(()));
131 };
132
133 callback(Ok(res))
134 }
135
136 pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
138 let Some(region) = self
139 .program
140 .unbound
141 .cage
142 .safe_deref_for_read(ptr as usize, size as usize)
143 else {
144 tracing::warn!(ptr, size, "invalid read");
145 return Err(());
146 };
147
148 if size != 0 {
149 if self
151 .mutable_dereferenced_regions
152 .iter()
153 .filter_map(|x| x.get())
154 .any(|x| nonnull_bytes_overlap(x, region))
155 {
156 tracing::warn!(ptr, size, "read overlapped with previous write");
157 return Err(());
158 }
159
160 let Some(slot) = self
162 .immutable_dereferenced_regions
163 .iter()
164 .find(|x| x.get().is_none())
165 else {
166 tracing::warn!(ptr, size, "too many reads");
167 return Err(());
168 };
169 slot.set(Some(region));
170 }
171
172 Ok(unsafe { region.as_ref() })
173 }
174
175 pub fn user_memory_mut<'c>(
177 &'c self,
178 ptr: u64,
179 size: u64,
180 ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
181 let Some(region) = self
182 .program
183 .unbound
184 .cage
185 .safe_deref_for_write(ptr as usize, size as usize)
186 else {
187 tracing::warn!(ptr, size, "invalid write");
188 return Err(());
189 };
190
191 if size != 0 {
192 if self
194 .mutable_dereferenced_regions
195 .iter()
196 .chain(self.immutable_dereferenced_regions.iter())
197 .filter_map(|x| x.get())
198 .any(|x| nonnull_bytes_overlap(x, region))
199 {
200 tracing::warn!(ptr, size, "write overlapped with previous read/write");
201 return Err(());
202 }
203
204 let Some(slot) = self
206 .mutable_dereferenced_regions
207 .iter()
208 .find(|x| x.get().is_none())
209 else {
210 tracing::warn!(ptr, size, "too many writes");
211 return Err(());
212 };
213 slot.set(Some(region));
214 }
215
216 Ok(MutableUserMemory {
217 _scope: self,
218 region,
219 })
220 }
221}
222
223#[derive(Copy, Clone)]
224struct AssumeSend<T>(T);
225unsafe impl<T> Send for AssumeSend<T> {}
226
227struct ExecContext {
228 native_stack: DefaultStack,
229 copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
230}
231
232impl ExecContext {
233 fn new() -> Self {
234 Self {
235 native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
236 .expect("failed to initialize native stack"),
237 copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
238 }
239 }
240}
241
242pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
244pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
246
247static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
248
249#[derive(Copy, Clone, Debug)]
250enum PreemptionState {
251 Inactive,
252 Active(usize),
253 Shutdown,
254}
255
256type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
257
258thread_local! {
259 static RUST_TID: ThreadId = std::thread::current().id();
260 static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
261 static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
262 static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
263 static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
264 static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
265}
266
267struct BorrowedExecContext {
268 ctx: ManuallyDrop<ExecContext>,
269}
270
271impl BorrowedExecContext {
272 fn new(stack_rng: &mut impl rand::Rng) -> Self {
273 let mut me = Self {
274 ctx: ManuallyDrop::new(
275 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
276 ),
277 };
278 stack_rng.fill(&mut *me.ctx.copy_stack);
279 me
280 }
281}
282
283impl Drop for BorrowedExecContext {
284 fn drop(&mut self) {
285 let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
286 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
287 }
288}
289
290#[derive(Default)]
291struct ActiveJitCodeZone {
292 valid: AtomicBool,
293 code_range: Cell<(usize, usize)>,
294 pointer_cage_protected_range: Cell<(usize, usize)>,
295 yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
296}
297
298pub trait ProgramEventListener: Send + Sync + 'static {
300 fn did_async_preempt(&self, _scope: &HelperScope) {}
302 fn did_yield(&self) {}
304 fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
306 None
307 }
308 fn did_save_shadow_stack(&self) {}
310 fn did_restore_shadow_stack(&self) {}
312}
313
314pub struct DummyProgramEventListener;
316impl ProgramEventListener for DummyProgramEventListener {}
317
318pub struct ProgramLoader {
320 helpers_inverse: HashMap<&'static str, i32>,
321 event_listener: Arc<dyn ProgramEventListener>,
322 helper_id_xor: u16,
323 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
324}
325
326pub struct UnboundProgram {
328 id: u64,
329 _code_mem: MmapRaw,
330 cage: PointerCage,
331 helper_id_xor: u16,
332 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
333 event_listener: Arc<dyn ProgramEventListener>,
334 entrypoints: HashMap<String, Entrypoint>,
335 stack_rng_seed: [u8; 32],
336}
337
338pub struct Program {
340 unbound: UnboundProgram,
341 data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
342 stack_rng: RefCell<ChaCha8Rng>,
343 t: ThreadEnv,
344}
345
346#[derive(Copy, Clone)]
347struct Entrypoint {
348 code_ptr: usize,
349 code_len: usize,
350}
351
352#[derive(Clone, Debug)]
354pub struct TimesliceConfig {
355 pub max_run_time_before_yield: Duration,
357 pub max_run_time_before_throttle: Duration,
359 pub throttle_duration: Duration,
361}
362
363pub trait Timeslicer {
365 fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
367 fn yield_now(&self) -> impl Future<Output = ()>;
369}
370
371#[derive(Copy, Clone)]
373pub struct GlobalEnv(());
374
375#[derive(Copy, Clone)]
377pub struct ThreadEnv {
378 _not_send_sync: std::marker::PhantomData<*const ()>,
379}
380
381impl GlobalEnv {
382 pub unsafe fn new() -> Self {
387 static INIT: Once = Once::new();
388
389 INIT.call_once(|| {
404 let sa_mask = get_blocked_sigset();
405
406 for (sig, handler) in [
407 (libc::SIGUSR1, sigusr1_handler as usize),
408 (libc::SIGSEGV, sigsegv_handler as usize),
409 ] {
410 let act = libc::sigaction {
411 sa_sigaction: handler,
412 sa_flags: libc::SA_SIGINFO,
413 sa_mask,
414 sa_restorer: None,
415 };
416 if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
417 panic!("failed to setup handler for signal {}", sig);
418 }
419 }
420 });
421
422 Self(())
423 }
424
425 pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
427 struct DeferDrop(Arc<PreemptionStateSignal>);
428 impl Drop for DeferDrop {
429 fn drop(&mut self) {
430 let x = &self.0;
431 *x.0.lock() = PreemptionState::Shutdown;
432 x.1.notify_one();
433 }
434 }
435
436 thread_local! {
437 static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
438 }
439
440 if WATCHER.with(|x| x.borrow().is_some()) {
441 return ThreadEnv {
442 _not_send_sync: PhantomData,
443 };
444 }
445
446 let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
447
448 unsafe {
449 let tgid = libc::getpid();
450 let tid = libc::gettid();
451
452 std::thread::Builder::new()
453 .name("preempt-watcher".to_string())
454 .spawn(move || {
455 let mut state = preemption_state.0.lock();
456 loop {
457 match *state {
458 PreemptionState::Shutdown => break,
459 PreemptionState::Inactive => {
460 preemption_state.1.wait(&mut state);
461 }
462 PreemptionState::Active(_) => {
463 let timeout = preemption_state.1.wait_while_for(
464 &mut state,
465 |x| matches!(x, PreemptionState::Active(_)),
466 async_preemption_interval,
467 );
468 if timeout.timed_out() {
469 let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
470 if ret != 0 {
471 break;
472 }
473 }
474 }
475 }
476 }
477 })
478 .expect("failed to spawn preemption watcher");
479
480 WATCHER.with(|x| {
481 x.borrow_mut()
482 .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
483 });
484
485 ThreadEnv {
486 _not_send_sync: PhantomData,
487 }
488 }
489 }
490}
491
492impl UnboundProgram {
493 pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
495 let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
496 Program {
497 unbound: self,
498 data: RefCell::new(HashMap::new()),
499 stack_rng,
500 t,
501 }
502 }
503}
504
505pub struct PreemptionEnabled(());
506
507impl PreemptionEnabled {
508 pub fn new(_: ThreadEnv) -> Self {
509 PREEMPTION_STATE.with(|x| {
510 let mut notify = false;
511 {
512 let mut st = x.0.lock();
513 let next = match *st {
514 PreemptionState::Inactive => {
515 notify = true;
516 PreemptionState::Active(1)
517 }
518 PreemptionState::Active(n) => PreemptionState::Active(n + 1),
519 PreemptionState::Shutdown => unreachable!(),
520 };
521 *st = next;
522 }
523
524 if notify {
525 x.1.notify_one();
526 }
527 });
528 Self(())
529 }
530}
531
532impl Drop for PreemptionEnabled {
533 fn drop(&mut self) {
534 PREEMPTION_STATE.with(|x| {
535 let mut st = x.0.lock();
536 let next = match *st {
537 PreemptionState::Active(1) => PreemptionState::Inactive,
538 PreemptionState::Active(n) => {
539 assert!(n > 1);
540 PreemptionState::Active(n - 1)
541 }
542 PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
543 };
544 *st = next;
545 });
546 }
547}
548
549impl Program {
550 pub fn id(&self) -> u64 {
552 self.unbound.id
553 }
554
555 pub fn thread_env(&self) -> ThreadEnv {
556 self.t
557 }
558
559 pub fn data<T: Default + 'static>(&self) -> Rc<T> {
561 let mut data = self.data.borrow_mut();
562 let entry = data.entry(TypeId::of::<T>());
563 let entry = entry.or_insert_with(|| Rc::new(T::default()));
564 entry.clone().downcast().unwrap()
565 }
566
567 pub fn has_section(&self, name: &str) -> bool {
568 self.unbound.entrypoints.contains_key(name)
569 }
570
571 pub async fn run(
573 &self,
574 timeslice: &TimesliceConfig,
575 timeslicer: &impl Timeslicer,
576 entrypoint: &str,
577 resources: &mut [&mut dyn Any],
578 calldata: &[u8],
579 preemption: &PreemptionEnabled,
580 ) -> Result<i64, Error> {
581 self
582 ._run(
583 timeslice, timeslicer, entrypoint, resources, calldata, preemption,
584 )
585 .await
586 .map_err(Error)
587 }
588
589 async fn _run(
590 &self,
591 timeslice: &TimesliceConfig,
592 timeslicer: &impl Timeslicer,
593 entrypoint: &str,
594 resources: &mut [&mut dyn Any],
595 calldata: &[u8],
596 _: &PreemptionEnabled,
597 ) -> Result<i64, RuntimeError> {
598 let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
599 return Err(RuntimeError::InvalidArgument("entrypoint not found"));
600 };
601
602 let entry = unsafe {
603 std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
604 entrypoint.code_ptr,
605 )
606 };
607 struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
608 ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
609 );
610 impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
611 for CoDropper<'a, Input, Yield, Return, DefaultStack>
612 {
613 fn drop(&mut self) {
614 unsafe {
622 self.0.force_reset();
623 }
624 }
625 }
626
627 let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
628
629 if calldata.len() > MAX_CALLDATA_SIZE {
630 return Err(RuntimeError::InvalidArgument("calldata too large"));
631 }
632 ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
633 let calldata_len = calldata.len();
634
635 let program_ret: u64 = {
636 let shadow_stack_top = self.unbound.cage.stack_top();
637 let shadow_stack_ptr = AssumeSend(
638 self
639 .unbound
640 .cage
641 .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
642 .unwrap(),
643 );
644 let ctx = &mut *ectx.ctx;
645
646 let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
647 &mut ctx.native_stack,
648 move |yielder, _input| unsafe {
649 ACTIVE_JIT_CODE_ZONE.with(|x| {
650 x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
651 });
652 entry(
653 shadow_stack_top - calldata_len,
654 shadow_stack_top - calldata_len,
655 )
656 },
657 )));
658
659 let mut last_yield_time = Instant::now();
660 let mut last_throttle_time = Instant::now();
661 let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
662 let mut resume_input: u64 = 0;
663 let mut did_throttle = false;
664 let mut shadow_stack_saved = true;
665 let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
666 let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
667 let mut invoke_scope = InvokeScope {
668 data: HashMap::new(),
669 };
670
671 loop {
672 ACTIVE_JIT_CODE_ZONE.with(|x| {
673 x.code_range.set((
674 entrypoint.code_ptr,
675 entrypoint.code_ptr + entrypoint.code_len,
676 ));
677 x.yielder.set(yielder.map(|x| x.0));
678 x.pointer_cage_protected_range
679 .set(self.unbound.cage.protected_range_without_margins());
680 compiler_fence(Ordering::Release);
681 x.valid.store(true, Ordering::Relaxed);
682 });
683
684 if shadow_stack_saved {
685 shadow_stack_saved = false;
686
687 unsafe {
689 std::ptr::copy_nonoverlapping(
690 ctx.copy_stack.as_ptr() as *const u8,
691 shadow_stack_ptr.0.as_ptr() as *mut u8,
692 SHADOW_STACK_SIZE,
693 );
694 }
695
696 self.unbound.event_listener.did_restore_shadow_stack();
697 }
698
699 if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
701 resume_input = prev_async_task_output(&HelperScope {
702 program: self,
703 invoke: RefCell::new(&mut invoke_scope),
704 resources: RefCell::new(resources),
705 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
706 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
707 can_post_task: false,
708 })
709 .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
710 }
711
712 let ret = co.0 .0.resume(resume_input);
713 ACTIVE_JIT_CODE_ZONE.with(|x| {
714 x.valid.store(false, Ordering::Relaxed);
715 compiler_fence(Ordering::Release);
716 yielder = x.yielder.get().map(AssumeSend);
717 });
718
719 let dispatch: Dispatch = match ret {
720 CoroutineResult::Return(x) => break x,
721 CoroutineResult::Yield(x) => x,
722 };
723
724 if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
726 unsafe {
727 let unblock = get_blocked_sigset();
728 libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
729 }
730 }
731
732 if let Some(si_addr) = dispatch.memory_access_error {
733 let vaddr = si_addr - self.unbound.cage.offset();
734 return Err(RuntimeError::MemoryFault(vaddr));
735 }
736
737 PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
739 let mut helper_name: &'static str = "";
740
741 let mut helper_scope = HelperScope {
742 program: self,
743 invoke: RefCell::new(&mut invoke_scope),
744 resources: RefCell::new(resources),
745 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
746 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
747 can_post_task: false,
748 };
749
750 if dispatch.async_preemption {
751 self
752 .unbound
753 .event_listener
754 .did_async_preempt(&mut helper_scope);
755 } else {
756 let Some((_, got_helper_name, helper)) = self
758 .unbound
759 .helpers
760 .get(
761 ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
762 as usize,
763 )
764 .copied()
765 else {
766 panic!("unknown helper index: {}", dispatch.index);
767 };
768 helper_name = got_helper_name;
769
770 helper_scope.can_post_task = true;
771 resume_input = helper(
772 &mut helper_scope,
773 dispatch.arg1,
774 dispatch.arg2,
775 dispatch.arg3,
776 dispatch.arg4,
777 dispatch.arg5,
778 )
779 .map_err(|()| RuntimeError::HelperError(helper_name))?;
780 helper_scope.can_post_task = false;
781 }
782
783 let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
784
785 let new_rust_tid_sigusr1_counter =
787 (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
788 if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
789 {
790 continue;
791 }
792
793 rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
794
795 let now = Instant::now();
796 let should_throttle = now > last_throttle_time
797 && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
798 let should_yield = now > last_yield_time
799 && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
800 if should_throttle || should_yield || pending_async_task.is_some() {
801 shadow_stack_saved = true;
803 unsafe {
804 std::ptr::copy_nonoverlapping(
805 shadow_stack_ptr.0.as_ptr() as *const u8,
806 ctx.copy_stack.as_mut_ptr() as *mut u8,
807 SHADOW_STACK_SIZE,
808 );
809 }
810 self.unbound.event_listener.did_save_shadow_stack();
811
812 if should_throttle {
815 if !did_throttle {
816 did_throttle = true;
817 tracing::warn!("throttling program");
818 }
819 timeslicer.sleep(timeslice.throttle_duration).await;
820 let now = Instant::now();
821 last_throttle_time = now;
822 last_yield_time = now;
823 let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
824 if let Some(task) = task {
825 task.await;
826 }
827 } else if should_yield {
828 timeslicer.yield_now().await;
829 let now = Instant::now();
830 last_yield_time = now;
831 self.unbound.event_listener.did_yield();
832 }
833
834 if let Some(pending_async_task) = pending_async_task {
836 let async_start = Instant::now();
837 prev_async_task_output = Some((helper_name, pending_async_task.await));
838 let async_dur = async_start.elapsed();
839 last_throttle_time += async_dur;
840 last_yield_time += async_dur;
841 }
842 }
843 }
844 };
845
846 Ok(program_ret as i64)
847 }
848}
849
850struct Vm(NonNull<crate::ubpf::ubpf_vm>);
851
852impl Vm {
853 fn new(cage: &PointerCage) -> Self {
854 let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
855 unsafe {
856 crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
857 crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
858 crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
859 }
860 Self(vm)
861 }
862}
863
864impl Drop for Vm {
865 fn drop(&mut self) {
866 unsafe {
867 crate::ubpf::ubpf_destroy(self.0.as_ptr());
868 }
869 }
870}
871
872impl ProgramLoader {
873 pub fn new(
875 rng: &mut impl rand::Rng,
876 event_listener: Arc<dyn ProgramEventListener>,
877 raw_helpers: &[&[(&'static str, Helper)]],
878 ) -> Self {
879 let helper_id_xor = rng.gen::<u16>();
880 let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
881 let mut shuffled_helpers = raw_helpers
883 .iter()
884 .flat_map(|x| x.iter().copied())
885 .collect::<HashMap<_, _>>()
886 .into_iter()
887 .collect::<Vec<_>>();
888 shuffled_helpers.shuffle(rng);
889 let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
890
891 assert!(shuffled_helpers.len() <= 65535);
892
893 for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
894 let entropy = rng.gen::<u16>() & 0x7fff;
895 helpers.push((entropy, name, helper));
896 helpers_inverse.insert(
897 name,
898 (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
899 );
900 }
901
902 tracing::info!(?helpers_inverse, "generated helper table");
903 Self {
904 helper_id_xor,
905 helpers: Arc::new(helpers),
906 helpers_inverse,
907 event_listener,
908 }
909 }
910
911 pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
913 self._load(rng, elf).map_err(Error)
914 }
915
916 fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
917 let start_time = Instant::now();
918 let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
919 let vm = Vm::new(&cage);
920
921 let code_sections = {
923 let mut data = cage
927 .safe_deref_for_read(cage.data_bottom(), elf.len())
928 .unwrap();
929 let data = unsafe { data.as_mut() };
930 data.copy_from_slice(elf);
931
932 link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
933 };
934 cage.freeze_data();
935
936 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
937 if page_size < 0 {
938 return Err(RuntimeError::PlatformError("failed to get page size"));
939 }
940 let page_size = page_size as usize;
941
942 let guard_size_before = rng.gen_range(16..128) * page_size;
944 let mut guard_size_after = rng.gen_range(16..128) * page_size;
945
946 let code_len_allocated: usize = 65536;
947 let code_mem = MmapRaw::from(
948 MmapOptions::new()
949 .len(code_len_allocated + guard_size_before + guard_size_after)
950 .map_anon()
951 .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
952 );
953
954 unsafe {
955 if crate::ubpf::ubpf_register_external_dispatcher(
956 vm.0.as_ptr(),
957 Some(tls_dispatcher),
958 Some(std_validator),
959 self as *const _ as *mut c_void,
960 ) != 0
961 {
962 return Err(RuntimeError::PlatformError(
963 "ubpf: failed to register external dispatcher",
964 ));
965 }
966 if libc::mprotect(
967 code_mem.as_mut_ptr() as *mut _,
968 guard_size_before,
969 libc::PROT_NONE,
970 ) != 0
971 || libc::mprotect(
972 code_mem
973 .as_mut_ptr()
974 .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
975 guard_size_after,
976 libc::PROT_NONE,
977 ) != 0
978 {
979 return Err(RuntimeError::PlatformError("failed to protect guard pages"));
980 }
981 }
982
983 let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
984
985 unsafe {
986 let mut code_slice = std::slice::from_raw_parts_mut(
988 code_mem.as_mut_ptr().offset(guard_size_before as isize),
989 code_len_allocated,
990 );
991 for (section_name, code_vaddr_size) in code_sections {
992 if code_slice.is_empty() {
993 return Err(RuntimeError::InvalidArgument(
994 "no space left for jit compilation",
995 ));
996 }
997
998 crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
999
1000 let mut errmsg_ptr = std::ptr::null_mut();
1001 let code = cage
1002 .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
1003 .unwrap();
1004 let ret = crate::ubpf::ubpf_load(
1005 vm.0.as_ptr(),
1006 code.as_ptr() as *const _,
1007 code.len() as u32,
1008 &mut errmsg_ptr,
1009 );
1010 if ret != 0 {
1011 let errmsg = if errmsg_ptr.is_null() {
1012 "".to_string()
1013 } else {
1014 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1015 };
1016 if !errmsg_ptr.is_null() {
1017 libc::free(errmsg_ptr as _);
1018 }
1019 tracing::error!(section_name, error = errmsg, "failed to load code");
1020 return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1021 }
1022
1023 let mut written_len = code_slice.len();
1024 let ret = crate::ubpf::ubpf_translate(
1025 vm.0.as_ptr(),
1026 code_slice.as_mut_ptr(),
1027 &mut written_len,
1028 &mut errmsg_ptr,
1029 );
1030 if ret != 0 {
1031 let errmsg = if errmsg_ptr.is_null() {
1032 "".to_string()
1033 } else {
1034 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1035 };
1036 if !errmsg_ptr.is_null() {
1037 libc::free(errmsg_ptr as _);
1038 }
1039 tracing::error!(section_name, error = errmsg, "failed to translate code");
1040 return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1041 }
1042
1043 assert!(written_len <= code_slice.len());
1044 entrypoints.insert(
1045 section_name,
1046 Entrypoint {
1047 code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1048 - code_slice.len(),
1049 code_len: written_len,
1050 },
1051 );
1052 code_slice = &mut code_slice[written_len..];
1053 }
1054
1055 let unpadded_code_len = code_len_allocated - code_slice.len();
1057 let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1058 assert!(code_len <= code_len_allocated);
1059
1060 if libc::mprotect(
1063 code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1064 code_len,
1065 libc::PROT_READ | libc::PROT_EXEC,
1066 ) != 0
1067 || (code_len < code_len_allocated
1068 && libc::mprotect(
1069 code_mem
1070 .as_mut_ptr()
1071 .offset((guard_size_before + code_len) as isize) as *mut _,
1072 code_len_allocated - code_len,
1073 libc::PROT_NONE,
1074 ) != 0)
1075 {
1076 return Err(RuntimeError::PlatformError("failed to protect code memory"));
1077 }
1078
1079 guard_size_after += code_len_allocated - code_len;
1080
1081 tracing::info!(
1082 elf_size = elf.len(),
1083 native_code_addr = ?code_mem.as_ptr(),
1084 native_code_size = code_len,
1085 native_code_size_unpadded = unpadded_code_len,
1086 guard_size_before,
1087 guard_size_after,
1088 duration = ?start_time.elapsed(),
1089 cage_ptr = ?cage.region().as_ptr(),
1090 cage_mapped_size = cage.region().len(),
1091 "jit compiled program"
1092 );
1093
1094 Ok(UnboundProgram {
1095 id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1096 _code_mem: code_mem,
1097 cage,
1098 helper_id_xor: self.helper_id_xor,
1099 helpers: self.helpers.clone(),
1100 event_listener: self.event_listener.clone(),
1101 entrypoints,
1102 stack_rng_seed: rng.gen(),
1103 })
1104 }
1105 }
1106}
1107
1108#[derive(Default)]
1109struct Dispatch {
1110 async_preemption: bool,
1111 memory_access_error: Option<usize>,
1112
1113 index: u32,
1114 arg1: u64,
1115 arg2: u64,
1116 arg3: u64,
1117 arg4: u64,
1118 arg5: u64,
1119}
1120
1121unsafe extern "C" fn tls_dispatcher(
1122 arg1: u64,
1123 arg2: u64,
1124 arg3: u64,
1125 arg4: u64,
1126 arg5: u64,
1127 index: std::os::raw::c_uint,
1128 _cookie: *mut std::os::raw::c_void,
1129) -> u64 {
1130 let yielder = ACTIVE_JIT_CODE_ZONE
1131 .with(|x| x.yielder.get())
1132 .expect("no yielder");
1133 let yielder = yielder.as_ref();
1134 let ret = yielder.suspend(Dispatch {
1135 async_preemption: false,
1136 memory_access_error: None,
1137 index,
1138 arg1,
1139 arg2,
1140 arg3,
1141 arg4,
1142 arg5,
1143 });
1144 ret
1145}
1146
1147unsafe extern "C" fn std_validator(
1148 index: std::os::raw::c_uint,
1149 loader: *mut std::os::raw::c_void,
1150) -> bool {
1151 let loader = &*(loader as *const ProgramLoader);
1152 let entropy = (index >> 16) & 0xffff;
1153 let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1154 loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1155}
1156
1157#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1158unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1159 (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1160}
1161
1162#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1163unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1164 (*uctx).uc_mcontext.pc as usize
1165}
1166
1167unsafe extern "C" fn sigsegv_handler(
1168 _sig: i32,
1169 siginfo: *mut libc::siginfo_t,
1170 uctx: *mut libc::ucontext_t,
1171) {
1172 let fail = || restore_default_signal_handler(libc::SIGSEGV);
1173
1174 let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1175 if x.valid.load(Ordering::Relaxed) {
1176 compiler_fence(Ordering::Acquire);
1177 Some((
1178 x.code_range.get(),
1179 x.pointer_cage_protected_range.get(),
1180 x.yielder.get(),
1181 ))
1182 } else {
1183 None
1184 }
1185 }) else {
1186 return fail();
1187 };
1188
1189 let pc = program_counter(uctx);
1190
1191 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1192 return fail();
1193 }
1194
1195 if (*siginfo).si_code != 2 {
1197 return fail();
1198 }
1199
1200 let si_addr = (*siginfo).si_addr() as usize;
1201 if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1202 return fail();
1203 }
1204
1205 let yielder = yielder.expect("no yielder").as_ref();
1206 yielder.suspend(Dispatch {
1207 memory_access_error: Some(si_addr),
1208 ..Default::default()
1209 });
1210}
1211
1212unsafe extern "C" fn sigusr1_handler(
1213 _sig: i32,
1214 _siginfo: *mut libc::siginfo_t,
1215 uctx: *mut libc::ucontext_t,
1216) {
1217 SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1218
1219 let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1220 if x.valid.load(Ordering::Relaxed) {
1221 compiler_fence(Ordering::Acquire);
1222 Some((x.code_range.get(), x.yielder.get()))
1223 } else {
1224 None
1225 }
1226 }) else {
1227 return;
1228 };
1229 let pc = program_counter(uctx);
1230 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1231 return;
1232 }
1233
1234 let yielder = yielder.expect("no yielder").as_ref();
1235 yielder.suspend(Dispatch {
1236 async_preemption: true,
1237 ..Default::default()
1238 });
1239}
1240
1241unsafe fn restore_default_signal_handler(signum: i32) {
1242 let act = libc::sigaction {
1243 sa_sigaction: libc::SIG_DFL,
1244 sa_flags: libc::SA_SIGINFO,
1245 sa_mask: std::mem::zeroed(),
1246 sa_restorer: None,
1247 };
1248 if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1249 libc::abort();
1250 }
1251}
1252
1253fn get_blocked_sigset() -> libc::sigset_t {
1254 unsafe {
1255 let mut s: libc::sigset_t = std::mem::zeroed();
1256 libc::sigaddset(&mut s, libc::SIGUSR1);
1257 libc::sigaddset(&mut s, libc::SIGSEGV);
1258 s
1259 }
1260}