async_ebpf/
program.rs

1//! This module contains the core logic for eBPF program execution. A lot of
2//! `unsafe` code is used - be careful when making changes here.
3
4use std::{
5  any::{Any, TypeId},
6  cell::{Cell, RefCell},
7  collections::HashMap,
8  ffi::CStr,
9  marker::PhantomData,
10  mem::ManuallyDrop,
11  ops::{Deref, DerefMut},
12  os::raw::c_void,
13  pin::Pin,
14  ptr::NonNull,
15  rc::Rc,
16  sync::{
17    atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18    Arc, Once,
19  },
20  thread::ThreadId,
21  time::{Duration, Instant},
22};
23
24use corosensei::{
25  stack::{DefaultStack, Stack},
26  Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36  error::{Error, RuntimeError},
37  helpers::Helper,
38  linker::link_elf,
39  pointer_cage::PointerCage,
40  util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_DEREF_REGIONS: usize = 4;
47
48/// Per-invocation storage for helper state during a program run.
49pub struct InvokeScope {
50  data: HashMap<TypeId, Box<dyn Any + Send>>,
51}
52
53impl InvokeScope {
54  /// Gets or creates typed data scoped to this invocation.
55  pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
56    let ty = TypeId::of::<T>();
57    self
58      .data
59      .entry(ty)
60      .or_insert_with(|| Box::new(T::default()))
61      .downcast_mut()
62      .expect("InvokeScope::data_mut: downcast failed")
63  }
64}
65
66/// Context passed to helpers while a program is executing.
67pub struct HelperScope<'a, 'b> {
68  /// The program being executed.
69  pub program: &'a Program,
70  /// Mutable per-invocation data for helpers.
71  pub invoke: RefCell<&'a mut InvokeScope>,
72  resources: RefCell<&'a mut [&'b mut dyn Any]>,
73  mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
74  immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
75  can_post_task: bool,
76}
77
78/// A validated mutable view into user memory.
79pub struct MutableUserMemory<'a, 'b, 'c> {
80  _scope: &'c HelperScope<'a, 'b>,
81  region: NonNull<[u8]>,
82}
83
84impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
85  type Target = [u8];
86
87  fn deref(&self) -> &Self::Target {
88    unsafe { self.region.as_ref() }
89  }
90}
91
92impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
93  fn deref_mut(&mut self) -> &mut Self::Target {
94    unsafe { self.region.as_mut() }
95  }
96}
97
98impl<'a, 'b> HelperScope<'a, 'b> {
99  /// Posts an async task to be run between timeslices.
100  pub fn post_task(
101    &self,
102    task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
103  ) {
104    if !self.can_post_task {
105      panic!("HelperScope::post_task() called in a context where posting task is not allowed");
106    }
107
108    PENDING_ASYNC_TASK.with(|x| {
109      let mut x = x.borrow_mut();
110      if x.is_some() {
111        panic!("post_task called while another task is pending");
112      }
113      *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
114    });
115  }
116
117  /// Calls `callback` with a mutable resource of type `T`, if present.
118  pub fn with_resource_mut<'c, T: 'static, R>(
119    &'c self,
120    callback: impl FnOnce(Result<&mut T, ()>) -> R,
121  ) -> R {
122    let mut resources = self.resources.borrow_mut();
123    let Some(res) = resources
124      .iter_mut()
125      .filter_map(|x| x.downcast_mut::<T>())
126      .next()
127    else {
128      tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
129      return callback(Err(()));
130    };
131
132    callback(Ok(res))
133  }
134
135  /// Validates and returns an immutable view into user memory.
136  pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
137    let Some(region) = self
138      .program
139      .unbound
140      .cage
141      .safe_deref_for_read(ptr as usize, size as usize)
142    else {
143      tracing::warn!(ptr, size, "invalid read");
144      return Err(());
145    };
146
147    if size != 0 {
148      // The region must not overlap with any previously dereferenced mutable regions
149      if self
150        .mutable_dereferenced_regions
151        .iter()
152        .filter_map(|x| x.get())
153        .any(|x| nonnull_bytes_overlap(x, region))
154      {
155        tracing::warn!(ptr, size, "read overlapped with previous write");
156        return Err(());
157      }
158
159      // Find a slot to record this dereference
160      let Some(slot) = self
161        .immutable_dereferenced_regions
162        .iter()
163        .find(|x| x.get().is_none())
164      else {
165        tracing::warn!(ptr, size, "too many reads");
166        return Err(());
167      };
168      slot.set(Some(region));
169    }
170
171    Ok(unsafe { region.as_ref() })
172  }
173
174  /// Validates and returns a mutable view into user memory.
175  pub fn user_memory_mut<'c>(
176    &'c self,
177    ptr: u64,
178    size: u64,
179  ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
180    let Some(region) = self
181      .program
182      .unbound
183      .cage
184      .safe_deref_for_write(ptr as usize, size as usize)
185    else {
186      tracing::warn!(ptr, size, "invalid write");
187      return Err(());
188    };
189
190    if size != 0 {
191      // The region must not overlap with any other previously dereferenced mutable or immutable regions
192      if self
193        .mutable_dereferenced_regions
194        .iter()
195        .chain(self.immutable_dereferenced_regions.iter())
196        .filter_map(|x| x.get())
197        .any(|x| nonnull_bytes_overlap(x, region))
198      {
199        tracing::warn!(ptr, size, "write overlapped with previous read/write");
200        return Err(());
201      }
202
203      // Find a slot to record this dereference
204      let Some(slot) = self
205        .mutable_dereferenced_regions
206        .iter()
207        .find(|x| x.get().is_none())
208      else {
209        tracing::warn!(ptr, size, "too many writes");
210        return Err(());
211      };
212      slot.set(Some(region));
213    }
214
215    Ok(MutableUserMemory {
216      _scope: self,
217      region,
218    })
219  }
220}
221
222#[derive(Copy, Clone)]
223struct AssumeSend<T>(T);
224unsafe impl<T> Send for AssumeSend<T> {}
225
226struct ExecContext {
227  native_stack: DefaultStack,
228  copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
229}
230
231impl ExecContext {
232  fn new() -> Self {
233    Self {
234      native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
235        .expect("failed to initialize native stack"),
236      copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
237    }
238  }
239}
240
241/// A pending async task spawned by a helper.
242pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
243/// The callback produced by a helper async task when it resumes.
244pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
245
246static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
247
248#[derive(Copy, Clone, Debug)]
249enum PreemptionState {
250  Inactive,
251  Active(usize),
252  Shutdown,
253}
254
255type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
256
257thread_local! {
258  static RUST_TID: ThreadId = std::thread::current().id();
259  static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
260  static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
261  static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
262  static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
263  static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
264}
265
266struct BorrowedExecContext {
267  ctx: ManuallyDrop<ExecContext>,
268}
269
270impl BorrowedExecContext {
271  fn new(stack_rng: &mut impl rand::Rng) -> Self {
272    let mut me = Self {
273      ctx: ManuallyDrop::new(
274        EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
275      ),
276    };
277    stack_rng.fill(&mut *me.ctx.copy_stack);
278    me
279  }
280}
281
282impl Drop for BorrowedExecContext {
283  fn drop(&mut self) {
284    let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
285    EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
286  }
287}
288
289#[derive(Default)]
290struct ActiveJitCodeZone {
291  valid: AtomicBool,
292  code_range: Cell<(usize, usize)>,
293  pointer_cage_protected_range: Cell<(usize, usize)>,
294  yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
295}
296
297/// Hooks for observing program execution events.
298pub trait ProgramEventListener: Send + Sync + 'static {
299  /// Called after an async preemption is triggered.
300  fn did_async_preempt(&self, _scope: &HelperScope) {}
301  /// Called after yielding back to the async runtime.
302  fn did_yield(&self) {}
303  /// Called after throttling a program's execution.
304  fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
305    None
306  }
307  /// Called after saving the shadow stack before yielding.
308  fn did_save_shadow_stack(&self) {}
309  /// Called after restoring the shadow stack on resume.
310  fn did_restore_shadow_stack(&self) {}
311}
312
313/// No-op event listener implementation.
314pub struct DummyProgramEventListener;
315impl ProgramEventListener for DummyProgramEventListener {}
316
317/// Prepares helper tables and loads eBPF programs.
318pub struct ProgramLoader {
319  helpers_inverse: HashMap<&'static str, i32>,
320  event_listener: Arc<dyn ProgramEventListener>,
321  helper_id_xor: u16,
322  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
323}
324
325/// A loaded program that is not yet pinned to a thread.
326pub struct UnboundProgram {
327  id: u64,
328  _code_mem: MmapRaw,
329  cage: PointerCage,
330  helper_id_xor: u16,
331  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
332  event_listener: Arc<dyn ProgramEventListener>,
333  entrypoints: HashMap<String, Entrypoint>,
334  stack_rng_seed: [u8; 32],
335}
336
337/// A program pinned to a specific thread and ready to execute.
338pub struct Program {
339  unbound: UnboundProgram,
340  data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
341  stack_rng: RefCell<ChaCha8Rng>,
342  t: ThreadEnv,
343}
344
345#[derive(Copy, Clone)]
346struct Entrypoint {
347  code_ptr: usize,
348  code_len: usize,
349}
350
351/// Time limits used to yield or throttle execution.
352#[derive(Clone, Debug)]
353pub struct TimesliceConfig {
354  /// Maximum runtime before yielding to the async scheduler.
355  pub max_run_time_before_yield: Duration,
356  /// Maximum runtime before a throttle sleep is forced.
357  pub max_run_time_before_throttle: Duration,
358  /// Duration of the throttle sleep once triggered.
359  pub throttle_duration: Duration,
360}
361
362/// Async runtime integration for yielding and sleeping.
363pub trait Timeslicer {
364  /// Sleep for the provided duration.
365  fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
366  /// Yield to the async scheduler.
367  fn yield_now(&self) -> impl Future<Output = ()>;
368}
369
370/// Global runtime environment for signal handlers.
371#[derive(Copy, Clone)]
372pub struct GlobalEnv(());
373
374/// Per-thread runtime environment for preemption handling.
375#[derive(Copy, Clone)]
376pub struct ThreadEnv {
377  _not_send_sync: std::marker::PhantomData<*const ()>,
378}
379
380impl GlobalEnv {
381  /// Initializes global state and installs signal handlers.
382  ///
383  /// # Safety
384  /// Must be called in a process that can install SIGUSR1/SIGSEGV handlers.
385  pub unsafe fn new() -> Self {
386    static INIT: Once = Once::new();
387
388    // SIGUSR1 must be blocked during exception handling
389    // Otherwise it seems that Linux gives up and throws an uncatchable SI_KERNEL SIGSEGV:
390    //
391    // [pid 517110] tgkill(517109, 517112, SIGUSR1 <unfinished ...>
392    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SEGV_ACCERR, si_addr=0x793789be227e} ---
393    // [pid 517109] write(15, "\1\0\0\0\0\0\0\0", 8 <unfinished ...>
394    // [pid 517110] <... tgkill resumed>)      = 0
395    // [pid 517109] <... write resumed>)       = 8
396    // [pid 517112] --- SIGUSR1 {si_signo=SIGUSR1, si_code=SI_TKILL, si_pid=517109, si_uid=1000} ---
397    // [pid 517109] recvfrom(57,  <unfinished ...>
398    // [pid 517110] futex(0x79378abff5a8, FUTEX_WAIT_PRIVATE, 1, {tv_sec=0, tv_nsec=5999909} <unfinished ...>
399    // [pid 517109] <... recvfrom resumed>"GET /write_rodata HTTP/1.1\r\nHost"..., 8192, 0, NULL, NULL) = 61
400    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SI_KERNEL, si_addr=NULL} ---
401
402    INIT.call_once(|| {
403      let sa_mask = get_blocked_sigset();
404
405      for (sig, handler) in [
406        (libc::SIGUSR1, sigusr1_handler as usize),
407        (libc::SIGSEGV, sigsegv_handler as usize),
408      ] {
409        let act = libc::sigaction {
410          sa_sigaction: handler,
411          sa_flags: libc::SA_SIGINFO,
412          sa_mask,
413          sa_restorer: None,
414        };
415        if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
416          panic!("failed to setup handler for signal {}", sig);
417        }
418      }
419    });
420
421    Self(())
422  }
423
424  /// Initializes per-thread state and starts the async preemption watcher.
425  pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
426    struct DeferDrop(Arc<PreemptionStateSignal>);
427    impl Drop for DeferDrop {
428      fn drop(&mut self) {
429        let x = &self.0;
430        *x.0.lock() = PreemptionState::Shutdown;
431        x.1.notify_one();
432      }
433    }
434
435    thread_local! {
436      static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
437    }
438
439    if WATCHER.with(|x| x.borrow().is_some()) {
440      return ThreadEnv {
441        _not_send_sync: PhantomData,
442      };
443    }
444
445    let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
446
447    unsafe {
448      let tgid = libc::getpid();
449      let tid = libc::gettid();
450
451      std::thread::Builder::new()
452        .name("preempt-watcher".to_string())
453        .spawn(move || {
454          let mut state = preemption_state.0.lock();
455          loop {
456            match *state {
457              PreemptionState::Shutdown => break,
458              PreemptionState::Inactive => {
459                preemption_state.1.wait(&mut state);
460              }
461              PreemptionState::Active(_) => {
462                let timeout = preemption_state
463                  .1
464                  .wait_while_for(
465                    &mut state,
466                    |x| matches!(x, PreemptionState::Active(_)),
467                    async_preemption_interval,
468                  );
469                if timeout.timed_out() {
470                  let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
471                  if ret != 0 {
472                    break;
473                  }
474                }
475              }
476            }
477          }
478        })
479        .expect("failed to spawn preemption watcher");
480
481      WATCHER.with(|x| {
482        x.borrow_mut()
483          .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
484      });
485
486      ThreadEnv {
487        _not_send_sync: PhantomData,
488      }
489    }
490  }
491}
492
493impl UnboundProgram {
494  /// Pins the program to the current thread using a prepared `ThreadEnv`.
495  pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
496    let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
497    Program {
498      unbound: self,
499      data: RefCell::new(HashMap::new()),
500      stack_rng,
501      t,
502    }
503  }
504}
505
506pub struct PreemptionEnabled(());
507
508impl PreemptionEnabled {
509  pub fn new(_: ThreadEnv) -> Self {
510    PREEMPTION_STATE.with(|x| {
511      let mut notify = false;
512      {
513        let mut st = x.0.lock();
514        let next = match *st {
515          PreemptionState::Inactive => {
516            notify = true;
517            PreemptionState::Active(1)
518          }
519          PreemptionState::Active(n) => PreemptionState::Active(n + 1),
520          PreemptionState::Shutdown => unreachable!(),
521        };
522        *st = next;
523      }
524
525      if notify {
526        x.1.notify_one();
527      }
528    });
529    Self(())
530  }
531}
532
533impl Drop for PreemptionEnabled {
534  fn drop(&mut self) {
535    PREEMPTION_STATE.with(|x| {
536      let mut st = x.0.lock();
537      let next = match *st {
538        PreemptionState::Active(1) => PreemptionState::Inactive,
539        PreemptionState::Active(n) => {
540          assert!(n > 1);
541          PreemptionState::Active(n - 1)
542        }
543        PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
544      };
545      *st = next;
546    });
547  }
548}
549
550impl Program {
551  /// Returns the unique program identifier.
552  pub fn id(&self) -> u64 {
553    self.unbound.id
554  }
555
556  pub fn thread_env(&self) -> ThreadEnv {
557    self.t
558  }
559
560  /// Gets or creates shared typed data for this program instance.
561  pub fn data<T: Default + 'static>(&self) -> Rc<T> {
562    let mut data = self.data.borrow_mut();
563    let entry = data.entry(TypeId::of::<T>());
564    let entry = entry.or_insert_with(|| Rc::new(T::default()));
565    entry.clone().downcast().unwrap()
566  }
567
568  pub fn has_section(&self, name: &str) -> bool {
569    self.unbound.entrypoints.contains_key(name)
570  }
571
572  /// Runs the program entrypoint with the provided resources and calldata.
573  pub async fn run(
574    &self,
575    timeslice: &TimesliceConfig,
576    timeslicer: &impl Timeslicer,
577    entrypoint: &str,
578    resources: &mut [&mut dyn Any],
579    calldata: &[u8],
580    preemption: &PreemptionEnabled,
581  ) -> Result<i64, Error> {
582    self
583      ._run(
584        timeslice, timeslicer, entrypoint, resources, calldata, preemption,
585      )
586      .await
587      .map_err(Error)
588  }
589
590  async fn _run(
591    &self,
592    timeslice: &TimesliceConfig,
593    timeslicer: &impl Timeslicer,
594    entrypoint: &str,
595    resources: &mut [&mut dyn Any],
596    calldata: &[u8],
597    _: &PreemptionEnabled,
598  ) -> Result<i64, RuntimeError> {
599    let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
600      return Err(RuntimeError::InvalidArgument("entrypoint not found"));
601    };
602
603    let entry = unsafe {
604      std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
605        entrypoint.code_ptr,
606      )
607    };
608    struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
609      ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
610    );
611    impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
612      for CoDropper<'a, Input, Yield, Return, DefaultStack>
613    {
614      fn drop(&mut self) {
615        // Prevent the coroutine library from attempting to unwind the stack of the coroutine
616        // and run destructors, because this stack might be running a signal handler and
617        // it's not allowed to unwind from there.
618        //
619        // SAFETY: The coroutine stack only contains stack frames for JIT-compiled code and
620        // carefully chosen Rust code that do not hold Droppable values, so it's safe to
621        // skip destructors.
622        unsafe {
623          self.0.force_reset();
624        }
625      }
626    }
627
628    let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
629
630    if calldata.len() > MAX_CALLDATA_SIZE {
631      return Err(RuntimeError::InvalidArgument("calldata too large"));
632    }
633    ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
634    let calldata_len = calldata.len();
635
636    let program_ret: u64 = {
637      let shadow_stack_top = self.unbound.cage.stack_top();
638      let shadow_stack_ptr = AssumeSend(
639        self
640          .unbound
641          .cage
642          .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
643          .unwrap(),
644      );
645      let ctx = &mut *ectx.ctx;
646
647      let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
648        &mut ctx.native_stack,
649        move |yielder, _input| unsafe {
650          ACTIVE_JIT_CODE_ZONE.with(|x| {
651            x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
652          });
653          entry(
654            shadow_stack_top - calldata_len,
655            shadow_stack_top - calldata_len,
656          )
657        },
658      )));
659
660      let mut last_yield_time = Instant::now();
661      let mut last_throttle_time = Instant::now();
662      let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
663      let mut resume_input: u64 = 0;
664      let mut did_throttle = false;
665      let mut shadow_stack_saved = true;
666      let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
667      let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
668      let mut invoke_scope = InvokeScope {
669        data: HashMap::new(),
670      };
671
672      loop {
673        ACTIVE_JIT_CODE_ZONE.with(|x| {
674          x.code_range.set((
675            entrypoint.code_ptr,
676            entrypoint.code_ptr + entrypoint.code_len,
677          ));
678          x.yielder.set(yielder.map(|x| x.0));
679          x.pointer_cage_protected_range
680            .set(self.unbound.cage.protected_range_without_margins());
681          compiler_fence(Ordering::Release);
682          x.valid.store(true, Ordering::Relaxed);
683        });
684
685        if shadow_stack_saved {
686          shadow_stack_saved = false;
687
688          // restore shadow stack
689          unsafe {
690            std::ptr::copy_nonoverlapping(
691              ctx.copy_stack.as_ptr() as *const u8,
692              shadow_stack_ptr.0.as_ptr() as *mut u8,
693              SHADOW_STACK_SIZE,
694            );
695          }
696
697          self.unbound.event_listener.did_restore_shadow_stack();
698        }
699
700        // If the previous iteration wants to write back to machine state
701        if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
702          resume_input = prev_async_task_output(&HelperScope {
703            program: self,
704            invoke: RefCell::new(&mut invoke_scope),
705            resources: RefCell::new(resources),
706            mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
707            immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
708            can_post_task: false,
709          })
710          .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
711        }
712
713        let ret = co.0 .0.resume(resume_input);
714        ACTIVE_JIT_CODE_ZONE.with(|x| {
715          x.valid.store(false, Ordering::Relaxed);
716          compiler_fence(Ordering::Release);
717          yielder = x.yielder.get().map(AssumeSend);
718        });
719
720        let dispatch: Dispatch = match ret {
721          CoroutineResult::Return(x) => break x,
722          CoroutineResult::Yield(x) => x,
723        };
724
725        // restore signal mask of current thread
726        if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
727          unsafe {
728            let unblock = get_blocked_sigset();
729            libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
730          }
731        }
732
733        if let Some(si_addr) = dispatch.memory_access_error {
734          let vaddr = si_addr - self.unbound.cage.offset();
735          return Err(RuntimeError::MemoryFault(vaddr));
736        }
737
738        // Clear pending task if something else has set it
739        PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
740        let mut helper_name: &'static str = "";
741
742        let mut helper_scope = HelperScope {
743          program: self,
744          invoke: RefCell::new(&mut invoke_scope),
745          resources: RefCell::new(resources),
746          mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
747          immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
748          can_post_task: false,
749        };
750
751        if dispatch.async_preemption {
752          self
753            .unbound
754            .event_listener
755            .did_async_preempt(&mut helper_scope);
756        } else {
757          // validator should ensure all helper indexes are present in the table
758          let Some((_, got_helper_name, helper)) = self
759            .unbound
760            .helpers
761            .get(
762              ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
763                as usize,
764            )
765            .copied()
766          else {
767            panic!("unknown helper index: {}", dispatch.index);
768          };
769          helper_name = got_helper_name;
770
771          helper_scope.can_post_task = true;
772          resume_input = helper(
773            &mut helper_scope,
774            dispatch.arg1,
775            dispatch.arg2,
776            dispatch.arg3,
777            dispatch.arg4,
778            dispatch.arg5,
779          )
780          .map_err(|()| RuntimeError::HelperError(helper_name))?;
781          helper_scope.can_post_task = false;
782        }
783
784        let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
785
786        // Fast path: do not read timestamp if no thread migration or async preemption happened
787        let new_rust_tid_sigusr1_counter =
788          (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
789        if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
790        {
791          continue;
792        }
793
794        rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
795
796        let now = Instant::now();
797        let should_throttle = now > last_throttle_time
798          && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
799        let should_yield = now > last_yield_time
800          && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
801        if should_throttle || should_yield || pending_async_task.is_some() {
802          // We are about to yield control to tokio. Save the shadow stack, and release the guard.
803          shadow_stack_saved = true;
804          unsafe {
805            std::ptr::copy_nonoverlapping(
806              shadow_stack_ptr.0.as_ptr() as *const u8,
807              ctx.copy_stack.as_mut_ptr() as *mut u8,
808              SHADOW_STACK_SIZE,
809            );
810          }
811          self.unbound.event_listener.did_save_shadow_stack();
812
813          // we are now free to give up control of current thread to other async tasks
814
815          if should_throttle {
816            if !did_throttle {
817              did_throttle = true;
818              tracing::warn!("throttling program");
819            }
820            timeslicer.sleep(timeslice.throttle_duration).await;
821            let now = Instant::now();
822            last_throttle_time = now;
823            last_yield_time = now;
824            let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
825            if let Some(task) = task {
826              task.await;
827            }
828          } else if should_yield {
829            timeslicer.yield_now().await;
830            let now = Instant::now();
831            last_yield_time = now;
832            self.unbound.event_listener.did_yield();
833          }
834
835          // Now we have released all exclusive resources and can safely execute the async task
836          if let Some(pending_async_task) = pending_async_task {
837            let async_start = Instant::now();
838            prev_async_task_output = Some((helper_name, pending_async_task.await));
839            let async_dur = async_start.elapsed();
840            last_throttle_time += async_dur;
841            last_yield_time += async_dur;
842          }
843        }
844      }
845    };
846
847    Ok(program_ret as i64)
848  }
849}
850
851struct Vm(NonNull<crate::ubpf::ubpf_vm>);
852
853impl Vm {
854  fn new(cage: &PointerCage) -> Self {
855    let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
856    unsafe {
857      crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
858      crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
859      crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
860    }
861    Self(vm)
862  }
863}
864
865impl Drop for Vm {
866  fn drop(&mut self) {
867    unsafe {
868      crate::ubpf::ubpf_destroy(self.0.as_ptr());
869    }
870  }
871}
872
873impl ProgramLoader {
874  /// Creates a new `ProgramLoader` to load eBPF code.
875  pub fn new(
876    rng: &mut impl rand::Rng,
877    event_listener: Arc<dyn ProgramEventListener>,
878    raw_helpers: &[&[(&'static str, Helper)]],
879  ) -> Self {
880    let helper_id_xor = rng.gen::<u16>();
881    let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
882    // Collect first to a HashMap then to a Vec to deduplicate
883    let mut shuffled_helpers = raw_helpers
884      .iter()
885      .flat_map(|x| x.iter().copied())
886      .collect::<HashMap<_, _>>()
887      .into_iter()
888      .collect::<Vec<_>>();
889    shuffled_helpers.shuffle(rng);
890    let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
891
892    assert!(shuffled_helpers.len() <= 65535);
893
894    for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
895      let entropy = rng.gen::<u16>() & 0x7fff;
896      helpers.push((entropy, name, helper));
897      helpers_inverse.insert(
898        name,
899        (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
900      );
901    }
902
903    tracing::info!(?helpers_inverse, "generated helper table");
904    Self {
905      helper_id_xor,
906      helpers: Arc::new(helpers),
907      helpers_inverse,
908      event_listener,
909    }
910  }
911
912  /// Loads an ELF image into a new `UnboundProgram`.
913  pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
914    self._load(rng, elf).map_err(Error)
915  }
916
917  fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
918    let start_time = Instant::now();
919    let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
920    let vm = Vm::new(&cage);
921
922    // Relocate ELF
923    let code_sections = {
924      // XXX: Although we are writing to the data region, we need to use `safe_deref_for_read`
925      // here because the `_write` variant checks that the requested region is within the
926      // stack. It's safe here because `freeze_data` is not yet called.
927      let mut data = cage
928        .safe_deref_for_read(cage.data_bottom(), elf.len())
929        .unwrap();
930      let data = unsafe { data.as_mut() };
931      data.copy_from_slice(elf);
932
933      link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
934    };
935    cage.freeze_data();
936
937    let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
938    if page_size < 0 {
939      return Err(RuntimeError::PlatformError("failed to get page size"));
940    }
941    let page_size = page_size as usize;
942
943    // Allocate code memory
944    let guard_size_before = rng.gen_range(16..128) * page_size;
945    let mut guard_size_after = rng.gen_range(16..128) * page_size;
946
947    let code_len_allocated: usize = 65536;
948    let code_mem = MmapRaw::from(
949      MmapOptions::new()
950        .len(code_len_allocated + guard_size_before + guard_size_after)
951        .map_anon()
952        .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
953    );
954
955    unsafe {
956      if crate::ubpf::ubpf_register_external_dispatcher(
957        vm.0.as_ptr(),
958        Some(tls_dispatcher),
959        Some(std_validator),
960        self as *const _ as *mut c_void,
961      ) != 0
962      {
963        return Err(RuntimeError::PlatformError(
964          "ubpf: failed to register external dispatcher",
965        ));
966      }
967      if libc::mprotect(
968        code_mem.as_mut_ptr() as *mut _,
969        guard_size_before,
970        libc::PROT_NONE,
971      ) != 0
972        || libc::mprotect(
973          code_mem
974            .as_mut_ptr()
975            .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
976          guard_size_after,
977          libc::PROT_NONE,
978        ) != 0
979      {
980        return Err(RuntimeError::PlatformError("failed to protect guard pages"));
981      }
982    }
983
984    let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
985
986    unsafe {
987      // Translate eBPF to native code
988      let mut code_slice = std::slice::from_raw_parts_mut(
989        code_mem.as_mut_ptr().offset(guard_size_before as isize),
990        code_len_allocated,
991      );
992      for (section_name, code_vaddr_size) in code_sections {
993        if code_slice.is_empty() {
994          return Err(RuntimeError::InvalidArgument(
995            "no space left for jit compilation",
996          ));
997        }
998
999        crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
1000
1001        let mut errmsg_ptr = std::ptr::null_mut();
1002        let code = cage
1003          .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
1004          .unwrap();
1005        let ret = crate::ubpf::ubpf_load(
1006          vm.0.as_ptr(),
1007          code.as_ptr() as *const _,
1008          code.len() as u32,
1009          &mut errmsg_ptr,
1010        );
1011        if ret != 0 {
1012          let errmsg = if errmsg_ptr.is_null() {
1013            "".to_string()
1014          } else {
1015            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1016          };
1017          if !errmsg_ptr.is_null() {
1018            libc::free(errmsg_ptr as _);
1019          }
1020          tracing::error!(section_name, error = errmsg, "failed to load code");
1021          return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1022        }
1023
1024        let mut written_len = code_slice.len();
1025        let ret = crate::ubpf::ubpf_translate(
1026          vm.0.as_ptr(),
1027          code_slice.as_mut_ptr(),
1028          &mut written_len,
1029          &mut errmsg_ptr,
1030        );
1031        if ret != 0 {
1032          let errmsg = if errmsg_ptr.is_null() {
1033            "".to_string()
1034          } else {
1035            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1036          };
1037          if !errmsg_ptr.is_null() {
1038            libc::free(errmsg_ptr as _);
1039          }
1040          tracing::error!(section_name, error = errmsg, "failed to translate code");
1041          return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1042        }
1043
1044        assert!(written_len <= code_slice.len());
1045        entrypoints.insert(
1046          section_name,
1047          Entrypoint {
1048            code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1049              - code_slice.len(),
1050            code_len: written_len,
1051          },
1052        );
1053        code_slice = &mut code_slice[written_len..];
1054      }
1055
1056      // Align up code_len to page size
1057      let unpadded_code_len = code_len_allocated - code_slice.len();
1058      let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1059      assert!(code_len <= code_len_allocated);
1060
1061      // RW- -> R-X
1062      // Also make the unused part of the pre-allocated code region PROT_NONE
1063      if libc::mprotect(
1064        code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1065        code_len,
1066        libc::PROT_READ | libc::PROT_EXEC,
1067      ) != 0
1068        || (code_len < code_len_allocated
1069          && libc::mprotect(
1070            code_mem
1071              .as_mut_ptr()
1072              .offset((guard_size_before + code_len) as isize) as *mut _,
1073            code_len_allocated - code_len,
1074            libc::PROT_NONE,
1075          ) != 0)
1076      {
1077        return Err(RuntimeError::PlatformError("failed to protect code memory"));
1078      }
1079
1080      guard_size_after += code_len_allocated - code_len;
1081
1082      tracing::info!(
1083        elf_size = elf.len(),
1084        native_code_addr = ?code_mem.as_ptr(),
1085        native_code_size = code_len,
1086        native_code_size_unpadded = unpadded_code_len,
1087        guard_size_before,
1088        guard_size_after,
1089        duration = ?start_time.elapsed(),
1090        cage_ptr = ?cage.region().as_ptr(),
1091        cage_mapped_size = cage.region().len(),
1092        "jit compiled program"
1093      );
1094
1095      Ok(UnboundProgram {
1096        id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1097        _code_mem: code_mem,
1098        cage,
1099        helper_id_xor: self.helper_id_xor,
1100        helpers: self.helpers.clone(),
1101        event_listener: self.event_listener.clone(),
1102        entrypoints,
1103        stack_rng_seed: rng.gen(),
1104      })
1105    }
1106  }
1107}
1108
1109#[derive(Default)]
1110struct Dispatch {
1111  async_preemption: bool,
1112  memory_access_error: Option<usize>,
1113
1114  index: u32,
1115  arg1: u64,
1116  arg2: u64,
1117  arg3: u64,
1118  arg4: u64,
1119  arg5: u64,
1120}
1121
1122unsafe extern "C" fn tls_dispatcher(
1123  arg1: u64,
1124  arg2: u64,
1125  arg3: u64,
1126  arg4: u64,
1127  arg5: u64,
1128  index: std::os::raw::c_uint,
1129  _cookie: *mut std::os::raw::c_void,
1130) -> u64 {
1131  let yielder = ACTIVE_JIT_CODE_ZONE
1132    .with(|x| x.yielder.get())
1133    .expect("no yielder");
1134  let yielder = yielder.as_ref();
1135  let ret = yielder.suspend(Dispatch {
1136    async_preemption: false,
1137    memory_access_error: None,
1138    index,
1139    arg1,
1140    arg2,
1141    arg3,
1142    arg4,
1143    arg5,
1144  });
1145  ret
1146}
1147
1148unsafe extern "C" fn std_validator(
1149  index: std::os::raw::c_uint,
1150  loader: *mut std::os::raw::c_void,
1151) -> bool {
1152  let loader = &*(loader as *const ProgramLoader);
1153  let entropy = (index >> 16) & 0xffff;
1154  let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1155  loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1156}
1157
1158#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1159unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1160  (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1161}
1162
1163#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1164unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1165  (*uctx).uc_mcontext.pc as usize
1166}
1167
1168unsafe extern "C" fn sigsegv_handler(
1169  _sig: i32,
1170  siginfo: *mut libc::siginfo_t,
1171  uctx: *mut libc::ucontext_t,
1172) {
1173  let fail = || restore_default_signal_handler(libc::SIGSEGV);
1174
1175  let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1176    if x.valid.load(Ordering::Relaxed) {
1177      compiler_fence(Ordering::Acquire);
1178      Some((
1179        x.code_range.get(),
1180        x.pointer_cage_protected_range.get(),
1181        x.yielder.get(),
1182      ))
1183    } else {
1184      None
1185    }
1186  }) else {
1187    return fail();
1188  };
1189
1190  let pc = program_counter(uctx);
1191
1192  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1193    return fail();
1194  }
1195
1196  // SEGV_ACCERR
1197  if (*siginfo).si_code != 2 {
1198    return fail();
1199  }
1200
1201  let si_addr = (*siginfo).si_addr() as usize;
1202  if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1203    return fail();
1204  }
1205
1206  let yielder = yielder.expect("no yielder").as_ref();
1207  yielder.suspend(Dispatch {
1208    memory_access_error: Some(si_addr),
1209    ..Default::default()
1210  });
1211}
1212
1213unsafe extern "C" fn sigusr1_handler(
1214  _sig: i32,
1215  _siginfo: *mut libc::siginfo_t,
1216  uctx: *mut libc::ucontext_t,
1217) {
1218  SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1219
1220  let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1221    if x.valid.load(Ordering::Relaxed) {
1222      compiler_fence(Ordering::Acquire);
1223      Some((x.code_range.get(), x.yielder.get()))
1224    } else {
1225      None
1226    }
1227  }) else {
1228    return;
1229  };
1230  let pc = program_counter(uctx);
1231  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1232    return;
1233  }
1234
1235  let yielder = yielder.expect("no yielder").as_ref();
1236  yielder.suspend(Dispatch {
1237    async_preemption: true,
1238    ..Default::default()
1239  });
1240}
1241
1242unsafe fn restore_default_signal_handler(signum: i32) {
1243  let act = libc::sigaction {
1244    sa_sigaction: libc::SIG_DFL,
1245    sa_flags: libc::SA_SIGINFO,
1246    sa_mask: std::mem::zeroed(),
1247    sa_restorer: None,
1248  };
1249  if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1250    libc::abort();
1251  }
1252}
1253
1254fn get_blocked_sigset() -> libc::sigset_t {
1255  unsafe {
1256    let mut s: libc::sigset_t = std::mem::zeroed();
1257    libc::sigaddset(&mut s, libc::SIGUSR1);
1258    libc::sigaddset(&mut s, libc::SIGSEGV);
1259    s
1260  }
1261}