1use std::{
5 any::{Any, TypeId},
6 cell::{Cell, RefCell},
7 collections::HashMap,
8 ffi::CStr,
9 marker::PhantomData,
10 mem::ManuallyDrop,
11 ops::{Deref, DerefMut},
12 os::raw::c_void,
13 pin::Pin,
14 ptr::NonNull,
15 rc::Rc,
16 sync::{
17 atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18 Arc, Once,
19 },
20 thread::ThreadId,
21 time::{Duration, Instant},
22};
23
24use corosensei::{
25 stack::{DefaultStack, Stack},
26 Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32
33use crate::{
34 error::{Error, RuntimeError},
35 helpers::Helper,
36 linker::link_elf,
37 pointer_cage::PointerCage,
38 util::nonnull_bytes_overlap,
39};
40
41const NATIVE_STACK_SIZE: usize = 16384;
42const SHADOW_STACK_SIZE: usize = 4096;
43const MAX_CALLDATA_SIZE: usize = 512;
44const MAX_MUTABLE_DEREF_REGIONS: usize = 4;
45const MAX_IMMUTABLE_DEREF_REGIONS: usize = 16;
46
47pub struct InvokeScope {
49 data: HashMap<TypeId, Box<dyn Any + Send>>,
50}
51
52impl InvokeScope {
53 pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
55 let ty = TypeId::of::<T>();
56 self
57 .data
58 .entry(ty)
59 .or_insert_with(|| Box::new(T::default()))
60 .downcast_mut()
61 .expect("InvokeScope::data_mut: downcast failed")
62 }
63}
64
65pub struct HelperScope<'a, 'b> {
67 pub program: &'a Program,
69 pub invoke: RefCell<&'a mut InvokeScope>,
71 resources: RefCell<&'a mut [&'b mut dyn Any]>,
72 mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_MUTABLE_DEREF_REGIONS],
73 immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_IMMUTABLE_DEREF_REGIONS],
74 can_post_task: bool,
75}
76
77pub struct MutableUserMemory<'a, 'b, 'c> {
79 _scope: &'c HelperScope<'a, 'b>,
80 region: NonNull<[u8]>,
81}
82
83impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
84 type Target = [u8];
85
86 fn deref(&self) -> &Self::Target {
87 unsafe { self.region.as_ref() }
88 }
89}
90
91impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
92 fn deref_mut(&mut self) -> &mut Self::Target {
93 unsafe { self.region.as_mut() }
94 }
95}
96
97impl<'a, 'b> HelperScope<'a, 'b> {
98 pub fn post_task(
100 &self,
101 task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
102 ) {
103 if !self.can_post_task {
104 panic!("HelperScope::post_task() called in a context where posting task is not allowed");
105 }
106
107 PENDING_ASYNC_TASK.with(|x| {
108 let mut x = x.borrow_mut();
109 if x.is_some() {
110 panic!("post_task called while another task is pending");
111 }
112 *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
113 });
114 }
115
116 pub fn with_resource_mut<'c, T: 'static, R>(
118 &'c self,
119 callback: impl FnOnce(Result<&mut T, ()>) -> R,
120 ) -> R {
121 let mut resources = self.resources.borrow_mut();
122 let Some(res) = resources
123 .iter_mut()
124 .filter_map(|x| x.downcast_mut::<T>())
125 .next()
126 else {
127 tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
128 return callback(Err(()));
129 };
130
131 callback(Ok(res))
132 }
133
134 pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
136 let Some(region) = self
137 .program
138 .unbound
139 .cage
140 .safe_deref_for_read(ptr as usize, size as usize)
141 else {
142 tracing::warn!(ptr, size, "invalid read");
143 return Err(());
144 };
145
146 if size != 0 {
147 if self
149 .mutable_dereferenced_regions
150 .iter()
151 .filter_map(|x| x.get())
152 .any(|x| nonnull_bytes_overlap(x, region))
153 {
154 tracing::warn!(ptr, size, "read overlapped with previous write");
155 return Err(());
156 }
157
158 let Some(slot) = self
160 .immutable_dereferenced_regions
161 .iter()
162 .find(|x| x.get().is_none())
163 else {
164 tracing::warn!(ptr, size, "too many reads");
165 return Err(());
166 };
167 slot.set(Some(region));
168 }
169
170 Ok(unsafe { region.as_ref() })
171 }
172
173 pub fn user_memory_mut<'c>(
175 &'c self,
176 ptr: u64,
177 size: u64,
178 ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
179 let Some(region) = self
180 .program
181 .unbound
182 .cage
183 .safe_deref_for_write(ptr as usize, size as usize)
184 else {
185 tracing::warn!(ptr, size, "invalid write");
186 return Err(());
187 };
188
189 if size != 0 {
190 if self
192 .mutable_dereferenced_regions
193 .iter()
194 .chain(self.immutable_dereferenced_regions.iter())
195 .filter_map(|x| x.get())
196 .any(|x| nonnull_bytes_overlap(x, region))
197 {
198 tracing::warn!(ptr, size, "write overlapped with previous read/write");
199 return Err(());
200 }
201
202 let Some(slot) = self
204 .mutable_dereferenced_regions
205 .iter()
206 .find(|x| x.get().is_none())
207 else {
208 tracing::warn!(ptr, size, "too many writes");
209 return Err(());
210 };
211 slot.set(Some(region));
212 }
213
214 Ok(MutableUserMemory {
215 _scope: self,
216 region,
217 })
218 }
219}
220
221#[derive(Copy, Clone)]
222struct AssumeSend<T>(T);
223unsafe impl<T> Send for AssumeSend<T> {}
224
225struct ExecContext {
226 native_stack: DefaultStack,
227 copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
228}
229
230impl ExecContext {
231 fn new() -> Self {
232 Self {
233 native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
234 .expect("failed to initialize native stack"),
235 copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
236 }
237 }
238}
239
240pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
242pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
244
245static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
246
247#[derive(Copy, Clone, Debug)]
248enum PreemptionState {
249 Inactive,
250 Active(usize),
251 Shutdown,
252}
253
254type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
255
256thread_local! {
257 static RUST_TID: ThreadId = std::thread::current().id();
258 static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
259 static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
260 static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
261 static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
262 static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
263}
264
265struct BorrowedExecContext {
266 ctx: ManuallyDrop<ExecContext>,
267}
268
269impl BorrowedExecContext {
270 fn new() -> Self {
271 let mut me = Self {
272 ctx: ManuallyDrop::new(
273 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
274 ),
275 };
276 me.ctx.copy_stack.fill(0x8e);
277 me
278 }
279}
280
281impl Drop for BorrowedExecContext {
282 fn drop(&mut self) {
283 let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
284 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
285 }
286}
287
288#[derive(Default)]
289struct ActiveJitCodeZone {
290 valid: AtomicBool,
291 code_range: Cell<(usize, usize)>,
292 pointer_cage_protected_range: Cell<(usize, usize)>,
293 yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
294}
295
296pub trait ProgramEventListener: Send + Sync + 'static {
298 fn did_async_preempt(&self, _scope: &HelperScope) {}
300 fn did_yield(&self) {}
302 fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
304 None
305 }
306 fn did_save_shadow_stack(&self) {}
308 fn did_restore_shadow_stack(&self) {}
310}
311
312pub struct DummyProgramEventListener;
314impl ProgramEventListener for DummyProgramEventListener {}
315
316pub struct ProgramLoader {
318 helpers_inverse: HashMap<&'static str, i32>,
319 event_listener: Arc<dyn ProgramEventListener>,
320 helper_id_xor: u16,
321 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
322}
323
324pub struct UnboundProgram {
326 id: u64,
327 _code_mem: MmapRaw,
328 cage: PointerCage,
329 helper_id_xor: u16,
330 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
331 event_listener: Arc<dyn ProgramEventListener>,
332 entrypoints: HashMap<String, Entrypoint>,
333}
334
335pub struct Program {
337 unbound: UnboundProgram,
338 data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
339 t: ThreadEnv,
340}
341
342#[derive(Copy, Clone)]
343struct Entrypoint {
344 code_ptr: usize,
345 code_len: usize,
346}
347
348#[derive(Clone, Debug)]
350pub struct TimesliceConfig {
351 pub max_run_time_before_yield: Duration,
353 pub max_run_time_before_throttle: Duration,
355 pub throttle_duration: Duration,
357}
358
359pub trait Timeslicer {
361 fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
363 fn yield_now(&self) -> impl Future<Output = ()>;
365}
366
367#[derive(Copy, Clone)]
369pub struct GlobalEnv(());
370
371#[derive(Copy, Clone)]
373pub struct ThreadEnv {
374 _not_send_sync: std::marker::PhantomData<*const ()>,
375}
376
377impl GlobalEnv {
378 pub unsafe fn new() -> Self {
383 static INIT: Once = Once::new();
384
385 INIT.call_once(|| {
400 let sa_mask = get_blocked_sigset();
401
402 for (sig, handler) in [
403 (libc::SIGUSR1, sigusr1_handler as usize),
404 (libc::SIGSEGV, sigsegv_handler as usize),
405 ] {
406 let act = libc::sigaction {
407 sa_sigaction: handler,
408 sa_flags: libc::SA_SIGINFO,
409 sa_mask,
410 sa_restorer: None,
411 };
412 if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
413 panic!("failed to setup handler for signal {}", sig);
414 }
415 }
416 });
417
418 Self(())
419 }
420
421 pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
423 struct DeferDrop(Arc<PreemptionStateSignal>);
424 impl Drop for DeferDrop {
425 fn drop(&mut self) {
426 let x = &self.0;
427 *x.0.lock() = PreemptionState::Shutdown;
428 x.1.notify_one();
429 }
430 }
431
432 thread_local! {
433 static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
434 }
435
436 if WATCHER.with(|x| x.borrow().is_some()) {
437 return ThreadEnv {
438 _not_send_sync: PhantomData,
439 };
440 }
441
442 let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
443
444 unsafe {
445 let tgid = libc::getpid();
446 let tid = libc::gettid();
447
448 std::thread::Builder::new()
449 .name("preempt-watcher".to_string())
450 .spawn(move || {
451 let mut state = preemption_state.0.lock();
452 loop {
453 match *state {
454 PreemptionState::Shutdown => break,
455 PreemptionState::Inactive => {
456 preemption_state.1.wait(&mut state);
457 }
458 PreemptionState::Active(_) => {
459 let timeout = preemption_state.1.wait_while_for(
460 &mut state,
461 |x| matches!(x, PreemptionState::Active(_)),
462 async_preemption_interval,
463 );
464 if timeout.timed_out() {
465 let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
466 if ret != 0 {
467 break;
468 }
469 }
470 }
471 }
472 }
473 })
474 .expect("failed to spawn preemption watcher");
475
476 WATCHER.with(|x| {
477 x.borrow_mut()
478 .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
479 });
480
481 ThreadEnv {
482 _not_send_sync: PhantomData,
483 }
484 }
485 }
486}
487
488impl UnboundProgram {
489 pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
491 Program {
492 unbound: self,
493 data: RefCell::new(HashMap::new()),
494 t,
495 }
496 }
497}
498
499pub struct PreemptionEnabled(());
500
501impl PreemptionEnabled {
502 pub fn new(_: ThreadEnv) -> Self {
503 PREEMPTION_STATE.with(|x| {
504 let mut notify = false;
505 {
506 let mut st = x.0.lock();
507 let next = match *st {
508 PreemptionState::Inactive => {
509 notify = true;
510 PreemptionState::Active(1)
511 }
512 PreemptionState::Active(n) => PreemptionState::Active(n + 1),
513 PreemptionState::Shutdown => unreachable!(),
514 };
515 *st = next;
516 }
517
518 if notify {
519 x.1.notify_one();
520 }
521 });
522 Self(())
523 }
524}
525
526impl Drop for PreemptionEnabled {
527 fn drop(&mut self) {
528 PREEMPTION_STATE.with(|x| {
529 let mut st = x.0.lock();
530 let next = match *st {
531 PreemptionState::Active(1) => PreemptionState::Inactive,
532 PreemptionState::Active(n) => {
533 assert!(n > 1);
534 PreemptionState::Active(n - 1)
535 }
536 PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
537 };
538 *st = next;
539 });
540 }
541}
542
543impl Program {
544 pub fn id(&self) -> u64 {
546 self.unbound.id
547 }
548
549 pub fn thread_env(&self) -> ThreadEnv {
550 self.t
551 }
552
553 pub fn data<T: Default + 'static>(&self) -> Rc<T> {
555 let mut data = self.data.borrow_mut();
556 let entry = data.entry(TypeId::of::<T>());
557 let entry = entry.or_insert_with(|| Rc::new(T::default()));
558 entry.clone().downcast().unwrap()
559 }
560
561 pub fn has_section(&self, name: &str) -> bool {
562 self.unbound.entrypoints.contains_key(name)
563 }
564
565 pub async fn run(
567 &self,
568 timeslice: &TimesliceConfig,
569 timeslicer: &impl Timeslicer,
570 entrypoint: &str,
571 resources: &mut [&mut dyn Any],
572 calldata: &[u8],
573 preemption: &PreemptionEnabled,
574 ) -> Result<i64, Error> {
575 self
576 ._run(
577 timeslice, timeslicer, entrypoint, resources, calldata, preemption,
578 )
579 .await
580 .map_err(Error)
581 }
582
583 async fn _run(
584 &self,
585 timeslice: &TimesliceConfig,
586 timeslicer: &impl Timeslicer,
587 entrypoint: &str,
588 resources: &mut [&mut dyn Any],
589 calldata: &[u8],
590 _: &PreemptionEnabled,
591 ) -> Result<i64, RuntimeError> {
592 let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
593 return Err(RuntimeError::InvalidArgument("entrypoint not found"));
594 };
595
596 let entry = unsafe {
597 std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
598 entrypoint.code_ptr,
599 )
600 };
601 struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
602 ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
603 );
604 impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
605 for CoDropper<'a, Input, Yield, Return, DefaultStack>
606 {
607 fn drop(&mut self) {
608 unsafe {
616 self.0.force_reset();
617 }
618 }
619 }
620
621 let mut ectx = BorrowedExecContext::new();
622
623 if calldata.len() > MAX_CALLDATA_SIZE {
624 return Err(RuntimeError::InvalidArgument("calldata too large"));
625 }
626 ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
627 let calldata_len = calldata.len();
628
629 let program_ret: u64 = {
630 let shadow_stack_top = self.unbound.cage.stack_top();
631 let shadow_stack_ptr = AssumeSend(
632 self
633 .unbound
634 .cage
635 .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
636 .unwrap(),
637 );
638 let ctx = &mut *ectx.ctx;
639
640 let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
641 &mut ctx.native_stack,
642 move |yielder, _input| unsafe {
643 ACTIVE_JIT_CODE_ZONE.with(|x| {
644 x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
645 });
646 entry(
647 shadow_stack_top - calldata_len,
648 shadow_stack_top - calldata_len,
649 )
650 },
651 )));
652
653 let mut last_yield_time = Instant::now();
654 let mut last_throttle_time = Instant::now();
655 let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
656 let mut resume_input: u64 = 0;
657 let mut did_throttle = false;
658 let mut shadow_stack_saved = true;
659 let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
660 let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
661 let mut invoke_scope = InvokeScope {
662 data: HashMap::new(),
663 };
664
665 loop {
666 ACTIVE_JIT_CODE_ZONE.with(|x| {
667 x.code_range.set((
668 entrypoint.code_ptr,
669 entrypoint.code_ptr + entrypoint.code_len,
670 ));
671 x.yielder.set(yielder.map(|x| x.0));
672 x.pointer_cage_protected_range
673 .set(self.unbound.cage.protected_range_without_margins());
674 compiler_fence(Ordering::Release);
675 x.valid.store(true, Ordering::Relaxed);
676 });
677
678 if shadow_stack_saved {
679 shadow_stack_saved = false;
680
681 unsafe {
683 std::ptr::copy_nonoverlapping(
684 ctx.copy_stack.as_ptr() as *const u8,
685 shadow_stack_ptr.0.as_ptr() as *mut u8,
686 SHADOW_STACK_SIZE,
687 );
688 }
689
690 self.unbound.event_listener.did_restore_shadow_stack();
691 }
692
693 if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
695 resume_input = prev_async_task_output(&HelperScope {
696 program: self,
697 invoke: RefCell::new(&mut invoke_scope),
698 resources: RefCell::new(resources),
699 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
700 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
701 can_post_task: false,
702 })
703 .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
704 }
705
706 let ret = co.0 .0.resume(resume_input);
707 ACTIVE_JIT_CODE_ZONE.with(|x| {
708 x.valid.store(false, Ordering::Relaxed);
709 compiler_fence(Ordering::Release);
710 yielder = x.yielder.get().map(AssumeSend);
711 });
712
713 let dispatch: Dispatch = match ret {
714 CoroutineResult::Return(x) => break x,
715 CoroutineResult::Yield(x) => x,
716 };
717
718 if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
720 unsafe {
721 let unblock = get_blocked_sigset();
722 libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
723 }
724 }
725
726 if let Some(si_addr) = dispatch.memory_access_error {
727 let vaddr = si_addr - self.unbound.cage.offset();
728 return Err(RuntimeError::MemoryFault(vaddr));
729 }
730
731 PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
733 let mut helper_name: &'static str = "";
734
735 let mut helper_scope = HelperScope {
736 program: self,
737 invoke: RefCell::new(&mut invoke_scope),
738 resources: RefCell::new(resources),
739 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
740 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
741 can_post_task: false,
742 };
743
744 if dispatch.async_preemption {
745 self
746 .unbound
747 .event_listener
748 .did_async_preempt(&mut helper_scope);
749 } else {
750 let Some((_, got_helper_name, helper)) = self
752 .unbound
753 .helpers
754 .get(
755 ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
756 as usize,
757 )
758 .copied()
759 else {
760 panic!("unknown helper index: {}", dispatch.index);
761 };
762 helper_name = got_helper_name;
763
764 helper_scope.can_post_task = true;
765 resume_input = helper(
766 &mut helper_scope,
767 dispatch.arg1,
768 dispatch.arg2,
769 dispatch.arg3,
770 dispatch.arg4,
771 dispatch.arg5,
772 )
773 .map_err(|()| RuntimeError::HelperError(helper_name))?;
774 helper_scope.can_post_task = false;
775 }
776
777 let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
778
779 let new_rust_tid_sigusr1_counter =
781 (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
782 if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
783 {
784 continue;
785 }
786
787 rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
788
789 let now = Instant::now();
790 let should_throttle = now > last_throttle_time
791 && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
792 let should_yield = now > last_yield_time
793 && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
794 if should_throttle || should_yield || pending_async_task.is_some() {
795 shadow_stack_saved = true;
797 unsafe {
798 std::ptr::copy_nonoverlapping(
799 shadow_stack_ptr.0.as_ptr() as *const u8,
800 ctx.copy_stack.as_mut_ptr() as *mut u8,
801 SHADOW_STACK_SIZE,
802 );
803 }
804 self.unbound.event_listener.did_save_shadow_stack();
805
806 if should_throttle {
809 if !did_throttle {
810 did_throttle = true;
811 tracing::warn!("throttling program");
812 }
813 timeslicer.sleep(timeslice.throttle_duration).await;
814 let now = Instant::now();
815 last_throttle_time = now;
816 last_yield_time = now;
817 let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
818 if let Some(task) = task {
819 task.await;
820 }
821 } else if should_yield {
822 timeslicer.yield_now().await;
823 let now = Instant::now();
824 last_yield_time = now;
825 self.unbound.event_listener.did_yield();
826 }
827
828 if let Some(pending_async_task) = pending_async_task {
830 let async_start = Instant::now();
831 prev_async_task_output = Some((helper_name, pending_async_task.await));
832 let async_dur = async_start.elapsed();
833 last_throttle_time += async_dur;
834 last_yield_time += async_dur;
835 }
836 }
837 }
838 };
839
840 Ok(program_ret as i64)
841 }
842}
843
844struct Vm(NonNull<crate::ubpf::ubpf_vm>);
845
846impl Vm {
847 fn new(cage: &PointerCage) -> Self {
848 let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
849 unsafe {
850 crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
851 crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
852 crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
853 }
854 Self(vm)
855 }
856}
857
858impl Drop for Vm {
859 fn drop(&mut self) {
860 unsafe {
861 crate::ubpf::ubpf_destroy(self.0.as_ptr());
862 }
863 }
864}
865
866impl ProgramLoader {
867 pub fn new(
869 rng: &mut impl rand::Rng,
870 event_listener: Arc<dyn ProgramEventListener>,
871 raw_helpers: &[&[(&'static str, Helper)]],
872 ) -> Self {
873 let helper_id_xor = rng.gen::<u16>();
874 let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
875 let mut shuffled_helpers = raw_helpers
877 .iter()
878 .flat_map(|x| x.iter().copied())
879 .collect::<HashMap<_, _>>()
880 .into_iter()
881 .collect::<Vec<_>>();
882 shuffled_helpers.shuffle(rng);
883 let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
884
885 assert!(shuffled_helpers.len() <= 65535);
886
887 for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
888 let entropy = rng.gen::<u16>() & 0x7fff;
889 helpers.push((entropy, name, helper));
890 helpers_inverse.insert(
891 name,
892 (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
893 );
894 }
895
896 tracing::info!(?helpers_inverse, "generated helper table");
897 Self {
898 helper_id_xor,
899 helpers: Arc::new(helpers),
900 helpers_inverse,
901 event_listener,
902 }
903 }
904
905 pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
907 self._load(rng, elf).map_err(Error)
908 }
909
910 fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
911 let start_time = Instant::now();
912 let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
913 let vm = Vm::new(&cage);
914
915 let code_sections = {
917 let mut data = cage
921 .safe_deref_for_read(cage.data_bottom(), elf.len())
922 .unwrap();
923 let data = unsafe { data.as_mut() };
924 data.copy_from_slice(elf);
925
926 link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
927 };
928 cage.freeze_data();
929
930 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
931 if page_size < 0 {
932 return Err(RuntimeError::PlatformError("failed to get page size"));
933 }
934 let page_size = page_size as usize;
935
936 let guard_size_before = rng.gen_range(16..128) * page_size;
938 let mut guard_size_after = rng.gen_range(16..128) * page_size;
939
940 let code_len_allocated: usize = 65536;
941 let code_mem = MmapRaw::from(
942 MmapOptions::new()
943 .len(code_len_allocated + guard_size_before + guard_size_after)
944 .map_anon()
945 .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
946 );
947
948 unsafe {
949 if crate::ubpf::ubpf_register_external_dispatcher(
950 vm.0.as_ptr(),
951 Some(tls_dispatcher),
952 Some(std_validator),
953 self as *const _ as *mut c_void,
954 ) != 0
955 {
956 return Err(RuntimeError::PlatformError(
957 "ubpf: failed to register external dispatcher",
958 ));
959 }
960 if libc::mprotect(
961 code_mem.as_mut_ptr() as *mut _,
962 guard_size_before,
963 libc::PROT_NONE,
964 ) != 0
965 || libc::mprotect(
966 code_mem
967 .as_mut_ptr()
968 .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
969 guard_size_after,
970 libc::PROT_NONE,
971 ) != 0
972 {
973 return Err(RuntimeError::PlatformError("failed to protect guard pages"));
974 }
975 }
976
977 let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
978
979 unsafe {
980 let mut code_slice = std::slice::from_raw_parts_mut(
982 code_mem.as_mut_ptr().offset(guard_size_before as isize),
983 code_len_allocated,
984 );
985 for (section_name, code_vaddr_size) in code_sections {
986 if code_slice.is_empty() {
987 return Err(RuntimeError::InvalidArgument(
988 "no space left for jit compilation",
989 ));
990 }
991
992 crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
993
994 let mut errmsg_ptr = std::ptr::null_mut();
995 let code = cage
996 .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
997 .unwrap();
998 let ret = crate::ubpf::ubpf_load(
999 vm.0.as_ptr(),
1000 code.as_ptr() as *const _,
1001 code.len() as u32,
1002 &mut errmsg_ptr,
1003 );
1004 if ret != 0 {
1005 let errmsg = if errmsg_ptr.is_null() {
1006 "".to_string()
1007 } else {
1008 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1009 };
1010 if !errmsg_ptr.is_null() {
1011 libc::free(errmsg_ptr as _);
1012 }
1013 tracing::error!(section_name, error = errmsg, "failed to load code");
1014 return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1015 }
1016
1017 let mut written_len = code_slice.len();
1018 let ret = crate::ubpf::ubpf_translate(
1019 vm.0.as_ptr(),
1020 code_slice.as_mut_ptr(),
1021 &mut written_len,
1022 &mut errmsg_ptr,
1023 );
1024 if ret != 0 {
1025 let errmsg = if errmsg_ptr.is_null() {
1026 "".to_string()
1027 } else {
1028 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1029 };
1030 if !errmsg_ptr.is_null() {
1031 libc::free(errmsg_ptr as _);
1032 }
1033 tracing::error!(section_name, error = errmsg, "failed to translate code");
1034 return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1035 }
1036
1037 assert!(written_len <= code_slice.len());
1038 entrypoints.insert(
1039 section_name,
1040 Entrypoint {
1041 code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1042 - code_slice.len(),
1043 code_len: written_len,
1044 },
1045 );
1046 code_slice = &mut code_slice[written_len..];
1047 }
1048
1049 let unpadded_code_len = code_len_allocated - code_slice.len();
1051 let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1052 assert!(code_len <= code_len_allocated);
1053
1054 if libc::mprotect(
1057 code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1058 code_len,
1059 libc::PROT_READ | libc::PROT_EXEC,
1060 ) != 0
1061 || (code_len < code_len_allocated
1062 && libc::mprotect(
1063 code_mem
1064 .as_mut_ptr()
1065 .offset((guard_size_before + code_len) as isize) as *mut _,
1066 code_len_allocated - code_len,
1067 libc::PROT_NONE,
1068 ) != 0)
1069 {
1070 return Err(RuntimeError::PlatformError("failed to protect code memory"));
1071 }
1072
1073 guard_size_after += code_len_allocated - code_len;
1074
1075 tracing::info!(
1076 elf_size = elf.len(),
1077 native_code_addr = ?code_mem.as_ptr(),
1078 native_code_size = code_len,
1079 native_code_size_unpadded = unpadded_code_len,
1080 guard_size_before,
1081 guard_size_after,
1082 duration = ?start_time.elapsed(),
1083 cage_ptr = ?cage.region().as_ptr(),
1084 cage_mapped_size = cage.region().len(),
1085 "jit compiled program"
1086 );
1087
1088 Ok(UnboundProgram {
1089 id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1090 _code_mem: code_mem,
1091 cage,
1092 helper_id_xor: self.helper_id_xor,
1093 helpers: self.helpers.clone(),
1094 event_listener: self.event_listener.clone(),
1095 entrypoints,
1096 })
1097 }
1098 }
1099}
1100
1101#[derive(Default)]
1102struct Dispatch {
1103 async_preemption: bool,
1104 memory_access_error: Option<usize>,
1105
1106 index: u32,
1107 arg1: u64,
1108 arg2: u64,
1109 arg3: u64,
1110 arg4: u64,
1111 arg5: u64,
1112}
1113
1114unsafe extern "C" fn tls_dispatcher(
1115 arg1: u64,
1116 arg2: u64,
1117 arg3: u64,
1118 arg4: u64,
1119 arg5: u64,
1120 index: std::os::raw::c_uint,
1121 _cookie: *mut std::os::raw::c_void,
1122) -> u64 {
1123 let yielder = ACTIVE_JIT_CODE_ZONE
1124 .with(|x| x.yielder.get())
1125 .expect("no yielder");
1126 let yielder = yielder.as_ref();
1127 let ret = yielder.suspend(Dispatch {
1128 async_preemption: false,
1129 memory_access_error: None,
1130 index,
1131 arg1,
1132 arg2,
1133 arg3,
1134 arg4,
1135 arg5,
1136 });
1137 ret
1138}
1139
1140unsafe extern "C" fn std_validator(
1141 index: std::os::raw::c_uint,
1142 loader: *mut std::os::raw::c_void,
1143) -> bool {
1144 let loader = &*(loader as *const ProgramLoader);
1145 let entropy = (index >> 16) & 0xffff;
1146 let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1147 loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1148}
1149
1150#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1151unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1152 (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1153}
1154
1155#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1156unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1157 (*uctx).uc_mcontext.pc as usize
1158}
1159
1160unsafe extern "C" fn sigsegv_handler(
1161 _sig: i32,
1162 siginfo: *mut libc::siginfo_t,
1163 uctx: *mut libc::ucontext_t,
1164) {
1165 let fail = || restore_default_signal_handler(libc::SIGSEGV);
1166
1167 let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1168 if x.valid.load(Ordering::Relaxed) {
1169 compiler_fence(Ordering::Acquire);
1170 Some((
1171 x.code_range.get(),
1172 x.pointer_cage_protected_range.get(),
1173 x.yielder.get(),
1174 ))
1175 } else {
1176 None
1177 }
1178 }) else {
1179 return fail();
1180 };
1181
1182 let pc = program_counter(uctx);
1183
1184 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1185 return fail();
1186 }
1187
1188 if (*siginfo).si_code != 2 {
1190 return fail();
1191 }
1192
1193 let si_addr = (*siginfo).si_addr() as usize;
1194 if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1195 return fail();
1196 }
1197
1198 let yielder = yielder.expect("no yielder").as_ref();
1199 yielder.suspend(Dispatch {
1200 memory_access_error: Some(si_addr),
1201 ..Default::default()
1202 });
1203}
1204
1205unsafe extern "C" fn sigusr1_handler(
1206 _sig: i32,
1207 _siginfo: *mut libc::siginfo_t,
1208 uctx: *mut libc::ucontext_t,
1209) {
1210 SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1211
1212 let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1213 if x.valid.load(Ordering::Relaxed) {
1214 compiler_fence(Ordering::Acquire);
1215 Some((x.code_range.get(), x.yielder.get()))
1216 } else {
1217 None
1218 }
1219 }) else {
1220 return;
1221 };
1222 let pc = program_counter(uctx);
1223 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1224 return;
1225 }
1226
1227 let yielder = yielder.expect("no yielder").as_ref();
1228 yielder.suspend(Dispatch {
1229 async_preemption: true,
1230 ..Default::default()
1231 });
1232}
1233
1234unsafe fn restore_default_signal_handler(signum: i32) {
1235 let act = libc::sigaction {
1236 sa_sigaction: libc::SIG_DFL,
1237 sa_flags: libc::SA_SIGINFO,
1238 sa_mask: std::mem::zeroed(),
1239 sa_restorer: None,
1240 };
1241 if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1242 libc::abort();
1243 }
1244}
1245
1246fn get_blocked_sigset() -> libc::sigset_t {
1247 unsafe {
1248 let mut s: libc::sigset_t = std::mem::zeroed();
1249 libc::sigaddset(&mut s, libc::SIGUSR1);
1250 libc::sigaddset(&mut s, libc::SIGSEGV);
1251 s
1252 }
1253}