1use std::{
5 any::{Any, TypeId},
6 cell::{Cell, RefCell},
7 collections::HashMap,
8 ffi::CStr,
9 marker::PhantomData,
10 mem::ManuallyDrop,
11 ops::{Deref, DerefMut},
12 os::raw::c_void,
13 pin::Pin,
14 ptr::NonNull,
15 rc::Rc,
16 sync::{
17 atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18 mpsc::RecvTimeoutError,
19 Arc, Once,
20 },
21 thread::ThreadId,
22 time::{Duration, Instant},
23};
24
25use corosensei::{
26 stack::{DefaultStack, Stack},
27 Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
28};
29use futures::{Future, FutureExt};
30use memmap2::{MmapOptions, MmapRaw};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36 error::{Error, RuntimeError},
37 helpers::Helper,
38 linker::link_elf,
39 pointer_cage::PointerCage,
40 util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_DEREF_REGIONS: usize = 4;
47
48pub struct InvokeScope {
50 data: HashMap<TypeId, Box<dyn Any + Send>>,
51}
52
53impl InvokeScope {
54 pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
56 let ty = TypeId::of::<T>();
57 self
58 .data
59 .entry(ty)
60 .or_insert_with(|| Box::new(T::default()))
61 .downcast_mut()
62 .expect("InvokeScope::data_mut: downcast failed")
63 }
64}
65
66pub struct HelperScope<'a, 'b> {
68 pub program: &'a Program,
70 pub invoke: RefCell<&'a mut InvokeScope>,
72 resources: RefCell<&'a mut [&'b mut dyn Any]>,
73 mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
74 immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
75 can_post_task: bool,
76}
77
78pub struct MutableUserMemory<'a, 'b, 'c> {
80 _scope: &'c HelperScope<'a, 'b>,
81 region: NonNull<[u8]>,
82}
83
84impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
85 type Target = [u8];
86
87 fn deref(&self) -> &Self::Target {
88 unsafe { self.region.as_ref() }
89 }
90}
91
92impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
93 fn deref_mut(&mut self) -> &mut Self::Target {
94 unsafe { self.region.as_mut() }
95 }
96}
97
98impl<'a, 'b> HelperScope<'a, 'b> {
99 pub fn post_task(
101 &self,
102 task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
103 ) {
104 if !self.can_post_task {
105 panic!("HelperScope::post_task() called in a context where posting task is not allowed");
106 }
107
108 PENDING_ASYNC_TASK.with(|x| {
109 let mut x = x.borrow_mut();
110 if x.is_some() {
111 panic!("post_task called while another task is pending");
112 }
113 *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
114 });
115 }
116
117 pub fn with_resource_mut<'c, T: 'static, R>(
119 &'c self,
120 callback: impl FnOnce(Result<&mut T, ()>) -> R,
121 ) -> R {
122 let mut resources = self.resources.borrow_mut();
123 let Some(res) = resources
124 .iter_mut()
125 .filter_map(|x| x.downcast_mut::<T>())
126 .next()
127 else {
128 tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
129 return callback(Err(()));
130 };
131
132 callback(Ok(res))
133 }
134
135 pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
137 let Some(region) = self
138 .program
139 .unbound
140 .cage
141 .safe_deref_for_read(ptr as usize, size as usize)
142 else {
143 tracing::warn!(ptr, size, "invalid read");
144 return Err(());
145 };
146
147 if size != 0 {
148 if self
150 .mutable_dereferenced_regions
151 .iter()
152 .filter_map(|x| x.get())
153 .any(|x| nonnull_bytes_overlap(x, region))
154 {
155 tracing::warn!(ptr, size, "read overlapped with previous write");
156 return Err(());
157 }
158
159 let Some(slot) = self
161 .immutable_dereferenced_regions
162 .iter()
163 .find(|x| x.get().is_none())
164 else {
165 tracing::warn!(ptr, size, "too many reads");
166 return Err(());
167 };
168 slot.set(Some(region));
169 }
170
171 Ok(unsafe { region.as_ref() })
172 }
173
174 pub fn user_memory_mut<'c>(
176 &'c self,
177 ptr: u64,
178 size: u64,
179 ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
180 let Some(region) = self
181 .program
182 .unbound
183 .cage
184 .safe_deref_for_write(ptr as usize, size as usize)
185 else {
186 tracing::warn!(ptr, size, "invalid write");
187 return Err(());
188 };
189
190 if size != 0 {
191 if self
193 .mutable_dereferenced_regions
194 .iter()
195 .chain(self.immutable_dereferenced_regions.iter())
196 .filter_map(|x| x.get())
197 .any(|x| nonnull_bytes_overlap(x, region))
198 {
199 tracing::warn!(ptr, size, "write overlapped with previous read/write");
200 return Err(());
201 }
202
203 let Some(slot) = self
205 .mutable_dereferenced_regions
206 .iter()
207 .find(|x| x.get().is_none())
208 else {
209 tracing::warn!(ptr, size, "too many writes");
210 return Err(());
211 };
212 slot.set(Some(region));
213 }
214
215 Ok(MutableUserMemory {
216 _scope: self,
217 region,
218 })
219 }
220}
221
222#[derive(Copy, Clone)]
223struct AssumeSend<T>(T);
224unsafe impl<T> Send for AssumeSend<T> {}
225
226struct ExecContext {
227 native_stack: DefaultStack,
228 copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
229}
230
231impl ExecContext {
232 fn new() -> Self {
233 Self {
234 native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
235 .expect("failed to initialize native stack"),
236 copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
237 }
238 }
239}
240
241pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
243pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
245
246static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
247
248thread_local! {
249 static RUST_TID: ThreadId = std::thread::current().id();
250 static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
251 static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
252 static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
253 static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
254}
255
256struct BorrowedExecContext {
257 ctx: ManuallyDrop<ExecContext>,
258}
259
260impl BorrowedExecContext {
261 fn new(stack_rng: &mut impl rand::Rng) -> Self {
262 let mut me = Self {
263 ctx: ManuallyDrop::new(
264 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
265 ),
266 };
267 stack_rng.fill(&mut *me.ctx.copy_stack);
268 me
269 }
270}
271
272impl Drop for BorrowedExecContext {
273 fn drop(&mut self) {
274 let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
275 EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
276 }
277}
278
279#[derive(Default)]
280struct ActiveJitCodeZone {
281 valid: AtomicBool,
282 code_range: Cell<(usize, usize)>,
283 pointer_cage_protected_range: Cell<(usize, usize)>,
284 yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
285}
286
287pub trait ProgramEventListener: Send + Sync + 'static {
289 fn did_async_preempt(&self, _scope: &HelperScope) {}
291 fn did_yield(&self) {}
293 fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
295 None
296 }
297 fn did_save_shadow_stack(&self) {}
299 fn did_restore_shadow_stack(&self) {}
301}
302
303pub struct DummyProgramEventListener;
305impl ProgramEventListener for DummyProgramEventListener {}
306
307pub struct ProgramLoader {
309 helpers_inverse: HashMap<&'static str, i32>,
310 event_listener: Arc<dyn ProgramEventListener>,
311 helper_id_xor: u16,
312 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
313}
314
315pub struct UnboundProgram {
317 id: u64,
318 _code_mem: MmapRaw,
319 cage: PointerCage,
320 helper_id_xor: u16,
321 helpers: Arc<Vec<(u16, &'static str, Helper)>>,
322 event_listener: Arc<dyn ProgramEventListener>,
323 entrypoints: HashMap<String, Entrypoint>,
324 stack_rng_seed: [u8; 32],
325}
326
327pub struct Program {
329 unbound: UnboundProgram,
330 data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
331 stack_rng: RefCell<ChaCha8Rng>,
332}
333
334#[derive(Copy, Clone)]
335struct Entrypoint {
336 code_ptr: usize,
337 code_len: usize,
338}
339
340#[derive(Clone, Debug)]
342pub struct TimesliceConfig {
343 pub max_run_time_before_yield: Duration,
345 pub max_run_time_before_throttle: Duration,
347 pub throttle_duration: Duration,
349}
350
351pub trait Timeslicer {
353 fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
355 fn yield_now(&self) -> impl Future<Output = ()>;
357}
358
359#[derive(Copy, Clone)]
361pub struct GlobalEnv(());
362
363#[derive(Copy, Clone)]
365pub struct ThreadEnv {
366 _not_send_sync: std::marker::PhantomData<*const ()>,
367}
368
369impl GlobalEnv {
370 pub unsafe fn new() -> Self {
375 static INIT: Once = Once::new();
376
377 INIT.call_once(|| {
392 let sa_mask = get_blocked_sigset();
393
394 for (sig, handler) in [
395 (libc::SIGUSR1, sigusr1_handler as usize),
396 (libc::SIGSEGV, sigsegv_handler as usize),
397 ] {
398 let act = libc::sigaction {
399 sa_sigaction: handler,
400 sa_flags: libc::SA_SIGINFO,
401 sa_mask,
402 sa_restorer: None,
403 };
404 if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
405 panic!("failed to setup handler for signal {}", sig);
406 }
407 }
408 });
409
410 Self(())
411 }
412
413 pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
415 struct DeferDrop(
416 Option<std::sync::mpsc::Sender<()>>,
417 std::sync::mpsc::Receiver<()>,
418 );
419 impl Drop for DeferDrop {
420 fn drop(&mut self) {
421 self.0.take();
422 let _ = self.1.recv();
423 }
425 }
426
427 thread_local! {
428 static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
429 }
430
431 if WATCHER.with(|x| x.borrow().is_some()) {
432 return ThreadEnv {
433 _not_send_sync: PhantomData,
434 };
435 }
436
437 unsafe {
438 let tgid = libc::getpid();
439 let tid = libc::gettid();
440 let (watcher_completion_tx, watcher_completion_rx) = std::sync::mpsc::channel::<()>();
441 let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
442
443 std::thread::Builder::new()
444 .name("preempt-watcher".to_string())
445 .spawn(move || {
446 let _watcher_completion_tx = watcher_completion_tx;
447 loop {
448 if !matches!(
449 shutdown_rx.recv_timeout(async_preemption_interval),
450 Err(RecvTimeoutError::Timeout)
451 ) {
452 break;
453 }
454 let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
455 if ret != 0 {
456 break;
457 }
458 }
459 })
460 .expect("failed to spawn preemption watcher");
461
462 WATCHER.with(|x| {
463 x.borrow_mut()
464 .replace(DeferDrop(Some(shutdown_tx), watcher_completion_rx));
465 });
466
467 ThreadEnv {
468 _not_send_sync: PhantomData,
469 }
470 }
471 }
472}
473
474impl UnboundProgram {
475 pub fn pin_to_current_thread(self, _: ThreadEnv) -> Program {
477 let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
478 Program {
479 unbound: self,
480 data: RefCell::new(HashMap::new()),
481 stack_rng,
482 }
483 }
484}
485
486impl Program {
487 pub fn id(&self) -> u64 {
489 self.unbound.id
490 }
491
492 pub fn data<T: Default + 'static>(&self) -> Rc<T> {
494 let mut data = self.data.borrow_mut();
495 let entry = data.entry(TypeId::of::<T>());
496 let entry = entry.or_insert_with(|| Rc::new(T::default()));
497 entry.clone().downcast().unwrap()
498 }
499
500 pub fn has_section(&self, name: &str) -> bool {
501 self.unbound.entrypoints.contains_key(name)
502 }
503
504 pub async fn run(
506 &self,
507 timeslice: &TimesliceConfig,
508 timeslicer: &impl Timeslicer,
509 entrypoint: &str,
510 resources: &mut [&mut dyn Any],
511 calldata: &[u8],
512 ) -> Result<i64, Error> {
513 self
514 ._run(timeslice, timeslicer, entrypoint, resources, calldata)
515 .await
516 .map_err(Error)
517 }
518
519 async fn _run(
520 &self,
521 timeslice: &TimesliceConfig,
522 timeslicer: &impl Timeslicer,
523 entrypoint: &str,
524 resources: &mut [&mut dyn Any],
525 calldata: &[u8],
526 ) -> Result<i64, RuntimeError> {
527 let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
528 return Err(RuntimeError::InvalidArgument("entrypoint not found"));
529 };
530
531 let entry = unsafe {
532 std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
533 entrypoint.code_ptr,
534 )
535 };
536 struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
537 ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
538 );
539 impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
540 for CoDropper<'a, Input, Yield, Return, DefaultStack>
541 {
542 fn drop(&mut self) {
543 unsafe {
551 self.0.force_reset();
552 }
553 }
554 }
555
556 let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
557
558 if calldata.len() > MAX_CALLDATA_SIZE {
559 return Err(RuntimeError::InvalidArgument("calldata too large"));
560 }
561 ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
562 let calldata_len = calldata.len();
563
564 let program_ret: u64 = {
565 let shadow_stack_top = self.unbound.cage.stack_top();
566 let shadow_stack_ptr = AssumeSend(
567 self
568 .unbound
569 .cage
570 .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
571 .unwrap(),
572 );
573 let ctx = &mut *ectx.ctx;
574
575 let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
576 &mut ctx.native_stack,
577 move |yielder, _input| unsafe {
578 ACTIVE_JIT_CODE_ZONE.with(|x| {
579 x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
580 });
581 entry(
582 shadow_stack_top - calldata_len,
583 shadow_stack_top - calldata_len,
584 )
585 },
586 )));
587
588 let mut last_yield_time = Instant::now();
589 let mut last_throttle_time = Instant::now();
590 let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
591 let mut resume_input: u64 = 0;
592 let mut did_throttle = false;
593 let mut shadow_stack_saved = true;
594 let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
595 let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
596 let mut invoke_scope = InvokeScope {
597 data: HashMap::new(),
598 };
599
600 loop {
601 ACTIVE_JIT_CODE_ZONE.with(|x| {
602 x.code_range.set((
603 entrypoint.code_ptr,
604 entrypoint.code_ptr + entrypoint.code_len,
605 ));
606 x.yielder.set(yielder.map(|x| x.0));
607 x.pointer_cage_protected_range
608 .set(self.unbound.cage.protected_range_without_margins());
609 compiler_fence(Ordering::Release);
610 x.valid.store(true, Ordering::Relaxed);
611 });
612
613 if shadow_stack_saved {
614 shadow_stack_saved = false;
615
616 unsafe {
618 std::ptr::copy_nonoverlapping(
619 ctx.copy_stack.as_ptr() as *const u8,
620 shadow_stack_ptr.0.as_ptr() as *mut u8,
621 SHADOW_STACK_SIZE,
622 );
623 }
624
625 self.unbound.event_listener.did_restore_shadow_stack();
626 }
627
628 if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
630 resume_input = prev_async_task_output(&HelperScope {
631 program: self,
632 invoke: RefCell::new(&mut invoke_scope),
633 resources: RefCell::new(resources),
634 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
635 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
636 can_post_task: false,
637 })
638 .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
639 }
640
641 let ret = co.0 .0.resume(resume_input);
642 ACTIVE_JIT_CODE_ZONE.with(|x| {
643 x.valid.store(false, Ordering::Relaxed);
644 compiler_fence(Ordering::Release);
645 yielder = x.yielder.get().map(AssumeSend);
646 });
647
648 let dispatch: Dispatch = match ret {
649 CoroutineResult::Return(x) => break x,
650 CoroutineResult::Yield(x) => x,
651 };
652
653 if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
655 unsafe {
656 let unblock = get_blocked_sigset();
657 libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
658 }
659 }
660
661 if let Some(si_addr) = dispatch.memory_access_error {
662 let vaddr = si_addr - self.unbound.cage.offset();
663 return Err(RuntimeError::MemoryFault(vaddr));
664 }
665
666 PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
668 let mut helper_name: &'static str = "";
669
670 let mut helper_scope = HelperScope {
671 program: self,
672 invoke: RefCell::new(&mut invoke_scope),
673 resources: RefCell::new(resources),
674 mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
675 immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
676 can_post_task: false,
677 };
678
679 if dispatch.async_preemption {
680 self
681 .unbound
682 .event_listener
683 .did_async_preempt(&mut helper_scope);
684 } else {
685 let Some((_, got_helper_name, helper)) = self
687 .unbound
688 .helpers
689 .get(
690 ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
691 as usize,
692 )
693 .copied()
694 else {
695 panic!("unknown helper index: {}", dispatch.index);
696 };
697 helper_name = got_helper_name;
698
699 helper_scope.can_post_task = true;
700 resume_input = helper(
701 &mut helper_scope,
702 dispatch.arg1,
703 dispatch.arg2,
704 dispatch.arg3,
705 dispatch.arg4,
706 dispatch.arg5,
707 )
708 .map_err(|()| RuntimeError::HelperError(helper_name))?;
709 helper_scope.can_post_task = false;
710 }
711
712 let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
713
714 let new_rust_tid_sigusr1_counter =
716 (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
717 if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
718 {
719 continue;
720 }
721
722 rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
723
724 let now = Instant::now();
725 let should_throttle = now > last_throttle_time
726 && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
727 let should_yield = now > last_yield_time
728 && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
729 if should_throttle || should_yield || pending_async_task.is_some() {
730 shadow_stack_saved = true;
732 unsafe {
733 std::ptr::copy_nonoverlapping(
734 shadow_stack_ptr.0.as_ptr() as *const u8,
735 ctx.copy_stack.as_mut_ptr() as *mut u8,
736 SHADOW_STACK_SIZE,
737 );
738 }
739 self.unbound.event_listener.did_save_shadow_stack();
740
741 if should_throttle {
744 if !did_throttle {
745 did_throttle = true;
746 tracing::warn!("throttling program");
747 }
748 timeslicer.sleep(timeslice.throttle_duration).await;
749 let now = Instant::now();
750 last_throttle_time = now;
751 last_yield_time = now;
752 let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
753 if let Some(task) = task {
754 task.await;
755 }
756 } else if should_yield {
757 timeslicer.yield_now().await;
758 let now = Instant::now();
759 last_yield_time = now;
760 self.unbound.event_listener.did_yield();
761 }
762
763 if let Some(pending_async_task) = pending_async_task {
765 let async_start = Instant::now();
766 prev_async_task_output = Some((helper_name, pending_async_task.await));
767 let async_dur = async_start.elapsed();
768 last_throttle_time += async_dur;
769 last_yield_time += async_dur;
770 }
771 }
772 }
773 };
774
775 Ok(program_ret as i64)
776 }
777}
778
779struct Vm(NonNull<crate::ubpf::ubpf_vm>);
780
781impl Vm {
782 fn new(cage: &PointerCage) -> Self {
783 let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
784 unsafe {
785 crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
786 crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
787 crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
788 }
789 Self(vm)
790 }
791}
792
793impl Drop for Vm {
794 fn drop(&mut self) {
795 unsafe {
796 crate::ubpf::ubpf_destroy(self.0.as_ptr());
797 }
798 }
799}
800
801impl ProgramLoader {
802 pub fn new(
804 rng: &mut impl rand::Rng,
805 event_listener: Arc<dyn ProgramEventListener>,
806 raw_helpers: &[&[(&'static str, Helper)]],
807 ) -> Self {
808 let helper_id_xor = rng.gen::<u16>();
809 let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
810 let mut shuffled_helpers = raw_helpers
812 .iter()
813 .flat_map(|x| x.iter().copied())
814 .collect::<HashMap<_, _>>()
815 .into_iter()
816 .collect::<Vec<_>>();
817 shuffled_helpers.shuffle(rng);
818 let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
819
820 assert!(shuffled_helpers.len() <= 65535);
821
822 for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
823 let entropy = rng.gen::<u16>() & 0x7fff;
824 helpers.push((entropy, name, helper));
825 helpers_inverse.insert(
826 name,
827 (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
828 );
829 }
830
831 tracing::info!(?helpers_inverse, "generated helper table");
832 Self {
833 helper_id_xor,
834 helpers: Arc::new(helpers),
835 helpers_inverse,
836 event_listener,
837 }
838 }
839
840 pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
842 self._load(rng, elf).map_err(Error)
843 }
844
845 fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
846 let start_time = Instant::now();
847 let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
848 let vm = Vm::new(&cage);
849
850 let code_sections = {
852 let mut data = cage
856 .safe_deref_for_read(cage.data_bottom(), elf.len())
857 .unwrap();
858 let data = unsafe { data.as_mut() };
859 data.copy_from_slice(elf);
860
861 link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
862 };
863 cage.freeze_data();
864
865 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
866 if page_size < 0 {
867 return Err(RuntimeError::PlatformError("failed to get page size"));
868 }
869 let page_size = page_size as usize;
870
871 let guard_size_before = rng.gen_range(16..128) * page_size;
873 let mut guard_size_after = rng.gen_range(16..128) * page_size;
874
875 let code_len_allocated: usize = 65536;
876 let code_mem = MmapRaw::from(
877 MmapOptions::new()
878 .len(code_len_allocated + guard_size_before + guard_size_after)
879 .map_anon()
880 .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
881 );
882
883 unsafe {
884 if crate::ubpf::ubpf_register_external_dispatcher(
885 vm.0.as_ptr(),
886 Some(tls_dispatcher),
887 Some(std_validator),
888 self as *const _ as *mut c_void,
889 ) != 0
890 {
891 return Err(RuntimeError::PlatformError(
892 "ubpf: failed to register external dispatcher",
893 ));
894 }
895 if libc::mprotect(
896 code_mem.as_mut_ptr() as *mut _,
897 guard_size_before,
898 libc::PROT_NONE,
899 ) != 0
900 || libc::mprotect(
901 code_mem
902 .as_mut_ptr()
903 .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
904 guard_size_after,
905 libc::PROT_NONE,
906 ) != 0
907 {
908 return Err(RuntimeError::PlatformError("failed to protect guard pages"));
909 }
910 }
911
912 let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
913
914 unsafe {
915 let mut code_slice = std::slice::from_raw_parts_mut(
917 code_mem.as_mut_ptr().offset(guard_size_before as isize),
918 code_len_allocated,
919 );
920 for (section_name, code_vaddr_size) in code_sections {
921 if code_slice.is_empty() {
922 return Err(RuntimeError::InvalidArgument(
923 "no space left for jit compilation",
924 ));
925 }
926
927 crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
928
929 let mut errmsg_ptr = std::ptr::null_mut();
930 let code = cage
931 .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
932 .unwrap();
933 let ret = crate::ubpf::ubpf_load(
934 vm.0.as_ptr(),
935 code.as_ptr() as *const _,
936 code.len() as u32,
937 &mut errmsg_ptr,
938 );
939 if ret != 0 {
940 let errmsg = if errmsg_ptr.is_null() {
941 "".to_string()
942 } else {
943 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
944 };
945 if !errmsg_ptr.is_null() {
946 libc::free(errmsg_ptr as _);
947 }
948 tracing::error!(section_name, error = errmsg, "failed to load code");
949 return Err(RuntimeError::PlatformError("ubpf: code load failed"));
950 }
951
952 let mut written_len = code_slice.len();
953 let ret = crate::ubpf::ubpf_translate(
954 vm.0.as_ptr(),
955 code_slice.as_mut_ptr(),
956 &mut written_len,
957 &mut errmsg_ptr,
958 );
959 if ret != 0 {
960 let errmsg = if errmsg_ptr.is_null() {
961 "".to_string()
962 } else {
963 CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
964 };
965 if !errmsg_ptr.is_null() {
966 libc::free(errmsg_ptr as _);
967 }
968 tracing::error!(section_name, error = errmsg, "failed to translate code");
969 return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
970 }
971
972 assert!(written_len <= code_slice.len());
973 entrypoints.insert(
974 section_name,
975 Entrypoint {
976 code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
977 - code_slice.len(),
978 code_len: written_len,
979 },
980 );
981 code_slice = &mut code_slice[written_len..];
982 }
983
984 let unpadded_code_len = code_len_allocated - code_slice.len();
986 let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
987 assert!(code_len <= code_len_allocated);
988
989 if libc::mprotect(
992 code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
993 code_len,
994 libc::PROT_READ | libc::PROT_EXEC,
995 ) != 0
996 || (code_len < code_len_allocated
997 && libc::mprotect(
998 code_mem
999 .as_mut_ptr()
1000 .offset((guard_size_before + code_len) as isize) as *mut _,
1001 code_len_allocated - code_len,
1002 libc::PROT_NONE,
1003 ) != 0)
1004 {
1005 return Err(RuntimeError::PlatformError("failed to protect code memory"));
1006 }
1007
1008 guard_size_after += code_len_allocated - code_len;
1009
1010 tracing::info!(
1011 elf_size = elf.len(),
1012 native_code_addr = ?code_mem.as_ptr(),
1013 native_code_size = code_len,
1014 native_code_size_unpadded = unpadded_code_len,
1015 guard_size_before,
1016 guard_size_after,
1017 duration = ?start_time.elapsed(),
1018 cage_ptr = ?cage.region().as_ptr(),
1019 cage_mapped_size = cage.region().len(),
1020 "jit compiled program"
1021 );
1022
1023 Ok(UnboundProgram {
1024 id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1025 _code_mem: code_mem,
1026 cage,
1027 helper_id_xor: self.helper_id_xor,
1028 helpers: self.helpers.clone(),
1029 event_listener: self.event_listener.clone(),
1030 entrypoints,
1031 stack_rng_seed: rng.gen(),
1032 })
1033 }
1034 }
1035}
1036
1037#[derive(Default)]
1038struct Dispatch {
1039 async_preemption: bool,
1040 memory_access_error: Option<usize>,
1041
1042 index: u32,
1043 arg1: u64,
1044 arg2: u64,
1045 arg3: u64,
1046 arg4: u64,
1047 arg5: u64,
1048}
1049
1050unsafe extern "C" fn tls_dispatcher(
1051 arg1: u64,
1052 arg2: u64,
1053 arg3: u64,
1054 arg4: u64,
1055 arg5: u64,
1056 index: std::os::raw::c_uint,
1057 _cookie: *mut std::os::raw::c_void,
1058) -> u64 {
1059 let yielder = ACTIVE_JIT_CODE_ZONE
1060 .with(|x| x.yielder.get())
1061 .expect("no yielder");
1062 let yielder = yielder.as_ref();
1063 let ret = yielder.suspend(Dispatch {
1064 async_preemption: false,
1065 memory_access_error: None,
1066 index,
1067 arg1,
1068 arg2,
1069 arg3,
1070 arg4,
1071 arg5,
1072 });
1073 ret
1074}
1075
1076unsafe extern "C" fn std_validator(
1077 index: std::os::raw::c_uint,
1078 loader: *mut std::os::raw::c_void,
1079) -> bool {
1080 let loader = &*(loader as *const ProgramLoader);
1081 let entropy = (index >> 16) & 0xffff;
1082 let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1083 loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1084}
1085
1086#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1087unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1088 (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1089}
1090
1091#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1092unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1093 (*uctx).uc_mcontext.pc as usize
1094}
1095
1096unsafe extern "C" fn sigsegv_handler(
1097 _sig: i32,
1098 siginfo: *mut libc::siginfo_t,
1099 uctx: *mut libc::ucontext_t,
1100) {
1101 let fail = || restore_default_signal_handler(libc::SIGSEGV);
1102
1103 let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1104 if x.valid.load(Ordering::Relaxed) {
1105 compiler_fence(Ordering::Acquire);
1106 Some((
1107 x.code_range.get(),
1108 x.pointer_cage_protected_range.get(),
1109 x.yielder.get(),
1110 ))
1111 } else {
1112 None
1113 }
1114 }) else {
1115 return fail();
1116 };
1117
1118 let pc = program_counter(uctx);
1119
1120 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1121 return fail();
1122 }
1123
1124 if (*siginfo).si_code != 2 {
1126 return fail();
1127 }
1128
1129 let si_addr = (*siginfo).si_addr() as usize;
1130 if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1131 return fail();
1132 }
1133
1134 let yielder = yielder.expect("no yielder").as_ref();
1135 yielder.suspend(Dispatch {
1136 memory_access_error: Some(si_addr),
1137 ..Default::default()
1138 });
1139}
1140
1141unsafe extern "C" fn sigusr1_handler(
1142 _sig: i32,
1143 _siginfo: *mut libc::siginfo_t,
1144 uctx: *mut libc::ucontext_t,
1145) {
1146 SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1147
1148 let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1149 if x.valid.load(Ordering::Relaxed) {
1150 compiler_fence(Ordering::Acquire);
1151 Some((x.code_range.get(), x.yielder.get()))
1152 } else {
1153 None
1154 }
1155 }) else {
1156 return;
1157 };
1158 let pc = program_counter(uctx);
1159 if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1160 return;
1161 }
1162
1163 let yielder = yielder.expect("no yielder").as_ref();
1164 yielder.suspend(Dispatch {
1165 async_preemption: true,
1166 ..Default::default()
1167 });
1168}
1169
1170unsafe fn restore_default_signal_handler(signum: i32) {
1171 let act = libc::sigaction {
1172 sa_sigaction: libc::SIG_DFL,
1173 sa_flags: libc::SA_SIGINFO,
1174 sa_mask: std::mem::zeroed(),
1175 sa_restorer: None,
1176 };
1177 if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1178 libc::abort();
1179 }
1180}
1181
1182fn get_blocked_sigset() -> libc::sigset_t {
1183 unsafe {
1184 let mut s: libc::sigset_t = std::mem::zeroed();
1185 libc::sigaddset(&mut s, libc::SIGUSR1);
1186 libc::sigaddset(&mut s, libc::SIGSEGV);
1187 s
1188 }
1189}