Skip to main content

async_ebpf/
program.rs

1//! This module contains the core logic for eBPF program execution. A lot of
2//! `unsafe` code is used - be careful when making changes here.
3
4use std::{
5  any::{Any, TypeId},
6  cell::{Cell, RefCell},
7  collections::HashMap,
8  ffi::CStr,
9  marker::PhantomData,
10  mem::ManuallyDrop,
11  ops::{Deref, DerefMut},
12  os::raw::c_void,
13  pin::Pin,
14  ptr::NonNull,
15  rc::Rc,
16  sync::{
17    atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18    Arc, Once,
19  },
20  thread::ThreadId,
21  time::{Duration, Instant},
22};
23
24use corosensei::{
25  stack::{DefaultStack, Stack},
26  Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32
33use crate::{
34  error::{Error, RuntimeError},
35  helpers::Helper,
36  linker::link_elf,
37  pointer_cage::PointerCage,
38  util::nonnull_bytes_overlap,
39};
40
41const NATIVE_STACK_SIZE: usize = 16384;
42const SHADOW_STACK_SIZE: usize = 4096;
43const MAX_CALLDATA_SIZE: usize = 512;
44const MAX_MUTABLE_DEREF_REGIONS: usize = 4;
45const MAX_IMMUTABLE_DEREF_REGIONS: usize = 16;
46
47/// Per-invocation storage for helper state during a program run.
48pub struct InvokeScope {
49  data: HashMap<TypeId, Box<dyn Any + Send>>,
50}
51
52impl InvokeScope {
53  /// Gets or creates typed data scoped to this invocation.
54  pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
55    let ty = TypeId::of::<T>();
56    self
57      .data
58      .entry(ty)
59      .or_insert_with(|| Box::new(T::default()))
60      .downcast_mut()
61      .expect("InvokeScope::data_mut: downcast failed")
62  }
63}
64
65/// Context passed to helpers while a program is executing.
66pub struct HelperScope<'a, 'b> {
67  /// The program being executed.
68  pub program: &'a Program,
69  /// Mutable per-invocation data for helpers.
70  pub invoke: RefCell<&'a mut InvokeScope>,
71  resources: RefCell<&'a mut [&'b mut dyn Any]>,
72  mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_MUTABLE_DEREF_REGIONS],
73  immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_IMMUTABLE_DEREF_REGIONS],
74  can_post_task: bool,
75}
76
77/// A validated mutable view into user memory.
78pub struct MutableUserMemory<'a, 'b, 'c> {
79  _scope: &'c HelperScope<'a, 'b>,
80  region: NonNull<[u8]>,
81}
82
83impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
84  type Target = [u8];
85
86  fn deref(&self) -> &Self::Target {
87    unsafe { self.region.as_ref() }
88  }
89}
90
91impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
92  fn deref_mut(&mut self) -> &mut Self::Target {
93    unsafe { self.region.as_mut() }
94  }
95}
96
97impl<'a, 'b> HelperScope<'a, 'b> {
98  /// Posts an async task to be run between timeslices.
99  pub fn post_task(
100    &self,
101    task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
102  ) {
103    if !self.can_post_task {
104      panic!("HelperScope::post_task() called in a context where posting task is not allowed");
105    }
106
107    PENDING_ASYNC_TASK.with(|x| {
108      let mut x = x.borrow_mut();
109      if x.is_some() {
110        panic!("post_task called while another task is pending");
111      }
112      *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
113    });
114  }
115
116  /// Calls `callback` with a mutable resource of type `T`, if present.
117  pub fn with_resource_mut<'c, T: 'static, R>(
118    &'c self,
119    callback: impl FnOnce(Result<&mut T, ()>) -> R,
120  ) -> R {
121    let mut resources = self.resources.borrow_mut();
122    let Some(res) = resources
123      .iter_mut()
124      .filter_map(|x| x.downcast_mut::<T>())
125      .next()
126    else {
127      tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
128      return callback(Err(()));
129    };
130
131    callback(Ok(res))
132  }
133
134  /// Validates and returns an immutable view into user memory.
135  pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
136    let Some(region) = self
137      .program
138      .unbound
139      .cage
140      .safe_deref_for_read(ptr as usize, size as usize)
141    else {
142      tracing::warn!(ptr, size, "invalid read");
143      return Err(());
144    };
145
146    if size != 0 {
147      // The region must not overlap with any previously dereferenced mutable regions
148      if self
149        .mutable_dereferenced_regions
150        .iter()
151        .filter_map(|x| x.get())
152        .any(|x| nonnull_bytes_overlap(x, region))
153      {
154        tracing::warn!(ptr, size, "read overlapped with previous write");
155        return Err(());
156      }
157
158      // Find a slot to record this dereference
159      let Some(slot) = self
160        .immutable_dereferenced_regions
161        .iter()
162        .find(|x| x.get().is_none())
163      else {
164        tracing::warn!(ptr, size, "too many reads");
165        return Err(());
166      };
167      slot.set(Some(region));
168    }
169
170    Ok(unsafe { region.as_ref() })
171  }
172
173  /// Validates and returns a mutable view into user memory.
174  pub fn user_memory_mut<'c>(
175    &'c self,
176    ptr: u64,
177    size: u64,
178  ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
179    let Some(region) = self
180      .program
181      .unbound
182      .cage
183      .safe_deref_for_write(ptr as usize, size as usize)
184    else {
185      tracing::warn!(ptr, size, "invalid write");
186      return Err(());
187    };
188
189    if size != 0 {
190      // The region must not overlap with any other previously dereferenced mutable or immutable regions
191      if self
192        .mutable_dereferenced_regions
193        .iter()
194        .chain(self.immutable_dereferenced_regions.iter())
195        .filter_map(|x| x.get())
196        .any(|x| nonnull_bytes_overlap(x, region))
197      {
198        tracing::warn!(ptr, size, "write overlapped with previous read/write");
199        return Err(());
200      }
201
202      // Find a slot to record this dereference
203      let Some(slot) = self
204        .mutable_dereferenced_regions
205        .iter()
206        .find(|x| x.get().is_none())
207      else {
208        tracing::warn!(ptr, size, "too many writes");
209        return Err(());
210      };
211      slot.set(Some(region));
212    }
213
214    Ok(MutableUserMemory {
215      _scope: self,
216      region,
217    })
218  }
219}
220
221#[derive(Copy, Clone)]
222struct AssumeSend<T>(T);
223unsafe impl<T> Send for AssumeSend<T> {}
224
225struct ExecContext {
226  native_stack: DefaultStack,
227  copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
228}
229
230impl ExecContext {
231  fn new() -> Self {
232    Self {
233      native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
234        .expect("failed to initialize native stack"),
235      copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
236    }
237  }
238}
239
240/// A pending async task spawned by a helper.
241pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
242/// The callback produced by a helper async task when it resumes.
243pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
244
245static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
246
247#[derive(Copy, Clone, Debug)]
248enum PreemptionState {
249  Inactive,
250  Active(usize),
251  Shutdown,
252}
253
254type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
255
256thread_local! {
257  static RUST_TID: ThreadId = std::thread::current().id();
258  static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
259  static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
260  static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
261  static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
262  static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
263}
264
265struct BorrowedExecContext {
266  ctx: ManuallyDrop<ExecContext>,
267}
268
269impl BorrowedExecContext {
270  fn new() -> Self {
271    let mut me = Self {
272      ctx: ManuallyDrop::new(
273        EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
274      ),
275    };
276    me.ctx.copy_stack.fill(0x8e);
277    me
278  }
279}
280
281impl Drop for BorrowedExecContext {
282  fn drop(&mut self) {
283    let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
284    EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
285  }
286}
287
288#[derive(Default)]
289struct ActiveJitCodeZone {
290  valid: AtomicBool,
291  code_range: Cell<(usize, usize)>,
292  pointer_cage_protected_range: Cell<(usize, usize)>,
293  yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
294}
295
296/// Hooks for observing program execution events.
297pub trait ProgramEventListener: Send + Sync + 'static {
298  /// Called after an async preemption is triggered.
299  fn did_async_preempt(&self, _scope: &HelperScope) {}
300  /// Called after yielding back to the async runtime.
301  fn did_yield(&self) {}
302  /// Called after throttling a program's execution.
303  fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
304    None
305  }
306  /// Called after saving the shadow stack before yielding.
307  fn did_save_shadow_stack(&self) {}
308  /// Called after restoring the shadow stack on resume.
309  fn did_restore_shadow_stack(&self) {}
310}
311
312/// No-op event listener implementation.
313pub struct DummyProgramEventListener;
314impl ProgramEventListener for DummyProgramEventListener {}
315
316/// Prepares helper tables and loads eBPF programs.
317pub struct ProgramLoader {
318  helpers_inverse: HashMap<&'static str, i32>,
319  event_listener: Arc<dyn ProgramEventListener>,
320  helper_id_xor: u16,
321  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
322}
323
324/// A loaded program that is not yet pinned to a thread.
325pub struct UnboundProgram {
326  id: u64,
327  _code_mem: MmapRaw,
328  cage: PointerCage,
329  helper_id_xor: u16,
330  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
331  event_listener: Arc<dyn ProgramEventListener>,
332  entrypoints: HashMap<String, Entrypoint>,
333}
334
335/// A program pinned to a specific thread and ready to execute.
336pub struct Program {
337  unbound: UnboundProgram,
338  data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
339  t: ThreadEnv,
340}
341
342#[derive(Copy, Clone)]
343struct Entrypoint {
344  code_ptr: usize,
345  code_len: usize,
346}
347
348/// Time limits used to yield or throttle execution.
349#[derive(Clone, Debug)]
350pub struct TimesliceConfig {
351  /// Maximum runtime before yielding to the async scheduler.
352  pub max_run_time_before_yield: Duration,
353  /// Maximum runtime before a throttle sleep is forced.
354  pub max_run_time_before_throttle: Duration,
355  /// Duration of the throttle sleep once triggered.
356  pub throttle_duration: Duration,
357}
358
359/// Async runtime integration for yielding and sleeping.
360pub trait Timeslicer {
361  /// Sleep for the provided duration.
362  fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
363  /// Yield to the async scheduler.
364  fn yield_now(&self) -> impl Future<Output = ()>;
365}
366
367/// Global runtime environment for signal handlers.
368#[derive(Copy, Clone)]
369pub struct GlobalEnv(());
370
371/// Per-thread runtime environment for preemption handling.
372#[derive(Copy, Clone)]
373pub struct ThreadEnv {
374  _not_send_sync: std::marker::PhantomData<*const ()>,
375}
376
377impl GlobalEnv {
378  /// Initializes global state and installs signal handlers.
379  ///
380  /// # Safety
381  /// Must be called in a process that can install SIGUSR1/SIGSEGV handlers.
382  pub unsafe fn new() -> Self {
383    static INIT: Once = Once::new();
384
385    // SIGUSR1 must be blocked during exception handling
386    // Otherwise it seems that Linux gives up and throws an uncatchable SI_KERNEL SIGSEGV:
387    //
388    // [pid 517110] tgkill(517109, 517112, SIGUSR1 <unfinished ...>
389    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SEGV_ACCERR, si_addr=0x793789be227e} ---
390    // [pid 517109] write(15, "\1\0\0\0\0\0\0\0", 8 <unfinished ...>
391    // [pid 517110] <... tgkill resumed>)      = 0
392    // [pid 517109] <... write resumed>)       = 8
393    // [pid 517112] --- SIGUSR1 {si_signo=SIGUSR1, si_code=SI_TKILL, si_pid=517109, si_uid=1000} ---
394    // [pid 517109] recvfrom(57,  <unfinished ...>
395    // [pid 517110] futex(0x79378abff5a8, FUTEX_WAIT_PRIVATE, 1, {tv_sec=0, tv_nsec=5999909} <unfinished ...>
396    // [pid 517109] <... recvfrom resumed>"GET /write_rodata HTTP/1.1\r\nHost"..., 8192, 0, NULL, NULL) = 61
397    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SI_KERNEL, si_addr=NULL} ---
398
399    INIT.call_once(|| {
400      let sa_mask = get_blocked_sigset();
401
402      for (sig, handler) in [
403        (libc::SIGUSR1, sigusr1_handler as usize),
404        (libc::SIGSEGV, sigsegv_handler as usize),
405      ] {
406        let act = libc::sigaction {
407          sa_sigaction: handler,
408          sa_flags: libc::SA_SIGINFO,
409          sa_mask,
410          sa_restorer: None,
411        };
412        if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
413          panic!("failed to setup handler for signal {}", sig);
414        }
415      }
416    });
417
418    Self(())
419  }
420
421  /// Initializes per-thread state and starts the async preemption watcher.
422  pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
423    struct DeferDrop(Arc<PreemptionStateSignal>);
424    impl Drop for DeferDrop {
425      fn drop(&mut self) {
426        let x = &self.0;
427        *x.0.lock() = PreemptionState::Shutdown;
428        x.1.notify_one();
429      }
430    }
431
432    thread_local! {
433      static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
434    }
435
436    if WATCHER.with(|x| x.borrow().is_some()) {
437      return ThreadEnv {
438        _not_send_sync: PhantomData,
439      };
440    }
441
442    let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
443
444    unsafe {
445      let tgid = libc::getpid();
446      let tid = libc::gettid();
447
448      std::thread::Builder::new()
449        .name("preempt-watcher".to_string())
450        .spawn(move || {
451          let mut state = preemption_state.0.lock();
452          loop {
453            match *state {
454              PreemptionState::Shutdown => break,
455              PreemptionState::Inactive => {
456                preemption_state.1.wait(&mut state);
457              }
458              PreemptionState::Active(_) => {
459                let timeout = preemption_state.1.wait_while_for(
460                  &mut state,
461                  |x| matches!(x, PreemptionState::Active(_)),
462                  async_preemption_interval,
463                );
464                if timeout.timed_out() {
465                  let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
466                  if ret != 0 {
467                    break;
468                  }
469                }
470              }
471            }
472          }
473        })
474        .expect("failed to spawn preemption watcher");
475
476      WATCHER.with(|x| {
477        x.borrow_mut()
478          .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
479      });
480
481      ThreadEnv {
482        _not_send_sync: PhantomData,
483      }
484    }
485  }
486}
487
488impl UnboundProgram {
489  /// Pins the program to the current thread using a prepared `ThreadEnv`.
490  pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
491    Program {
492      unbound: self,
493      data: RefCell::new(HashMap::new()),
494      t,
495    }
496  }
497}
498
499pub struct PreemptionEnabled(());
500
501impl PreemptionEnabled {
502  pub fn new(_: ThreadEnv) -> Self {
503    PREEMPTION_STATE.with(|x| {
504      let mut notify = false;
505      {
506        let mut st = x.0.lock();
507        let next = match *st {
508          PreemptionState::Inactive => {
509            notify = true;
510            PreemptionState::Active(1)
511          }
512          PreemptionState::Active(n) => PreemptionState::Active(n + 1),
513          PreemptionState::Shutdown => unreachable!(),
514        };
515        *st = next;
516      }
517
518      if notify {
519        x.1.notify_one();
520      }
521    });
522    Self(())
523  }
524}
525
526impl Drop for PreemptionEnabled {
527  fn drop(&mut self) {
528    PREEMPTION_STATE.with(|x| {
529      let mut st = x.0.lock();
530      let next = match *st {
531        PreemptionState::Active(1) => PreemptionState::Inactive,
532        PreemptionState::Active(n) => {
533          assert!(n > 1);
534          PreemptionState::Active(n - 1)
535        }
536        PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
537      };
538      *st = next;
539    });
540  }
541}
542
543impl Program {
544  /// Returns the unique program identifier.
545  pub fn id(&self) -> u64 {
546    self.unbound.id
547  }
548
549  pub fn thread_env(&self) -> ThreadEnv {
550    self.t
551  }
552
553  /// Gets or creates shared typed data for this program instance.
554  pub fn data<T: Default + 'static>(&self) -> Rc<T> {
555    let mut data = self.data.borrow_mut();
556    let entry = data.entry(TypeId::of::<T>());
557    let entry = entry.or_insert_with(|| Rc::new(T::default()));
558    entry.clone().downcast().unwrap()
559  }
560
561  pub fn has_section(&self, name: &str) -> bool {
562    self.unbound.entrypoints.contains_key(name)
563  }
564
565  /// Runs the program entrypoint with the provided resources and calldata.
566  pub async fn run(
567    &self,
568    timeslice: &TimesliceConfig,
569    timeslicer: &impl Timeslicer,
570    entrypoint: &str,
571    resources: &mut [&mut dyn Any],
572    calldata: &[u8],
573    preemption: &PreemptionEnabled,
574  ) -> Result<i64, Error> {
575    self
576      ._run(
577        timeslice, timeslicer, entrypoint, resources, calldata, preemption,
578      )
579      .await
580      .map_err(Error)
581  }
582
583  async fn _run(
584    &self,
585    timeslice: &TimesliceConfig,
586    timeslicer: &impl Timeslicer,
587    entrypoint: &str,
588    resources: &mut [&mut dyn Any],
589    calldata: &[u8],
590    _: &PreemptionEnabled,
591  ) -> Result<i64, RuntimeError> {
592    let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
593      return Err(RuntimeError::InvalidArgument("entrypoint not found"));
594    };
595
596    let entry = unsafe {
597      std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
598        entrypoint.code_ptr,
599      )
600    };
601    struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
602      ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
603    );
604    impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
605      for CoDropper<'a, Input, Yield, Return, DefaultStack>
606    {
607      fn drop(&mut self) {
608        // Prevent the coroutine library from attempting to unwind the stack of the coroutine
609        // and run destructors, because this stack might be running a signal handler and
610        // it's not allowed to unwind from there.
611        //
612        // SAFETY: The coroutine stack only contains stack frames for JIT-compiled code and
613        // carefully chosen Rust code that do not hold Droppable values, so it's safe to
614        // skip destructors.
615        unsafe {
616          self.0.force_reset();
617        }
618      }
619    }
620
621    let mut ectx = BorrowedExecContext::new();
622
623    if calldata.len() > MAX_CALLDATA_SIZE {
624      return Err(RuntimeError::InvalidArgument("calldata too large"));
625    }
626    ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
627    let calldata_len = calldata.len();
628
629    let program_ret: u64 = {
630      let shadow_stack_top = self.unbound.cage.stack_top();
631      let shadow_stack_ptr = AssumeSend(
632        self
633          .unbound
634          .cage
635          .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
636          .unwrap(),
637      );
638      let ctx = &mut *ectx.ctx;
639
640      let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
641        &mut ctx.native_stack,
642        move |yielder, _input| unsafe {
643          ACTIVE_JIT_CODE_ZONE.with(|x| {
644            x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
645          });
646          entry(
647            shadow_stack_top - calldata_len,
648            shadow_stack_top - calldata_len,
649          )
650        },
651      )));
652
653      let mut last_yield_time = Instant::now();
654      let mut last_throttle_time = Instant::now();
655      let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
656      let mut resume_input: u64 = 0;
657      let mut did_throttle = false;
658      let mut shadow_stack_saved = true;
659      let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
660      let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
661      let mut invoke_scope = InvokeScope {
662        data: HashMap::new(),
663      };
664
665      loop {
666        ACTIVE_JIT_CODE_ZONE.with(|x| {
667          x.code_range.set((
668            entrypoint.code_ptr,
669            entrypoint.code_ptr + entrypoint.code_len,
670          ));
671          x.yielder.set(yielder.map(|x| x.0));
672          x.pointer_cage_protected_range
673            .set(self.unbound.cage.protected_range_without_margins());
674          compiler_fence(Ordering::Release);
675          x.valid.store(true, Ordering::Relaxed);
676        });
677
678        if shadow_stack_saved {
679          shadow_stack_saved = false;
680
681          // restore shadow stack
682          unsafe {
683            std::ptr::copy_nonoverlapping(
684              ctx.copy_stack.as_ptr() as *const u8,
685              shadow_stack_ptr.0.as_ptr() as *mut u8,
686              SHADOW_STACK_SIZE,
687            );
688          }
689
690          self.unbound.event_listener.did_restore_shadow_stack();
691        }
692
693        // If the previous iteration wants to write back to machine state
694        if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
695          resume_input = prev_async_task_output(&HelperScope {
696            program: self,
697            invoke: RefCell::new(&mut invoke_scope),
698            resources: RefCell::new(resources),
699            mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
700            immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
701            can_post_task: false,
702          })
703          .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
704        }
705
706        let ret = co.0 .0.resume(resume_input);
707        ACTIVE_JIT_CODE_ZONE.with(|x| {
708          x.valid.store(false, Ordering::Relaxed);
709          compiler_fence(Ordering::Release);
710          yielder = x.yielder.get().map(AssumeSend);
711        });
712
713        let dispatch: Dispatch = match ret {
714          CoroutineResult::Return(x) => break x,
715          CoroutineResult::Yield(x) => x,
716        };
717
718        // restore signal mask of current thread
719        if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
720          unsafe {
721            let unblock = get_blocked_sigset();
722            libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
723          }
724        }
725
726        if let Some(si_addr) = dispatch.memory_access_error {
727          let vaddr = si_addr - self.unbound.cage.offset();
728          return Err(RuntimeError::MemoryFault(vaddr));
729        }
730
731        // Clear pending task if something else has set it
732        PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
733        let mut helper_name: &'static str = "";
734
735        let mut helper_scope = HelperScope {
736          program: self,
737          invoke: RefCell::new(&mut invoke_scope),
738          resources: RefCell::new(resources),
739          mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
740          immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
741          can_post_task: false,
742        };
743
744        if dispatch.async_preemption {
745          self
746            .unbound
747            .event_listener
748            .did_async_preempt(&mut helper_scope);
749        } else {
750          // validator should ensure all helper indexes are present in the table
751          let Some((_, got_helper_name, helper)) = self
752            .unbound
753            .helpers
754            .get(
755              ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
756                as usize,
757            )
758            .copied()
759          else {
760            panic!("unknown helper index: {}", dispatch.index);
761          };
762          helper_name = got_helper_name;
763
764          helper_scope.can_post_task = true;
765          resume_input = helper(
766            &mut helper_scope,
767            dispatch.arg1,
768            dispatch.arg2,
769            dispatch.arg3,
770            dispatch.arg4,
771            dispatch.arg5,
772          )
773          .map_err(|()| RuntimeError::HelperError(helper_name))?;
774          helper_scope.can_post_task = false;
775        }
776
777        let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
778
779        // Fast path: do not read timestamp if no thread migration or async preemption happened
780        let new_rust_tid_sigusr1_counter =
781          (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
782        if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
783        {
784          continue;
785        }
786
787        rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
788
789        let now = Instant::now();
790        let should_throttle = now > last_throttle_time
791          && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
792        let should_yield = now > last_yield_time
793          && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
794        if should_throttle || should_yield || pending_async_task.is_some() {
795          // We are about to yield control to tokio. Save the shadow stack, and release the guard.
796          shadow_stack_saved = true;
797          unsafe {
798            std::ptr::copy_nonoverlapping(
799              shadow_stack_ptr.0.as_ptr() as *const u8,
800              ctx.copy_stack.as_mut_ptr() as *mut u8,
801              SHADOW_STACK_SIZE,
802            );
803          }
804          self.unbound.event_listener.did_save_shadow_stack();
805
806          // we are now free to give up control of current thread to other async tasks
807
808          if should_throttle {
809            if !did_throttle {
810              did_throttle = true;
811              tracing::warn!("throttling program");
812            }
813            timeslicer.sleep(timeslice.throttle_duration).await;
814            let now = Instant::now();
815            last_throttle_time = now;
816            last_yield_time = now;
817            let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
818            if let Some(task) = task {
819              task.await;
820            }
821          } else if should_yield {
822            timeslicer.yield_now().await;
823            let now = Instant::now();
824            last_yield_time = now;
825            self.unbound.event_listener.did_yield();
826          }
827
828          // Now we have released all exclusive resources and can safely execute the async task
829          if let Some(pending_async_task) = pending_async_task {
830            let async_start = Instant::now();
831            prev_async_task_output = Some((helper_name, pending_async_task.await));
832            let async_dur = async_start.elapsed();
833            last_throttle_time += async_dur;
834            last_yield_time += async_dur;
835          }
836        }
837      }
838    };
839
840    Ok(program_ret as i64)
841  }
842}
843
844struct Vm(NonNull<crate::ubpf::ubpf_vm>);
845
846impl Vm {
847  fn new(cage: &PointerCage) -> Self {
848    let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
849    unsafe {
850      crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
851      crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
852      crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
853    }
854    Self(vm)
855  }
856}
857
858impl Drop for Vm {
859  fn drop(&mut self) {
860    unsafe {
861      crate::ubpf::ubpf_destroy(self.0.as_ptr());
862    }
863  }
864}
865
866impl ProgramLoader {
867  /// Creates a new `ProgramLoader` to load eBPF code.
868  pub fn new(
869    rng: &mut impl rand::Rng,
870    event_listener: Arc<dyn ProgramEventListener>,
871    raw_helpers: &[&[(&'static str, Helper)]],
872  ) -> Self {
873    let helper_id_xor = rng.gen::<u16>();
874    let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
875    // Collect first to a HashMap then to a Vec to deduplicate
876    let mut shuffled_helpers = raw_helpers
877      .iter()
878      .flat_map(|x| x.iter().copied())
879      .collect::<HashMap<_, _>>()
880      .into_iter()
881      .collect::<Vec<_>>();
882    shuffled_helpers.shuffle(rng);
883    let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
884
885    assert!(shuffled_helpers.len() <= 65535);
886
887    for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
888      let entropy = rng.gen::<u16>() & 0x7fff;
889      helpers.push((entropy, name, helper));
890      helpers_inverse.insert(
891        name,
892        (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
893      );
894    }
895
896    tracing::info!(?helpers_inverse, "generated helper table");
897    Self {
898      helper_id_xor,
899      helpers: Arc::new(helpers),
900      helpers_inverse,
901      event_listener,
902    }
903  }
904
905  /// Loads an ELF image into a new `UnboundProgram`.
906  pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
907    self._load(rng, elf).map_err(Error)
908  }
909
910  fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
911    let start_time = Instant::now();
912    let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
913    let vm = Vm::new(&cage);
914
915    // Relocate ELF
916    let code_sections = {
917      // XXX: Although we are writing to the data region, we need to use `safe_deref_for_read`
918      // here because the `_write` variant checks that the requested region is within the
919      // stack. It's safe here because `freeze_data` is not yet called.
920      let mut data = cage
921        .safe_deref_for_read(cage.data_bottom(), elf.len())
922        .unwrap();
923      let data = unsafe { data.as_mut() };
924      data.copy_from_slice(elf);
925
926      link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
927    };
928    cage.freeze_data();
929
930    let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
931    if page_size < 0 {
932      return Err(RuntimeError::PlatformError("failed to get page size"));
933    }
934    let page_size = page_size as usize;
935
936    // Allocate code memory
937    let guard_size_before = rng.gen_range(16..128) * page_size;
938    let mut guard_size_after = rng.gen_range(16..128) * page_size;
939
940    let code_len_allocated: usize = 65536;
941    let code_mem = MmapRaw::from(
942      MmapOptions::new()
943        .len(code_len_allocated + guard_size_before + guard_size_after)
944        .map_anon()
945        .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
946    );
947
948    unsafe {
949      if crate::ubpf::ubpf_register_external_dispatcher(
950        vm.0.as_ptr(),
951        Some(tls_dispatcher),
952        Some(std_validator),
953        self as *const _ as *mut c_void,
954      ) != 0
955      {
956        return Err(RuntimeError::PlatformError(
957          "ubpf: failed to register external dispatcher",
958        ));
959      }
960      if libc::mprotect(
961        code_mem.as_mut_ptr() as *mut _,
962        guard_size_before,
963        libc::PROT_NONE,
964      ) != 0
965        || libc::mprotect(
966          code_mem
967            .as_mut_ptr()
968            .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
969          guard_size_after,
970          libc::PROT_NONE,
971        ) != 0
972      {
973        return Err(RuntimeError::PlatformError("failed to protect guard pages"));
974      }
975    }
976
977    let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
978
979    unsafe {
980      // Translate eBPF to native code
981      let mut code_slice = std::slice::from_raw_parts_mut(
982        code_mem.as_mut_ptr().offset(guard_size_before as isize),
983        code_len_allocated,
984      );
985      for (section_name, code_vaddr_size) in code_sections {
986        if code_slice.is_empty() {
987          return Err(RuntimeError::InvalidArgument(
988            "no space left for jit compilation",
989          ));
990        }
991
992        crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
993
994        let mut errmsg_ptr = std::ptr::null_mut();
995        let code = cage
996          .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
997          .unwrap();
998        let ret = crate::ubpf::ubpf_load(
999          vm.0.as_ptr(),
1000          code.as_ptr() as *const _,
1001          code.len() as u32,
1002          &mut errmsg_ptr,
1003        );
1004        if ret != 0 {
1005          let errmsg = if errmsg_ptr.is_null() {
1006            "".to_string()
1007          } else {
1008            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1009          };
1010          if !errmsg_ptr.is_null() {
1011            libc::free(errmsg_ptr as _);
1012          }
1013          tracing::error!(section_name, error = errmsg, "failed to load code");
1014          return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1015        }
1016
1017        let mut written_len = code_slice.len();
1018        let ret = crate::ubpf::ubpf_translate(
1019          vm.0.as_ptr(),
1020          code_slice.as_mut_ptr(),
1021          &mut written_len,
1022          &mut errmsg_ptr,
1023        );
1024        if ret != 0 {
1025          let errmsg = if errmsg_ptr.is_null() {
1026            "".to_string()
1027          } else {
1028            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1029          };
1030          if !errmsg_ptr.is_null() {
1031            libc::free(errmsg_ptr as _);
1032          }
1033          tracing::error!(section_name, error = errmsg, "failed to translate code");
1034          return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1035        }
1036
1037        assert!(written_len <= code_slice.len());
1038        entrypoints.insert(
1039          section_name,
1040          Entrypoint {
1041            code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1042              - code_slice.len(),
1043            code_len: written_len,
1044          },
1045        );
1046        code_slice = &mut code_slice[written_len..];
1047      }
1048
1049      // Align up code_len to page size
1050      let unpadded_code_len = code_len_allocated - code_slice.len();
1051      let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1052      assert!(code_len <= code_len_allocated);
1053
1054      // RW- -> R-X
1055      // Also make the unused part of the pre-allocated code region PROT_NONE
1056      if libc::mprotect(
1057        code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1058        code_len,
1059        libc::PROT_READ | libc::PROT_EXEC,
1060      ) != 0
1061        || (code_len < code_len_allocated
1062          && libc::mprotect(
1063            code_mem
1064              .as_mut_ptr()
1065              .offset((guard_size_before + code_len) as isize) as *mut _,
1066            code_len_allocated - code_len,
1067            libc::PROT_NONE,
1068          ) != 0)
1069      {
1070        return Err(RuntimeError::PlatformError("failed to protect code memory"));
1071      }
1072
1073      guard_size_after += code_len_allocated - code_len;
1074
1075      tracing::info!(
1076        elf_size = elf.len(),
1077        native_code_addr = ?code_mem.as_ptr(),
1078        native_code_size = code_len,
1079        native_code_size_unpadded = unpadded_code_len,
1080        guard_size_before,
1081        guard_size_after,
1082        duration = ?start_time.elapsed(),
1083        cage_ptr = ?cage.region().as_ptr(),
1084        cage_mapped_size = cage.region().len(),
1085        "jit compiled program"
1086      );
1087
1088      Ok(UnboundProgram {
1089        id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1090        _code_mem: code_mem,
1091        cage,
1092        helper_id_xor: self.helper_id_xor,
1093        helpers: self.helpers.clone(),
1094        event_listener: self.event_listener.clone(),
1095        entrypoints,
1096      })
1097    }
1098  }
1099}
1100
1101#[derive(Default)]
1102struct Dispatch {
1103  async_preemption: bool,
1104  memory_access_error: Option<usize>,
1105
1106  index: u32,
1107  arg1: u64,
1108  arg2: u64,
1109  arg3: u64,
1110  arg4: u64,
1111  arg5: u64,
1112}
1113
1114unsafe extern "C" fn tls_dispatcher(
1115  arg1: u64,
1116  arg2: u64,
1117  arg3: u64,
1118  arg4: u64,
1119  arg5: u64,
1120  index: std::os::raw::c_uint,
1121  _cookie: *mut std::os::raw::c_void,
1122) -> u64 {
1123  let yielder = ACTIVE_JIT_CODE_ZONE
1124    .with(|x| x.yielder.get())
1125    .expect("no yielder");
1126  let yielder = yielder.as_ref();
1127  let ret = yielder.suspend(Dispatch {
1128    async_preemption: false,
1129    memory_access_error: None,
1130    index,
1131    arg1,
1132    arg2,
1133    arg3,
1134    arg4,
1135    arg5,
1136  });
1137  ret
1138}
1139
1140unsafe extern "C" fn std_validator(
1141  index: std::os::raw::c_uint,
1142  loader: *mut std::os::raw::c_void,
1143) -> bool {
1144  let loader = &*(loader as *const ProgramLoader);
1145  let entropy = (index >> 16) & 0xffff;
1146  let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1147  loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1148}
1149
1150#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1151unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1152  (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1153}
1154
1155#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1156unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1157  (*uctx).uc_mcontext.pc as usize
1158}
1159
1160unsafe extern "C" fn sigsegv_handler(
1161  _sig: i32,
1162  siginfo: *mut libc::siginfo_t,
1163  uctx: *mut libc::ucontext_t,
1164) {
1165  let fail = || restore_default_signal_handler(libc::SIGSEGV);
1166
1167  let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1168    if x.valid.load(Ordering::Relaxed) {
1169      compiler_fence(Ordering::Acquire);
1170      Some((
1171        x.code_range.get(),
1172        x.pointer_cage_protected_range.get(),
1173        x.yielder.get(),
1174      ))
1175    } else {
1176      None
1177    }
1178  }) else {
1179    return fail();
1180  };
1181
1182  let pc = program_counter(uctx);
1183
1184  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1185    return fail();
1186  }
1187
1188  // SEGV_ACCERR
1189  if (*siginfo).si_code != 2 {
1190    return fail();
1191  }
1192
1193  let si_addr = (*siginfo).si_addr() as usize;
1194  if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1195    return fail();
1196  }
1197
1198  let yielder = yielder.expect("no yielder").as_ref();
1199  yielder.suspend(Dispatch {
1200    memory_access_error: Some(si_addr),
1201    ..Default::default()
1202  });
1203}
1204
1205unsafe extern "C" fn sigusr1_handler(
1206  _sig: i32,
1207  _siginfo: *mut libc::siginfo_t,
1208  uctx: *mut libc::ucontext_t,
1209) {
1210  SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1211
1212  let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1213    if x.valid.load(Ordering::Relaxed) {
1214      compiler_fence(Ordering::Acquire);
1215      Some((x.code_range.get(), x.yielder.get()))
1216    } else {
1217      None
1218    }
1219  }) else {
1220    return;
1221  };
1222  let pc = program_counter(uctx);
1223  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1224    return;
1225  }
1226
1227  let yielder = yielder.expect("no yielder").as_ref();
1228  yielder.suspend(Dispatch {
1229    async_preemption: true,
1230    ..Default::default()
1231  });
1232}
1233
1234unsafe fn restore_default_signal_handler(signum: i32) {
1235  let act = libc::sigaction {
1236    sa_sigaction: libc::SIG_DFL,
1237    sa_flags: libc::SA_SIGINFO,
1238    sa_mask: std::mem::zeroed(),
1239    sa_restorer: None,
1240  };
1241  if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1242    libc::abort();
1243  }
1244}
1245
1246fn get_blocked_sigset() -> libc::sigset_t {
1247  unsafe {
1248    let mut s: libc::sigset_t = std::mem::zeroed();
1249    libc::sigaddset(&mut s, libc::SIGUSR1);
1250    libc::sigaddset(&mut s, libc::SIGSEGV);
1251    s
1252  }
1253}