Skip to main content

async_ebpf/
program.rs

1//! This module contains the core logic for eBPF program execution. A lot of
2//! `unsafe` code is used - be careful when making changes here.
3
4use std::{
5  any::{Any, TypeId},
6  cell::{Cell, RefCell},
7  collections::HashMap,
8  ffi::CStr,
9  marker::PhantomData,
10  mem::ManuallyDrop,
11  ops::{Deref, DerefMut},
12  os::raw::c_void,
13  pin::Pin,
14  ptr::NonNull,
15  rc::Rc,
16  sync::{
17    atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18    Arc, Once,
19  },
20  thread::ThreadId,
21  time::{Duration, Instant},
22};
23
24use corosensei::{
25  stack::{DefaultStack, Stack},
26  Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
27};
28use futures::{Future, FutureExt};
29use memmap2::{MmapOptions, MmapRaw};
30use parking_lot::{Condvar, Mutex};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36  error::{Error, RuntimeError},
37  helpers::Helper,
38  linker::link_elf,
39  pointer_cage::PointerCage,
40  util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_MUTABLE_DEREF_REGIONS: usize = 4;
47const MAX_IMMUTABLE_DEREF_REGIONS: usize = 16;
48
49/// Per-invocation storage for helper state during a program run.
50pub struct InvokeScope {
51  data: HashMap<TypeId, Box<dyn Any + Send>>,
52}
53
54impl InvokeScope {
55  /// Gets or creates typed data scoped to this invocation.
56  pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
57    let ty = TypeId::of::<T>();
58    self
59      .data
60      .entry(ty)
61      .or_insert_with(|| Box::new(T::default()))
62      .downcast_mut()
63      .expect("InvokeScope::data_mut: downcast failed")
64  }
65}
66
67/// Context passed to helpers while a program is executing.
68pub struct HelperScope<'a, 'b> {
69  /// The program being executed.
70  pub program: &'a Program,
71  /// Mutable per-invocation data for helpers.
72  pub invoke: RefCell<&'a mut InvokeScope>,
73  resources: RefCell<&'a mut [&'b mut dyn Any]>,
74  mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_MUTABLE_DEREF_REGIONS],
75  immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_IMMUTABLE_DEREF_REGIONS],
76  can_post_task: bool,
77}
78
79/// A validated mutable view into user memory.
80pub struct MutableUserMemory<'a, 'b, 'c> {
81  _scope: &'c HelperScope<'a, 'b>,
82  region: NonNull<[u8]>,
83}
84
85impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
86  type Target = [u8];
87
88  fn deref(&self) -> &Self::Target {
89    unsafe { self.region.as_ref() }
90  }
91}
92
93impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
94  fn deref_mut(&mut self) -> &mut Self::Target {
95    unsafe { self.region.as_mut() }
96  }
97}
98
99impl<'a, 'b> HelperScope<'a, 'b> {
100  /// Posts an async task to be run between timeslices.
101  pub fn post_task(
102    &self,
103    task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
104  ) {
105    if !self.can_post_task {
106      panic!("HelperScope::post_task() called in a context where posting task is not allowed");
107    }
108
109    PENDING_ASYNC_TASK.with(|x| {
110      let mut x = x.borrow_mut();
111      if x.is_some() {
112        panic!("post_task called while another task is pending");
113      }
114      *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
115    });
116  }
117
118  /// Calls `callback` with a mutable resource of type `T`, if present.
119  pub fn with_resource_mut<'c, T: 'static, R>(
120    &'c self,
121    callback: impl FnOnce(Result<&mut T, ()>) -> R,
122  ) -> R {
123    let mut resources = self.resources.borrow_mut();
124    let Some(res) = resources
125      .iter_mut()
126      .filter_map(|x| x.downcast_mut::<T>())
127      .next()
128    else {
129      tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
130      return callback(Err(()));
131    };
132
133    callback(Ok(res))
134  }
135
136  /// Validates and returns an immutable view into user memory.
137  pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
138    let Some(region) = self
139      .program
140      .unbound
141      .cage
142      .safe_deref_for_read(ptr as usize, size as usize)
143    else {
144      tracing::warn!(ptr, size, "invalid read");
145      return Err(());
146    };
147
148    if size != 0 {
149      // The region must not overlap with any previously dereferenced mutable regions
150      if self
151        .mutable_dereferenced_regions
152        .iter()
153        .filter_map(|x| x.get())
154        .any(|x| nonnull_bytes_overlap(x, region))
155      {
156        tracing::warn!(ptr, size, "read overlapped with previous write");
157        return Err(());
158      }
159
160      // Find a slot to record this dereference
161      let Some(slot) = self
162        .immutable_dereferenced_regions
163        .iter()
164        .find(|x| x.get().is_none())
165      else {
166        tracing::warn!(ptr, size, "too many reads");
167        return Err(());
168      };
169      slot.set(Some(region));
170    }
171
172    Ok(unsafe { region.as_ref() })
173  }
174
175  /// Validates and returns a mutable view into user memory.
176  pub fn user_memory_mut<'c>(
177    &'c self,
178    ptr: u64,
179    size: u64,
180  ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
181    let Some(region) = self
182      .program
183      .unbound
184      .cage
185      .safe_deref_for_write(ptr as usize, size as usize)
186    else {
187      tracing::warn!(ptr, size, "invalid write");
188      return Err(());
189    };
190
191    if size != 0 {
192      // The region must not overlap with any other previously dereferenced mutable or immutable regions
193      if self
194        .mutable_dereferenced_regions
195        .iter()
196        .chain(self.immutable_dereferenced_regions.iter())
197        .filter_map(|x| x.get())
198        .any(|x| nonnull_bytes_overlap(x, region))
199      {
200        tracing::warn!(ptr, size, "write overlapped with previous read/write");
201        return Err(());
202      }
203
204      // Find a slot to record this dereference
205      let Some(slot) = self
206        .mutable_dereferenced_regions
207        .iter()
208        .find(|x| x.get().is_none())
209      else {
210        tracing::warn!(ptr, size, "too many writes");
211        return Err(());
212      };
213      slot.set(Some(region));
214    }
215
216    Ok(MutableUserMemory {
217      _scope: self,
218      region,
219    })
220  }
221}
222
223#[derive(Copy, Clone)]
224struct AssumeSend<T>(T);
225unsafe impl<T> Send for AssumeSend<T> {}
226
227struct ExecContext {
228  native_stack: DefaultStack,
229  copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
230}
231
232impl ExecContext {
233  fn new() -> Self {
234    Self {
235      native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
236        .expect("failed to initialize native stack"),
237      copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
238    }
239  }
240}
241
242/// A pending async task spawned by a helper.
243pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
244/// The callback produced by a helper async task when it resumes.
245pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
246
247static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
248
249#[derive(Copy, Clone, Debug)]
250enum PreemptionState {
251  Inactive,
252  Active(usize),
253  Shutdown,
254}
255
256type PreemptionStateSignal = (Mutex<PreemptionState>, Condvar);
257
258thread_local! {
259  static RUST_TID: ThreadId = std::thread::current().id();
260  static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
261  static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
262  static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
263  static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
264  static PREEMPTION_STATE: Arc<PreemptionStateSignal> = Arc::new((Mutex::new(PreemptionState::Inactive), Condvar::new()));
265}
266
267struct BorrowedExecContext {
268  ctx: ManuallyDrop<ExecContext>,
269}
270
271impl BorrowedExecContext {
272  fn new(stack_rng: &mut impl rand::Rng) -> Self {
273    let mut me = Self {
274      ctx: ManuallyDrop::new(
275        EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
276      ),
277    };
278    stack_rng.fill(&mut *me.ctx.copy_stack);
279    me
280  }
281}
282
283impl Drop for BorrowedExecContext {
284  fn drop(&mut self) {
285    let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
286    EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
287  }
288}
289
290#[derive(Default)]
291struct ActiveJitCodeZone {
292  valid: AtomicBool,
293  code_range: Cell<(usize, usize)>,
294  pointer_cage_protected_range: Cell<(usize, usize)>,
295  yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
296}
297
298/// Hooks for observing program execution events.
299pub trait ProgramEventListener: Send + Sync + 'static {
300  /// Called after an async preemption is triggered.
301  fn did_async_preempt(&self, _scope: &HelperScope) {}
302  /// Called after yielding back to the async runtime.
303  fn did_yield(&self) {}
304  /// Called after throttling a program's execution.
305  fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
306    None
307  }
308  /// Called after saving the shadow stack before yielding.
309  fn did_save_shadow_stack(&self) {}
310  /// Called after restoring the shadow stack on resume.
311  fn did_restore_shadow_stack(&self) {}
312}
313
314/// No-op event listener implementation.
315pub struct DummyProgramEventListener;
316impl ProgramEventListener for DummyProgramEventListener {}
317
318/// Prepares helper tables and loads eBPF programs.
319pub struct ProgramLoader {
320  helpers_inverse: HashMap<&'static str, i32>,
321  event_listener: Arc<dyn ProgramEventListener>,
322  helper_id_xor: u16,
323  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
324}
325
326/// A loaded program that is not yet pinned to a thread.
327pub struct UnboundProgram {
328  id: u64,
329  _code_mem: MmapRaw,
330  cage: PointerCage,
331  helper_id_xor: u16,
332  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
333  event_listener: Arc<dyn ProgramEventListener>,
334  entrypoints: HashMap<String, Entrypoint>,
335  stack_rng_seed: [u8; 32],
336}
337
338/// A program pinned to a specific thread and ready to execute.
339pub struct Program {
340  unbound: UnboundProgram,
341  data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
342  stack_rng: RefCell<ChaCha8Rng>,
343  t: ThreadEnv,
344}
345
346#[derive(Copy, Clone)]
347struct Entrypoint {
348  code_ptr: usize,
349  code_len: usize,
350}
351
352/// Time limits used to yield or throttle execution.
353#[derive(Clone, Debug)]
354pub struct TimesliceConfig {
355  /// Maximum runtime before yielding to the async scheduler.
356  pub max_run_time_before_yield: Duration,
357  /// Maximum runtime before a throttle sleep is forced.
358  pub max_run_time_before_throttle: Duration,
359  /// Duration of the throttle sleep once triggered.
360  pub throttle_duration: Duration,
361}
362
363/// Async runtime integration for yielding and sleeping.
364pub trait Timeslicer {
365  /// Sleep for the provided duration.
366  fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
367  /// Yield to the async scheduler.
368  fn yield_now(&self) -> impl Future<Output = ()>;
369}
370
371/// Global runtime environment for signal handlers.
372#[derive(Copy, Clone)]
373pub struct GlobalEnv(());
374
375/// Per-thread runtime environment for preemption handling.
376#[derive(Copy, Clone)]
377pub struct ThreadEnv {
378  _not_send_sync: std::marker::PhantomData<*const ()>,
379}
380
381impl GlobalEnv {
382  /// Initializes global state and installs signal handlers.
383  ///
384  /// # Safety
385  /// Must be called in a process that can install SIGUSR1/SIGSEGV handlers.
386  pub unsafe fn new() -> Self {
387    static INIT: Once = Once::new();
388
389    // SIGUSR1 must be blocked during exception handling
390    // Otherwise it seems that Linux gives up and throws an uncatchable SI_KERNEL SIGSEGV:
391    //
392    // [pid 517110] tgkill(517109, 517112, SIGUSR1 <unfinished ...>
393    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SEGV_ACCERR, si_addr=0x793789be227e} ---
394    // [pid 517109] write(15, "\1\0\0\0\0\0\0\0", 8 <unfinished ...>
395    // [pid 517110] <... tgkill resumed>)      = 0
396    // [pid 517109] <... write resumed>)       = 8
397    // [pid 517112] --- SIGUSR1 {si_signo=SIGUSR1, si_code=SI_TKILL, si_pid=517109, si_uid=1000} ---
398    // [pid 517109] recvfrom(57,  <unfinished ...>
399    // [pid 517110] futex(0x79378abff5a8, FUTEX_WAIT_PRIVATE, 1, {tv_sec=0, tv_nsec=5999909} <unfinished ...>
400    // [pid 517109] <... recvfrom resumed>"GET /write_rodata HTTP/1.1\r\nHost"..., 8192, 0, NULL, NULL) = 61
401    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SI_KERNEL, si_addr=NULL} ---
402
403    INIT.call_once(|| {
404      let sa_mask = get_blocked_sigset();
405
406      for (sig, handler) in [
407        (libc::SIGUSR1, sigusr1_handler as usize),
408        (libc::SIGSEGV, sigsegv_handler as usize),
409      ] {
410        let act = libc::sigaction {
411          sa_sigaction: handler,
412          sa_flags: libc::SA_SIGINFO,
413          sa_mask,
414          sa_restorer: None,
415        };
416        if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
417          panic!("failed to setup handler for signal {}", sig);
418        }
419      }
420    });
421
422    Self(())
423  }
424
425  /// Initializes per-thread state and starts the async preemption watcher.
426  pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
427    struct DeferDrop(Arc<PreemptionStateSignal>);
428    impl Drop for DeferDrop {
429      fn drop(&mut self) {
430        let x = &self.0;
431        *x.0.lock() = PreemptionState::Shutdown;
432        x.1.notify_one();
433      }
434    }
435
436    thread_local! {
437      static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
438    }
439
440    if WATCHER.with(|x| x.borrow().is_some()) {
441      return ThreadEnv {
442        _not_send_sync: PhantomData,
443      };
444    }
445
446    let preemption_state = PREEMPTION_STATE.with(|x| x.clone());
447
448    unsafe {
449      let tgid = libc::getpid();
450      let tid = libc::gettid();
451
452      std::thread::Builder::new()
453        .name("preempt-watcher".to_string())
454        .spawn(move || {
455          let mut state = preemption_state.0.lock();
456          loop {
457            match *state {
458              PreemptionState::Shutdown => break,
459              PreemptionState::Inactive => {
460                preemption_state.1.wait(&mut state);
461              }
462              PreemptionState::Active(_) => {
463                let timeout = preemption_state.1.wait_while_for(
464                  &mut state,
465                  |x| matches!(x, PreemptionState::Active(_)),
466                  async_preemption_interval,
467                );
468                if timeout.timed_out() {
469                  let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
470                  if ret != 0 {
471                    break;
472                  }
473                }
474              }
475            }
476          }
477        })
478        .expect("failed to spawn preemption watcher");
479
480      WATCHER.with(|x| {
481        x.borrow_mut()
482          .replace(DeferDrop(PREEMPTION_STATE.with(|x| x.clone())));
483      });
484
485      ThreadEnv {
486        _not_send_sync: PhantomData,
487      }
488    }
489  }
490}
491
492impl UnboundProgram {
493  /// Pins the program to the current thread using a prepared `ThreadEnv`.
494  pub fn pin_to_current_thread(self, t: ThreadEnv) -> Program {
495    let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
496    Program {
497      unbound: self,
498      data: RefCell::new(HashMap::new()),
499      stack_rng,
500      t,
501    }
502  }
503}
504
505pub struct PreemptionEnabled(());
506
507impl PreemptionEnabled {
508  pub fn new(_: ThreadEnv) -> Self {
509    PREEMPTION_STATE.with(|x| {
510      let mut notify = false;
511      {
512        let mut st = x.0.lock();
513        let next = match *st {
514          PreemptionState::Inactive => {
515            notify = true;
516            PreemptionState::Active(1)
517          }
518          PreemptionState::Active(n) => PreemptionState::Active(n + 1),
519          PreemptionState::Shutdown => unreachable!(),
520        };
521        *st = next;
522      }
523
524      if notify {
525        x.1.notify_one();
526      }
527    });
528    Self(())
529  }
530}
531
532impl Drop for PreemptionEnabled {
533  fn drop(&mut self) {
534    PREEMPTION_STATE.with(|x| {
535      let mut st = x.0.lock();
536      let next = match *st {
537        PreemptionState::Active(1) => PreemptionState::Inactive,
538        PreemptionState::Active(n) => {
539          assert!(n > 1);
540          PreemptionState::Active(n - 1)
541        }
542        PreemptionState::Inactive | PreemptionState::Shutdown => unreachable!(),
543      };
544      *st = next;
545    });
546  }
547}
548
549impl Program {
550  /// Returns the unique program identifier.
551  pub fn id(&self) -> u64 {
552    self.unbound.id
553  }
554
555  pub fn thread_env(&self) -> ThreadEnv {
556    self.t
557  }
558
559  /// Gets or creates shared typed data for this program instance.
560  pub fn data<T: Default + 'static>(&self) -> Rc<T> {
561    let mut data = self.data.borrow_mut();
562    let entry = data.entry(TypeId::of::<T>());
563    let entry = entry.or_insert_with(|| Rc::new(T::default()));
564    entry.clone().downcast().unwrap()
565  }
566
567  pub fn has_section(&self, name: &str) -> bool {
568    self.unbound.entrypoints.contains_key(name)
569  }
570
571  /// Runs the program entrypoint with the provided resources and calldata.
572  pub async fn run(
573    &self,
574    timeslice: &TimesliceConfig,
575    timeslicer: &impl Timeslicer,
576    entrypoint: &str,
577    resources: &mut [&mut dyn Any],
578    calldata: &[u8],
579    preemption: &PreemptionEnabled,
580  ) -> Result<i64, Error> {
581    self
582      ._run(
583        timeslice, timeslicer, entrypoint, resources, calldata, preemption,
584      )
585      .await
586      .map_err(Error)
587  }
588
589  async fn _run(
590    &self,
591    timeslice: &TimesliceConfig,
592    timeslicer: &impl Timeslicer,
593    entrypoint: &str,
594    resources: &mut [&mut dyn Any],
595    calldata: &[u8],
596    _: &PreemptionEnabled,
597  ) -> Result<i64, RuntimeError> {
598    let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
599      return Err(RuntimeError::InvalidArgument("entrypoint not found"));
600    };
601
602    let entry = unsafe {
603      std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
604        entrypoint.code_ptr,
605      )
606    };
607    struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
608      ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
609    );
610    impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
611      for CoDropper<'a, Input, Yield, Return, DefaultStack>
612    {
613      fn drop(&mut self) {
614        // Prevent the coroutine library from attempting to unwind the stack of the coroutine
615        // and run destructors, because this stack might be running a signal handler and
616        // it's not allowed to unwind from there.
617        //
618        // SAFETY: The coroutine stack only contains stack frames for JIT-compiled code and
619        // carefully chosen Rust code that do not hold Droppable values, so it's safe to
620        // skip destructors.
621        unsafe {
622          self.0.force_reset();
623        }
624      }
625    }
626
627    let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
628
629    if calldata.len() > MAX_CALLDATA_SIZE {
630      return Err(RuntimeError::InvalidArgument("calldata too large"));
631    }
632    ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
633    let calldata_len = calldata.len();
634
635    let program_ret: u64 = {
636      let shadow_stack_top = self.unbound.cage.stack_top();
637      let shadow_stack_ptr = AssumeSend(
638        self
639          .unbound
640          .cage
641          .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
642          .unwrap(),
643      );
644      let ctx = &mut *ectx.ctx;
645
646      let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
647        &mut ctx.native_stack,
648        move |yielder, _input| unsafe {
649          ACTIVE_JIT_CODE_ZONE.with(|x| {
650            x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
651          });
652          entry(
653            shadow_stack_top - calldata_len,
654            shadow_stack_top - calldata_len,
655          )
656        },
657      )));
658
659      let mut last_yield_time = Instant::now();
660      let mut last_throttle_time = Instant::now();
661      let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
662      let mut resume_input: u64 = 0;
663      let mut did_throttle = false;
664      let mut shadow_stack_saved = true;
665      let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
666      let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
667      let mut invoke_scope = InvokeScope {
668        data: HashMap::new(),
669      };
670
671      loop {
672        ACTIVE_JIT_CODE_ZONE.with(|x| {
673          x.code_range.set((
674            entrypoint.code_ptr,
675            entrypoint.code_ptr + entrypoint.code_len,
676          ));
677          x.yielder.set(yielder.map(|x| x.0));
678          x.pointer_cage_protected_range
679            .set(self.unbound.cage.protected_range_without_margins());
680          compiler_fence(Ordering::Release);
681          x.valid.store(true, Ordering::Relaxed);
682        });
683
684        if shadow_stack_saved {
685          shadow_stack_saved = false;
686
687          // restore shadow stack
688          unsafe {
689            std::ptr::copy_nonoverlapping(
690              ctx.copy_stack.as_ptr() as *const u8,
691              shadow_stack_ptr.0.as_ptr() as *mut u8,
692              SHADOW_STACK_SIZE,
693            );
694          }
695
696          self.unbound.event_listener.did_restore_shadow_stack();
697        }
698
699        // If the previous iteration wants to write back to machine state
700        if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
701          resume_input = prev_async_task_output(&HelperScope {
702            program: self,
703            invoke: RefCell::new(&mut invoke_scope),
704            resources: RefCell::new(resources),
705            mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
706            immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
707            can_post_task: false,
708          })
709          .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
710        }
711
712        let ret = co.0 .0.resume(resume_input);
713        ACTIVE_JIT_CODE_ZONE.with(|x| {
714          x.valid.store(false, Ordering::Relaxed);
715          compiler_fence(Ordering::Release);
716          yielder = x.yielder.get().map(AssumeSend);
717        });
718
719        let dispatch: Dispatch = match ret {
720          CoroutineResult::Return(x) => break x,
721          CoroutineResult::Yield(x) => x,
722        };
723
724        // restore signal mask of current thread
725        if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
726          unsafe {
727            let unblock = get_blocked_sigset();
728            libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
729          }
730        }
731
732        if let Some(si_addr) = dispatch.memory_access_error {
733          let vaddr = si_addr - self.unbound.cage.offset();
734          return Err(RuntimeError::MemoryFault(vaddr));
735        }
736
737        // Clear pending task if something else has set it
738        PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
739        let mut helper_name: &'static str = "";
740
741        let mut helper_scope = HelperScope {
742          program: self,
743          invoke: RefCell::new(&mut invoke_scope),
744          resources: RefCell::new(resources),
745          mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
746          immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
747          can_post_task: false,
748        };
749
750        if dispatch.async_preemption {
751          self
752            .unbound
753            .event_listener
754            .did_async_preempt(&mut helper_scope);
755        } else {
756          // validator should ensure all helper indexes are present in the table
757          let Some((_, got_helper_name, helper)) = self
758            .unbound
759            .helpers
760            .get(
761              ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
762                as usize,
763            )
764            .copied()
765          else {
766            panic!("unknown helper index: {}", dispatch.index);
767          };
768          helper_name = got_helper_name;
769
770          helper_scope.can_post_task = true;
771          resume_input = helper(
772            &mut helper_scope,
773            dispatch.arg1,
774            dispatch.arg2,
775            dispatch.arg3,
776            dispatch.arg4,
777            dispatch.arg5,
778          )
779          .map_err(|()| RuntimeError::HelperError(helper_name))?;
780          helper_scope.can_post_task = false;
781        }
782
783        let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
784
785        // Fast path: do not read timestamp if no thread migration or async preemption happened
786        let new_rust_tid_sigusr1_counter =
787          (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
788        if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
789        {
790          continue;
791        }
792
793        rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
794
795        let now = Instant::now();
796        let should_throttle = now > last_throttle_time
797          && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
798        let should_yield = now > last_yield_time
799          && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
800        if should_throttle || should_yield || pending_async_task.is_some() {
801          // We are about to yield control to tokio. Save the shadow stack, and release the guard.
802          shadow_stack_saved = true;
803          unsafe {
804            std::ptr::copy_nonoverlapping(
805              shadow_stack_ptr.0.as_ptr() as *const u8,
806              ctx.copy_stack.as_mut_ptr() as *mut u8,
807              SHADOW_STACK_SIZE,
808            );
809          }
810          self.unbound.event_listener.did_save_shadow_stack();
811
812          // we are now free to give up control of current thread to other async tasks
813
814          if should_throttle {
815            if !did_throttle {
816              did_throttle = true;
817              tracing::warn!("throttling program");
818            }
819            timeslicer.sleep(timeslice.throttle_duration).await;
820            let now = Instant::now();
821            last_throttle_time = now;
822            last_yield_time = now;
823            let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
824            if let Some(task) = task {
825              task.await;
826            }
827          } else if should_yield {
828            timeslicer.yield_now().await;
829            let now = Instant::now();
830            last_yield_time = now;
831            self.unbound.event_listener.did_yield();
832          }
833
834          // Now we have released all exclusive resources and can safely execute the async task
835          if let Some(pending_async_task) = pending_async_task {
836            let async_start = Instant::now();
837            prev_async_task_output = Some((helper_name, pending_async_task.await));
838            let async_dur = async_start.elapsed();
839            last_throttle_time += async_dur;
840            last_yield_time += async_dur;
841          }
842        }
843      }
844    };
845
846    Ok(program_ret as i64)
847  }
848}
849
850struct Vm(NonNull<crate::ubpf::ubpf_vm>);
851
852impl Vm {
853  fn new(cage: &PointerCage) -> Self {
854    let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
855    unsafe {
856      crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
857      crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
858      crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
859    }
860    Self(vm)
861  }
862}
863
864impl Drop for Vm {
865  fn drop(&mut self) {
866    unsafe {
867      crate::ubpf::ubpf_destroy(self.0.as_ptr());
868    }
869  }
870}
871
872impl ProgramLoader {
873  /// Creates a new `ProgramLoader` to load eBPF code.
874  pub fn new(
875    rng: &mut impl rand::Rng,
876    event_listener: Arc<dyn ProgramEventListener>,
877    raw_helpers: &[&[(&'static str, Helper)]],
878  ) -> Self {
879    let helper_id_xor = rng.gen::<u16>();
880    let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
881    // Collect first to a HashMap then to a Vec to deduplicate
882    let mut shuffled_helpers = raw_helpers
883      .iter()
884      .flat_map(|x| x.iter().copied())
885      .collect::<HashMap<_, _>>()
886      .into_iter()
887      .collect::<Vec<_>>();
888    shuffled_helpers.shuffle(rng);
889    let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
890
891    assert!(shuffled_helpers.len() <= 65535);
892
893    for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
894      let entropy = rng.gen::<u16>() & 0x7fff;
895      helpers.push((entropy, name, helper));
896      helpers_inverse.insert(
897        name,
898        (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
899      );
900    }
901
902    tracing::info!(?helpers_inverse, "generated helper table");
903    Self {
904      helper_id_xor,
905      helpers: Arc::new(helpers),
906      helpers_inverse,
907      event_listener,
908    }
909  }
910
911  /// Loads an ELF image into a new `UnboundProgram`.
912  pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
913    self._load(rng, elf).map_err(Error)
914  }
915
916  fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
917    let start_time = Instant::now();
918    let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
919    let vm = Vm::new(&cage);
920
921    // Relocate ELF
922    let code_sections = {
923      // XXX: Although we are writing to the data region, we need to use `safe_deref_for_read`
924      // here because the `_write` variant checks that the requested region is within the
925      // stack. It's safe here because `freeze_data` is not yet called.
926      let mut data = cage
927        .safe_deref_for_read(cage.data_bottom(), elf.len())
928        .unwrap();
929      let data = unsafe { data.as_mut() };
930      data.copy_from_slice(elf);
931
932      link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
933    };
934    cage.freeze_data();
935
936    let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
937    if page_size < 0 {
938      return Err(RuntimeError::PlatformError("failed to get page size"));
939    }
940    let page_size = page_size as usize;
941
942    // Allocate code memory
943    let guard_size_before = rng.gen_range(16..128) * page_size;
944    let mut guard_size_after = rng.gen_range(16..128) * page_size;
945
946    let code_len_allocated: usize = 65536;
947    let code_mem = MmapRaw::from(
948      MmapOptions::new()
949        .len(code_len_allocated + guard_size_before + guard_size_after)
950        .map_anon()
951        .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
952    );
953
954    unsafe {
955      if crate::ubpf::ubpf_register_external_dispatcher(
956        vm.0.as_ptr(),
957        Some(tls_dispatcher),
958        Some(std_validator),
959        self as *const _ as *mut c_void,
960      ) != 0
961      {
962        return Err(RuntimeError::PlatformError(
963          "ubpf: failed to register external dispatcher",
964        ));
965      }
966      if libc::mprotect(
967        code_mem.as_mut_ptr() as *mut _,
968        guard_size_before,
969        libc::PROT_NONE,
970      ) != 0
971        || libc::mprotect(
972          code_mem
973            .as_mut_ptr()
974            .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
975          guard_size_after,
976          libc::PROT_NONE,
977        ) != 0
978      {
979        return Err(RuntimeError::PlatformError("failed to protect guard pages"));
980      }
981    }
982
983    let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
984
985    unsafe {
986      // Translate eBPF to native code
987      let mut code_slice = std::slice::from_raw_parts_mut(
988        code_mem.as_mut_ptr().offset(guard_size_before as isize),
989        code_len_allocated,
990      );
991      for (section_name, code_vaddr_size) in code_sections {
992        if code_slice.is_empty() {
993          return Err(RuntimeError::InvalidArgument(
994            "no space left for jit compilation",
995          ));
996        }
997
998        crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
999
1000        let mut errmsg_ptr = std::ptr::null_mut();
1001        let code = cage
1002          .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
1003          .unwrap();
1004        let ret = crate::ubpf::ubpf_load(
1005          vm.0.as_ptr(),
1006          code.as_ptr() as *const _,
1007          code.len() as u32,
1008          &mut errmsg_ptr,
1009        );
1010        if ret != 0 {
1011          let errmsg = if errmsg_ptr.is_null() {
1012            "".to_string()
1013          } else {
1014            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1015          };
1016          if !errmsg_ptr.is_null() {
1017            libc::free(errmsg_ptr as _);
1018          }
1019          tracing::error!(section_name, error = errmsg, "failed to load code");
1020          return Err(RuntimeError::PlatformError("ubpf: code load failed"));
1021        }
1022
1023        let mut written_len = code_slice.len();
1024        let ret = crate::ubpf::ubpf_translate(
1025          vm.0.as_ptr(),
1026          code_slice.as_mut_ptr(),
1027          &mut written_len,
1028          &mut errmsg_ptr,
1029        );
1030        if ret != 0 {
1031          let errmsg = if errmsg_ptr.is_null() {
1032            "".to_string()
1033          } else {
1034            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
1035          };
1036          if !errmsg_ptr.is_null() {
1037            libc::free(errmsg_ptr as _);
1038          }
1039          tracing::error!(section_name, error = errmsg, "failed to translate code");
1040          return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
1041        }
1042
1043        assert!(written_len <= code_slice.len());
1044        entrypoints.insert(
1045          section_name,
1046          Entrypoint {
1047            code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
1048              - code_slice.len(),
1049            code_len: written_len,
1050          },
1051        );
1052        code_slice = &mut code_slice[written_len..];
1053      }
1054
1055      // Align up code_len to page size
1056      let unpadded_code_len = code_len_allocated - code_slice.len();
1057      let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
1058      assert!(code_len <= code_len_allocated);
1059
1060      // RW- -> R-X
1061      // Also make the unused part of the pre-allocated code region PROT_NONE
1062      if libc::mprotect(
1063        code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
1064        code_len,
1065        libc::PROT_READ | libc::PROT_EXEC,
1066      ) != 0
1067        || (code_len < code_len_allocated
1068          && libc::mprotect(
1069            code_mem
1070              .as_mut_ptr()
1071              .offset((guard_size_before + code_len) as isize) as *mut _,
1072            code_len_allocated - code_len,
1073            libc::PROT_NONE,
1074          ) != 0)
1075      {
1076        return Err(RuntimeError::PlatformError("failed to protect code memory"));
1077      }
1078
1079      guard_size_after += code_len_allocated - code_len;
1080
1081      tracing::info!(
1082        elf_size = elf.len(),
1083        native_code_addr = ?code_mem.as_ptr(),
1084        native_code_size = code_len,
1085        native_code_size_unpadded = unpadded_code_len,
1086        guard_size_before,
1087        guard_size_after,
1088        duration = ?start_time.elapsed(),
1089        cage_ptr = ?cage.region().as_ptr(),
1090        cage_mapped_size = cage.region().len(),
1091        "jit compiled program"
1092      );
1093
1094      Ok(UnboundProgram {
1095        id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1096        _code_mem: code_mem,
1097        cage,
1098        helper_id_xor: self.helper_id_xor,
1099        helpers: self.helpers.clone(),
1100        event_listener: self.event_listener.clone(),
1101        entrypoints,
1102        stack_rng_seed: rng.gen(),
1103      })
1104    }
1105  }
1106}
1107
1108#[derive(Default)]
1109struct Dispatch {
1110  async_preemption: bool,
1111  memory_access_error: Option<usize>,
1112
1113  index: u32,
1114  arg1: u64,
1115  arg2: u64,
1116  arg3: u64,
1117  arg4: u64,
1118  arg5: u64,
1119}
1120
1121unsafe extern "C" fn tls_dispatcher(
1122  arg1: u64,
1123  arg2: u64,
1124  arg3: u64,
1125  arg4: u64,
1126  arg5: u64,
1127  index: std::os::raw::c_uint,
1128  _cookie: *mut std::os::raw::c_void,
1129) -> u64 {
1130  let yielder = ACTIVE_JIT_CODE_ZONE
1131    .with(|x| x.yielder.get())
1132    .expect("no yielder");
1133  let yielder = yielder.as_ref();
1134  let ret = yielder.suspend(Dispatch {
1135    async_preemption: false,
1136    memory_access_error: None,
1137    index,
1138    arg1,
1139    arg2,
1140    arg3,
1141    arg4,
1142    arg5,
1143  });
1144  ret
1145}
1146
1147unsafe extern "C" fn std_validator(
1148  index: std::os::raw::c_uint,
1149  loader: *mut std::os::raw::c_void,
1150) -> bool {
1151  let loader = &*(loader as *const ProgramLoader);
1152  let entropy = (index >> 16) & 0xffff;
1153  let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1154  loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1155}
1156
1157#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1158unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1159  (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1160}
1161
1162#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1163unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1164  (*uctx).uc_mcontext.pc as usize
1165}
1166
1167unsafe extern "C" fn sigsegv_handler(
1168  _sig: i32,
1169  siginfo: *mut libc::siginfo_t,
1170  uctx: *mut libc::ucontext_t,
1171) {
1172  let fail = || restore_default_signal_handler(libc::SIGSEGV);
1173
1174  let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1175    if x.valid.load(Ordering::Relaxed) {
1176      compiler_fence(Ordering::Acquire);
1177      Some((
1178        x.code_range.get(),
1179        x.pointer_cage_protected_range.get(),
1180        x.yielder.get(),
1181      ))
1182    } else {
1183      None
1184    }
1185  }) else {
1186    return fail();
1187  };
1188
1189  let pc = program_counter(uctx);
1190
1191  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1192    return fail();
1193  }
1194
1195  // SEGV_ACCERR
1196  if (*siginfo).si_code != 2 {
1197    return fail();
1198  }
1199
1200  let si_addr = (*siginfo).si_addr() as usize;
1201  if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1202    return fail();
1203  }
1204
1205  let yielder = yielder.expect("no yielder").as_ref();
1206  yielder.suspend(Dispatch {
1207    memory_access_error: Some(si_addr),
1208    ..Default::default()
1209  });
1210}
1211
1212unsafe extern "C" fn sigusr1_handler(
1213  _sig: i32,
1214  _siginfo: *mut libc::siginfo_t,
1215  uctx: *mut libc::ucontext_t,
1216) {
1217  SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1218
1219  let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1220    if x.valid.load(Ordering::Relaxed) {
1221      compiler_fence(Ordering::Acquire);
1222      Some((x.code_range.get(), x.yielder.get()))
1223    } else {
1224      None
1225    }
1226  }) else {
1227    return;
1228  };
1229  let pc = program_counter(uctx);
1230  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1231    return;
1232  }
1233
1234  let yielder = yielder.expect("no yielder").as_ref();
1235  yielder.suspend(Dispatch {
1236    async_preemption: true,
1237    ..Default::default()
1238  });
1239}
1240
1241unsafe fn restore_default_signal_handler(signum: i32) {
1242  let act = libc::sigaction {
1243    sa_sigaction: libc::SIG_DFL,
1244    sa_flags: libc::SA_SIGINFO,
1245    sa_mask: std::mem::zeroed(),
1246    sa_restorer: None,
1247  };
1248  if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1249    libc::abort();
1250  }
1251}
1252
1253fn get_blocked_sigset() -> libc::sigset_t {
1254  unsafe {
1255    let mut s: libc::sigset_t = std::mem::zeroed();
1256    libc::sigaddset(&mut s, libc::SIGUSR1);
1257    libc::sigaddset(&mut s, libc::SIGSEGV);
1258    s
1259  }
1260}