1use std::{
5 any::{Any, TypeId},
6 cell::{Cell, RefCell},
7 collections::HashMap,
8 ffi::CStr,
9 marker::PhantomData,
10 mem::ManuallyDrop,
11 ops::{Deref, DerefMut},
12 os::raw::c_void,
13 pin::Pin,
14 ptr::NonNull,
15 rc::Rc,
16 sync::{
17 atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18 Arc, Once,
19 },
20 thread::ThreadId,
21 time::{Duration, Instant},
22};
23
24use corosensei::{
25 stack::{DefaultStack, Stack},
26 Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36 error::{Error, RuntimeError},
37 helpers::Helper,
38 linker::link_elf,
39 pointer_cage::PointerCage,
40 util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_DEREF_REGIONS: usize = 4;
47
48pub struct InvokeScope {
50 data: HashMap<TypeId, Box<dyn Any + Send>>,
51}
52
53impl InvokeScope {
54 pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
56 let ty = TypeId::of::<T>();
57 self
58 .data
59 .entry(ty)
60 .or_insert_with(|| Box::new(T::default()))
61 .downcast_mut()
62 .expect("InvokeScope::data_mut: downcast failed")
63 }
64}
65
66pub struct HelperScope<'a, 'b> {
68 pub program: &'a Program,
70 pub invoke: RefCell<&'a mut InvokeScope>,
72 resources: RefCell<&'a mut [&'b mut dyn Any]>,
73 mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
74 immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
75 can_post_task: bool,
76}
77
78pub struct MutableUserMemory<'a, 'b, 'c> {
80 _scope: &'c HelperScope<'a, 'b>,
81 region: NonNull<[u8]>,
82}
83
84impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
85 type Target = [u8];
86
87 fn deref(&self) -> &Self::Target {
88 unsafe { self.region.as_ref() }
89 }
90}
91
92impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
93 fn deref_mut(&mut self) -> &mut Self::Target {
94 unsafe { self.region.as_mut() }
95 }
96}
97
98impl<'a, 'b> HelperScope<'a, 'b> {
99 pub fn post_task(
101 &self,
102 task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
103 ) {
104 if !self.can_post_task {
105 panic!("HelperScope::post_task() called in a context where posting task is not allowed");
106 }
107
108 PENDING_ASYNC_TASK.with(|x| {
109 let mut x = x.borrow_mut();
110 if x.is_some() {
111 panic!("post_task called while another task is pending");
112 }
113 *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
114 });
115 }
116
117 pub fn with_resource_mut<'c, T: 'static, R>(
119 &'c self,
120 callback: impl FnOnce(Result<&mut T, ()>) -> R,
121 ) -> R {
122 let mut resources = self.resources.borrow_mut();
123 let Some(res) = resources
124 .iter_mut()
125 .filter_map(|x| x.downcast_mut::<T>())
126 .next()
127 else {
128 tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
129 return callback(Err(()));
130 };
131
132 callback(Ok(res))
133 }
134
135 pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
137 let Some(region) = self
138 .program
139 .unbound
140 .cage
141 .safe_deref_for_read(ptr as usize, size as usize)
142 else {
143 tracing::warn!(ptr, size, "invalid read");
144 return Err(());
145 };
146
147 if size != 0 {
148 if self
150 .mutable_dereferenced_regions
151 .iter()
152 .filter_map(|x| x.get())
153 .any(|x| nonnull_bytes_overlap(x, region))
154 {
155 tracing::warn!(ptr, size, "read overlapped with previous write");
156 return Err(());
157 }
158
159 let Some(slot) = self
161 .immutable_dereferenced_regions
162 .iter()
163 .find(|x| x.get().is_none())
164 else {
165 tracing::warn!(ptr, size, "too many reads");
166 return Err(());
167 };
168 slot.set(Some(region));
169 }
170
171 Ok(unsafe { region.as_ref() })
172 }
173
174 pub fn user_memory_mut<'c>(
176 &'c self,
177 ptr: u64,
178 size: u64,
179 ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
180 let Some(region) = self
181 .program
182 .unbound
183 .cage
184 .safe_deref_for_write(ptr as usize, size as usize)
185 else {
186 tracing::warn!(ptr, size, "invalid write");
187 return Err(());
188 };
189
190 if size != 0 {
191 if self
193 .mutable_dereferenced_regions
194 .iter()
195 .chain(self.immutable_dereferenced_regions.iter())
196 .filter_map(|x| x.get())
197 .any(|x| nonnull_bytes_overlap(x, region))
198 {
199 tracing::warn!(ptr, size, "write overlapped with previous read/write");
200 return Err(());
201 }
202
203 let Some(slot) = self
205 .mutable_dereferenced_regions
206 .iter()
207 .find(|x| x.get().is_none())
208 else {
209 tracing::warn!(ptr, size, "too many writes");
210 return Err(());
211 };
212 slot.set(Some(region));
213 }
214
215 Ok(MutableUserMemory {
216 _scope: self,
217 region,
218 })
219 }
220}
221
222#[derive(Copy, Clone)]
223struct AssumeSend<T>(T);
224unsafe impl<T> Send for AssumeSend<T> {}
225
226struct ExecContext {
227 native_stack: DefaultStack,
228 copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
229}
230
231impl ExecContext {
232 fn new() -> Self {
233 Self {
234 native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
235 .expect("failed to initialize native stack"),
236 copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
237 }
238 }
239}
240
241pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
243pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
245
246static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
247
248#[derive(Copy, Clone, Debug)]
249enum PreemptionState {
250 Inactive,
251 Active(usize),
252 Shutdown,
253}
254
255type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
256
257thread_local! {
258 static RUST_TID: ThreadId = std::thread::current().id();
259 static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
260 static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
261 static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
262 static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
263 static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
264}
265
266struct BorrowedExecContext {
267 ctx: ManuallyDrop<ExecContext>,
268}
269
270impl BorrowedExecContext {
271 fn new(stack_rng: &mut impl rand::Rng) -> Self {
272 let mut me = Self {
273 ctx: ManuallyDrop::new(
274 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
275 ),
276 };
277 stack_rng.fill(&mut *me.ctx.copy_stack);
278 me
279 }
280}
281
282impl Drop for BorrowedExecContext {
283 fn drop(&mut self) {
284 let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
285 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
286 }
287}
288
289#[derive(Default)]
290struct ActiveJitCodeZone {
291 valid: AtomicBool,
292 code_range: Cell<(usize, usize)>,
293 pointer_cage_protected_range: Cell<(usize, usize)>,
294 yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
295}
296
297pub trait ProgramEventListener: Send + Sync + 'static {
299 fn did_async_preempt(&self, _scope: &HelperScope) {}
301 fn did_yield(&self) {}
303 fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
305 None
306 }
307 fn did_save_shadow_stack(&self) {}
309 fn did_restore_shadow_stack(&self) {}
311}
312
313pub struct DummyProgramEventListener;
315impl ProgramEventListener for DummyProgramEventListener {}
316
317pub struct ProgramLoader {
319 helpers_inverse: HashMap<&'static str, i32>,
320 event_listener: Arc<dyn ProgramEventListener>,
321 helper_id_xor: u16,
322 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
323}
324
325pub struct UnboundProgram {
327 id: u64,
328 _code_mem: MmapRaw,
329 cage: PointerCage,
330 helper_id_xor: u16,
331 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
332 event_listener: Arc<dyn ProgramEventListener>,
333 entrypoints: HashMap<String, Entrypoint>,
334 stack_rng_seed: [u8; 32],
335}
336
337pub struct Program {
339 unbound: UnboundProgram,
340 data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
341 stack_rng: RefCell<ChaCha8Rng>,
342 t: ThreadEnv,
343}
344
345#[derive(Copy, Clone)]
346struct Entrypoint {
347 code_ptr: usize,
348 code_len: usize,
349}
350
351#[derive(Clone, Debug)]
353pub struct TimesliceConfig {
354 pub max_run_time_before_yield: Duration,
356 pub max_run_time_before_throttle: Duration,
358 pub throttle_duration: Duration,
360}
361
362pub trait Timeslicer {
364 fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
366 fn yield_now(&self) -> impl Future<Output = ()>;
368}
369
370#[derive(Copy, Clone)]
372pub struct GlobalEnv(());
373
374#[derive(Copy, Clone)]
376pub struct ThreadEnv {
377 _not_send_sync: std::marker::PhantomData<*const ()>,
378}
379
380impl GlobalEnv {
381 pub unsafe fn new() -> Self {
386 static INIT: Once = Once::new();
387
388 INIT.call_once(|| {
403 let sa_mask = get_blocked_sigset();
404
405 for (sig, handler) in [
406 (libc::SIGUSR1, sigusr1_handler as usize),
407 (libc::SIGSEGV, sigsegv_handler as usize),
408 ] {
409 let act = libc::sigaction {
410 sa_sigaction: handler,
411 sa_flags: libc::SA_SIGINFO,
412 sa_mask,
413 sa_restorer: None,
414 };
415 if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
416 panic!("failed to setup handler for signal {}", sig);
417 }
418 }
419 });
420
421 Self(())
422 }
423
424 pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
426 struct DeferDrop(Arc<PreemptionStateSignal>);
427 impl Drop for DeferDrop {
428 fn drop(&mut self) {
429 let x = &self.0;
430 *x.0.lock() = PreemptionState::Shutdown;
431 x.1.notify_one();
432 }
433 }
434
435 thread_local! {
436 static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
437 }
438
439 if WATCHER.with(|x| x.borrow().is_some()) {
440 return ThreadEnv {
441 _not_send_sync: PhantomData,
442 };
443 }
444
445 let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
446
447 unsafe {
448 let tgid = libc::getpid();
449 let tid = libc::gettid();
450
451 std::thread::Builder::new()
452 .name("preempt-watcher".to_string())
453 .spawn(move || {
454 let mut state = preemption_state.0.lock();
455 loop {
456 match *state {
457 PreemptionState::Shutdown => break,
458 PreemptionState::Inactive => {
459 preemption_state.1.wait(&mut state);
460 }
461 PreemptionState::Active(_) => {
462 let timeout = preemption_state
463 .1
464 .wait_while_for(
465 &mut state,
466 |x| matches!(x, PreemptionState::Active(_)),
467 async_preemption_interval,
468 );
469 if timeout.timed_out() {
470 let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
471 if ret != 0 {
472 break;
473 }
474 }
475 }
476 }
477 }
478 })
479 .expect("failed to spawn preemption watcher");
480
481 WATCHER.with(|x| {
482 x.borrow_mut()
483 .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
484 });
485
486 ThreadEnv {
487 _not_send_sync: PhantomData,
488 }
489 }
490 }
491}
492
493impl UnboundProgram {
494 pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
496 let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
497 Program {
498 unbound: self,
499 data: RefCell::new(HashMap::new()),
500 stack_rng,
501 t,
502 }
503 }
504}
505
506pub struct PreemptionEnabled(());
507
508impl PreemptionEnabled {
509 pub fn new(_: ThreadEnv) -> Self {
510 PREEMPTION_STATE.with(|x| {
511 let mut notify = false;
512 {
513 let mut st = x.0.lock();
514 let next = match *st {
515 PreemptionState::Inactive => {
516 notify = true;
517 PreemptionState::Active(1)
518 }
519 PreemptionState::Active(n) => PreemptionState::Active(n + 1),
520 PreemptionState::Shutdown => unreachable!(),
521 };
522 *st = next;
523 }
524
525 if notify {
526 x.1.notify_one();
527 }
528 });
529 Self(())
530 }
531}
532
533impl Drop for PreemptionEnabled {
534 fn drop(&mut self) {
535 PREEMPTION_STATE.with(|x| {
536 let mut st = x.0.lock();
537 let next = match *st {
538 PreemptionState::Active(1) => PreemptionState::Inactive,
539 PreemptionState::Active(n) => {
540 assert!(n > 1);
541 PreemptionState::Active(n - 1)
542 }
543 PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
544 };
545 *st = next;
546 });
547 }
548}
549
550impl Program {
551 pub fn id(&self) -> u64 {
553 self.unbound.id
554 }
555
556 pub fn thread_env(&self) -> ThreadEnv {
557 self.t
558 }
559
560 pub fn data<T: Default + 'static>(&self) -> Rc<T> {
562 let mut data = self.data.borrow_mut();
563 let entry = data.entry(TypeId::of::<T>());
564 let entry = entry.or_insert_with(|| Rc::new(T::default()));
565 entry.clone().downcast().unwrap()
566 }
567
568 pub fn has_section(&self, name: &str) -> bool {
569 self.unbound.entrypoints.contains_key(name)
570 }
571
572 pub async fn run(
574 &self,
575 timeslice: &TimesliceConfig,
576 timeslicer: &impl Timeslicer,
577 entrypoint: &str,
578 resources: &mut [&mut dyn Any],
579 calldata: &[u8],
580 preemption: &PreemptionEnabled,
581 ) -> Result<i64, Error> {
582 self
583 ._run(
584 timeslice, timeslicer, entrypoint, resources, calldata, preemption,
585 )
586 .await
587 .map_err(Error)
588 }
589
590 async fn _run(
591 &self,
592 timeslice: &TimesliceConfig,
593 timeslicer: &impl Timeslicer,
594 entrypoint: &str,
595 resources: &mut [&mut dyn Any],
596 calldata: &[u8],
597 _: &PreemptionEnabled,
598 ) -> Result<i64, RuntimeError> {
599 let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
600 return Err(RuntimeError::InvalidArgument("entrypoint not found"));
601 };
602
603 let entry = unsafe {
604 std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
605 entrypoint.code_ptr,
606 )
607 };
608 struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
609 ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
610 );
611 impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
612 for CoDropper<'a, Input, Yield, Return, DefaultStack>
613 {
614 fn drop(&mut self) {
615 unsafe {
623 self.0.force_reset();
624 }
625 }
626 }
627
628 let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
629
630 if calldata.len() > MAX_CALLDATA_SIZE {
631 return Err(RuntimeError::InvalidArgument("calldata too large"));
632 }
633 ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
634 let calldata_len = calldata.len();
635
636 let program_ret: u64 = {
637 let shadow_stack_top = self.unbound.cage.stack_top();
638 let shadow_stack_ptr = AssumeSend(
639 self
640 .unbound
641 .cage
642 .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
643 .unwrap(),
644 );
645 let ctx = &mut *ectx.ctx;
646
647 let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
648 &mut ctx.native_stack,
649 move |yielder, _input| unsafe {
650 ACTIVE_JIT_CODE_ZONE.with(|x| {
651 x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
652 });
653 entry(
654 shadow_stack_top - calldata_len,
655 shadow_stack_top - calldata_len,
656 )
657 },
658 )));
659
660 let mut last_yield_time = Instant::now();
661 let mut last_throttle_time = Instant::now();
662 let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
663 let mut resume_input: u64 = 0;
664 let mut did_throttle = false;
665 let mut shadow_stack_saved = true;
666 let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
667 let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
668 let mut invoke_scope = InvokeScope {
669 data: HashMap::new(),
670 };
671
672 loop {
673 ACTIVE_JIT_CODE_ZONE.with(|x| {
674 x.code_range.set((
675 entrypoint.code_ptr,
676 entrypoint.code_ptr + entrypoint.code_len,
677 ));
678 x.yielder.set(yielder.map(|x| x.0));
679 x.pointer_cage_protected_range
680 .set(self.unbound.cage.protected_range_without_margins());
681 compiler_fence(Ordering::Release);
682 x.valid.store(true, Ordering::Relaxed);
683 });
684
685 if shadow_stack_saved {
686 shadow_stack_saved = false;
687
688 unsafe {
690 std::ptr::copy_nonoverlapping(
691 ctx.copy_stack.as_ptr() as *const u8,
692 shadow_stack_ptr.0.as_ptr() as *mut u8,
693 SHADOW_STACK_SIZE,
694 );
695 }
696
697 self.unbound.event_listener.did_restore_shadow_stack();
698 }
699
700 if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
702 resume_input = prev_async_task_output(&HelperScope {
703 program: self,
704 invoke: RefCell::new(&mut invoke_scope),
705 resources: RefCell::new(resources),
706 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
707 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
708 can_post_task: false,
709 })
710 .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
711 }
712
713 let ret = co.0 .0.resume(resume_input);
714 ACTIVE_JIT_CODE_ZONE.with(|x| {
715 x.valid.store(false, Ordering::Relaxed);
716 compiler_fence(Ordering::Release);
717 yielder = x.yielder.get().map(AssumeSend);
718 });
719
720 let dispatch: Dispatch = match ret {
721 CoroutineResult::Return(x) => break x,
722 CoroutineResult::Yield(x) => x,
723 };
724
725 if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
727 unsafe {
728 let unblock = get_blocked_sigset();
729 libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
730 }
731 }
732
733 if let Some(si_addr) = dispatch.memory_access_error {
734 let vaddr = si_addr - self.unbound.cage.offset();
735 return Err(RuntimeError::MemoryFault(vaddr));
736 }
737
738 PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
740 let mut helper_name: &'static str = "";
741
742 let mut helper_scope = HelperScope {
743 program: self,
744 invoke: RefCell::new(&mut invoke_scope),
745 resources: RefCell::new(resources),
746 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
747 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
748 can_post_task: false,
749 };
750
751 if dispatch.async_preemption {
752 self
753 .unbound
754 .event_listener
755 .did_async_preempt(&mut helper_scope);
756 } else {
757 let Some((_, got_helper_name, helper)) = self
759 .unbound
760 .helpers
761 .get(
762 ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
763 as usize,
764 )
765 .copied()
766 else {
767 panic!("unknown helper index: {}", dispatch.index);
768 };
769 helper_name = got_helper_name;
770
771 helper_scope.can_post_task = true;
772 resume_input = helper(
773 &mut helper_scope,
774 dispatch.arg1,
775 dispatch.arg2,
776 dispatch.arg3,
777 dispatch.arg4,
778 dispatch.arg5,
779 )
780 .map_err(|()| RuntimeError::HelperError(helper_name))?;
781 helper_scope.can_post_task = false;
782 }
783
784 let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
785
786 let new_rust_tid_sigusr1_counter =
788 (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
789 if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
790 {
791 continue;
792 }
793
794 rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
795
796 let now = Instant::now();
797 let should_throttle = now > last_throttle_time
798 && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
799 let should_yield = now > last_yield_time
800 && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
801 if should_throttle || should_yield || pending_async_task.is_some() {
802 shadow_stack_saved = true;
804 unsafe {
805 std::ptr::copy_nonoverlapping(
806 shadow_stack_ptr.0.as_ptr() as *const u8,
807 ctx.copy_stack.as_mut_ptr() as *mut u8,
808 SHADOW_STACK_SIZE,
809 );
810 }
811 self.unbound.event_listener.did_save_shadow_stack();
812
813 if should_throttle {
816 if !did_throttle {
817 did_throttle = true;
818 tracing::warn!("throttling program");
819 }
820 timeslicer.sleep(timeslice.throttle_duration).await;
821 let now = Instant::now();
822 last_throttle_time = now;
823 last_yield_time = now;
824 let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
825 if let Some(task) = task {
826 task.await;
827 }
828 } else if should_yield {
829 timeslicer.yield_now().await;
830 let now = Instant::now();
831 last_yield_time = now;
832 self.unbound.event_listener.did_yield();
833 }
834
835 if let Some(pending_async_task) = pending_async_task {
837 let async_start = Instant::now();
838 prev_async_task_output = Some((helper_name, pending_async_task.await));
839 let async_dur = async_start.elapsed();
840 last_throttle_time += async_dur;
841 last_yield_time += async_dur;
842 }
843 }
844 }
845 };
846
847 Ok(program_ret as i64)
848 }
849}
850
851struct Vm(NonNull<crate::ubpf::ubpf_vm>);
852
853impl Vm {
854 fn new(cage: &PointerCage) -> Self {
855 let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
856 unsafe {
857 crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
858 crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
859 crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
860 }
861 Self(vm)
862 }
863}
864
865impl Drop for Vm {
866 fn drop(&mut self) {
867 unsafe {
868 crate::ubpf::ubpf_destroy(self.0.as_ptr());
869 }
870 }
871}
872
873impl ProgramLoader {
874 pub fn new(
876 rng: &mut impl rand::Rng,
877 event_listener: Arc<dyn ProgramEventListener>,
878 raw_helpers: &[&[(&'static str, Helper)]],
879 ) -> Self {
880 let helper_id_xor = rng.gen::<u16>();
881 let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
882 let mut shuffled_helpers = raw_helpers
884 .iter()
885 .flat_map(|x| x.iter().copied())
886 .collect::<HashMap<_, _>>()
887 .into_iter()
888 .collect::<Vec<_>>();
889 shuffled_helpers.shuffle(rng);
890 let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
891
892 assert!(shuffled_helpers.len() <= 65535);
893
894 for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
895 let entropy = rng.gen::<u16>() & 0x7fff;
896 helpers.push((entropy, name, helper));
897 helpers_inverse.insert(
898 name,
899 (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
900 );
901 }
902
903 tracing::info!(?helpers_inverse, "generated helper table");
904 Self {
905 helper_id_xor,
906 helpers: Arc::new(helpers),
907 helpers_inverse,
908 event_listener,
909 }
910 }
911
912 pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
914 self._load(rng, elf).map_err(Error)
915 }
916
917 fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
918 let start_time = Instant::now();
919 let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
920 let vm = Vm::new(&cage);
921
922 let code_sections = {
924 let mut data = cage
928 .safe_deref_for_read(cage.data_bottom(), elf.len())
929 .unwrap();
930 let data = unsafe { data.as_mut() };
931 data.copy_from_slice(elf);
932
933 link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
934 };
935 cage.freeze_data();
936
937 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
938 if page_size < 0 {
939 return Err(RuntimeError::PlatformError("failed to get page size"));
940 }
941 let page_size = page_size as usize;
942
943 let guard_size_before = rng.gen_range(16..128) * page_size;
945 let mut guard_size_after = rng.gen_range(16..128) * page_size;
946
947 let code_len_allocated: usize = 65536;
948 let code_mem = MmapRaw::from(
949 MmapOptions::new()
950 .len(code_len_allocated + guard_size_before + guard_size_after)
951 .map_anon()
952 .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
953 );
954
955 unsafe {
956 if crate::ubpf::ubpf_register_external_dispatcher(
957 vm.0.as_ptr(),
958 Some(tls_dispatcher),
959 Some(std_validator),
960 self as *const _ as *mut c_void,
961 ) != 0
962 {
963 return Err(RuntimeError::PlatformError(
964 "ubpf: failed to register external dispatcher",
965 ));
966 }
967 if libc::mprotect(
968 code_mem.as_mut_ptr() as *mut _,
969 guard_size_before,
970 libc::PROT_NONE,
971 ) != 0
972 || libc::mprotect(
973 code_mem
974 .as_mut_ptr()
975 .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
976 guard_size_after,
977 libc::PROT_NONE,
978 ) != 0
979 {
980 return Err(RuntimeError::PlatformError("failed to protect guard pages"));
981 }
982 }
983
984 let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
985
986 unsafe {
987 let mut code_slice = std::slice::from_raw_parts_mut(
989 code_mem.as_mut_ptr().offset(guard_size_before as isize),
990 code_len_allocated,
991 );
992 for (section_name, code_vaddr_size) in code_sections {
993 if code_slice.is_empty() {
994 return Err(RuntimeError::InvalidArgument(
995 "no space left for jit compilation",
996 ));
997 }
998
999 crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
1000
1001 let mut errmsg_ptr = std::ptr::null_mut();
1002 let code = cage
1003 .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
1004 .unwrap();
1005 let ret = crate::ubpf::ubpf_load(
1006 vm.0.as_ptr(),
1007 code.as_ptr() as *const _,
1008 code.len() as u32,
1009 &mut errmsg_ptr,
1010 );
1011 if ret != 0 {
1012 let errmsg = if errmsg_ptr.is_null() {
1013 "".to_string()
1014 } else {
1015 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1016 };
1017 if !errmsg_ptr.is_null() {
1018 libc::free(errmsg_ptr as _);
1019 }
1020 tracing::error!(section_name, error = errmsg, "failed to load code");
1021 return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1022 }
1023
1024 let mut written_len = code_slice.len();
1025 let ret = crate::ubpf::ubpf_translate(
1026 vm.0.as_ptr(),
1027 code_slice.as_mut_ptr(),
1028 &mut written_len,
1029 &mut errmsg_ptr,
1030 );
1031 if ret != 0 {
1032 let errmsg = if errmsg_ptr.is_null() {
1033 "".to_string()
1034 } else {
1035 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1036 };
1037 if !errmsg_ptr.is_null() {
1038 libc::free(errmsg_ptr as _);
1039 }
1040 tracing::error!(section_name, error = errmsg, "failed to translate code");
1041 return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1042 }
1043
1044 assert!(written_len <= code_slice.len());
1045 entrypoints.insert(
1046 section_name,
1047 Entrypoint {
1048 code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1049 - code_slice.len(),
1050 code_len: written_len,
1051 },
1052 );
1053 code_slice = &mut code_slice[written_len..];
1054 }
1055
1056 let unpadded_code_len = code_len_allocated - code_slice.len();
1058 let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1059 assert!(code_len <= code_len_allocated);
1060
1061 if libc::mprotect(
1064 code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1065 code_len,
1066 libc::PROT_READ | libc::PROT_EXEC,
1067 ) != 0
1068 || (code_len < code_len_allocated
1069 && libc::mprotect(
1070 code_mem
1071 .as_mut_ptr()
1072 .offset((guard_size_before + code_len) as isize) as *mut _,
1073 code_len_allocated - code_len,
1074 libc::PROT_NONE,
1075 ) != 0)
1076 {
1077 return Err(RuntimeError::PlatformError("failed to protect code memory"));
1078 }
1079
1080 guard_size_after += code_len_allocated - code_len;
1081
1082 tracing::info!(
1083 elf_size = elf.len(),
1084 native_code_addr = ?code_mem.as_ptr(),
1085 native_code_size = code_len,
1086 native_code_size_unpadded = unpadded_code_len,
1087 guard_size_before,
1088 guard_size_after,
1089 duration = ?start_time.elapsed(),
1090 cage_ptr = ?cage.region().as_ptr(),
1091 cage_mapped_size = cage.region().len(),
1092 "jit compiled program"
1093 );
1094
1095 Ok(UnboundProgram {
1096 id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1097 _code_mem: code_mem,
1098 cage,
1099 helper_id_xor: self.helper_id_xor,
1100 helpers: self.helpers.clone(),
1101 event_listener: self.event_listener.clone(),
1102 entrypoints,
1103 stack_rng_seed: rng.gen(),
1104 })
1105 }
1106 }
1107}
1108
1109#[derive(Default)]
1110struct Dispatch {
1111 async_preemption: bool,
1112 memory_access_error: Option<usize>,
1113
1114 index: u32,
1115 arg1: u64,
1116 arg2: u64,
1117 arg3: u64,
1118 arg4: u64,
1119 arg5: u64,
1120}
1121
1122unsafe extern "C" fn tls_dispatcher(
1123 arg1: u64,
1124 arg2: u64,
1125 arg3: u64,
1126 arg4: u64,
1127 arg5: u64,
1128 index: std::os::raw::c_uint,
1129 _cookie: *mut std::os::raw::c_void,
1130) -> u64 {
1131 let yielder = ACTIVE_JIT_CODE_ZONE
1132 .with(|x| x.yielder.get())
1133 .expect("no yielder");
1134 let yielder = yielder.as_ref();
1135 let ret = yielder.suspend(Dispatch {
1136 async_preemption: false,
1137 memory_access_error: None,
1138 index,
1139 arg1,
1140 arg2,
1141 arg3,
1142 arg4,
1143 arg5,
1144 });
1145 ret
1146}
1147
1148unsafe extern "C" fn std_validator(
1149 index: std::os::raw::c_uint,
1150 loader: *mut std::os::raw::c_void,
1151) -> bool {
1152 let loader = &*(loader as *const ProgramLoader);
1153 let entropy = (index >> 16) & 0xffff;
1154 let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1155 loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1156}
1157
1158#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1159unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1160 (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1161}
1162
1163#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1164unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1165 (*uctx).uc_mcontext.pc as usize
1166}
1167
1168unsafe extern "C" fn sigsegv_handler(
1169 _sig: i32,
1170 siginfo: *mut libc::siginfo_t,
1171 uctx: *mut libc::ucontext_t,
1172) {
1173 let fail = || restore_default_signal_handler(libc::SIGSEGV);
1174
1175 let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1176 if x.valid.load(Ordering::Relaxed) {
1177 compiler_fence(Ordering::Acquire);
1178 Some((
1179 x.code_range.get(),
1180 x.pointer_cage_protected_range.get(),
1181 x.yielder.get(),
1182 ))
1183 } else {
1184 None
1185 }
1186 }) else {
1187 return fail();
1188 };
1189
1190 let pc = program_counter(uctx);
1191
1192 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1193 return fail();
1194 }
1195
1196 if (*siginfo).si_code != 2 {
1198 return fail();
1199 }
1200
1201 let si_addr = (*siginfo).si_addr() as usize;
1202 if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1203 return fail();
1204 }
1205
1206 let yielder = yielder.expect("no yielder").as_ref();
1207 yielder.suspend(Dispatch {
1208 memory_access_error: Some(si_addr),
1209 ..Default::default()
1210 });
1211}
1212
1213unsafe extern "C" fn sigusr1_handler(
1214 _sig: i32,
1215 _siginfo: *mut libc::siginfo_t,
1216 uctx: *mut libc::ucontext_t,
1217) {
1218 SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1219
1220 let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1221 if x.valid.load(Ordering::Relaxed) {
1222 compiler_fence(Ordering::Acquire);
1223 Some((x.code_range.get(), x.yielder.get()))
1224 } else {
1225 None
1226 }
1227 }) else {
1228 return;
1229 };
1230 let pc = program_counter(uctx);
1231 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1232 return;
1233 }
1234
1235 let yielder = yielder.expect("no yielder").as_ref();
1236 yielder.suspend(Dispatch {
1237 async_preemption: true,
1238 ..Default::default()
1239 });
1240}
1241
1242unsafe fn restore_default_signal_handler(signum: i32) {
1243 let act = libc::sigaction {
1244 sa_sigaction: libc::SIG_DFL,
1245 sa_flags: libc::SA_SIGINFO,
1246 sa_mask: std::mem::zeroed(),
1247 sa_restorer: None,
1248 };
1249 if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1250 libc::abort();
1251 }
1252}
1253
1254fn get_blocked_sigset() -> libc::sigset_t {
1255 unsafe {
1256 let mut s: libc::sigset_t = std::mem::zeroed();
1257 libc::sigaddset(&mut s, libc::SIGUSR1);
1258 libc::sigaddset(&mut s, libc::SIGSEGV);
1259 s
1260 }
1261}