async_ebpf/
program.rs

1//! This module contains the core logic for eBPF program execution. A lot of
2//! `unsafe` code is used - be careful when making changes here.
3
4use std::{
5  any::{Any, TypeId},
6  cell::{Cell, RefCell},
7  collections::HashMap,
8  ffi::CStr,
9  marker::PhantomData,
10  mem::ManuallyDrop,
11  ops::{Deref, DerefMut},
12  os::raw::c_void,
13  pin::Pin,
14  ptr::NonNull,
15  rc::Rc,
16  sync::{
17    atomic::{compiler_fence, AtomicBool, AtomicU64, Ordering},
18    mpsc::RecvTimeoutError,
19    Arc, Once,
20  },
21  thread::ThreadId,
22  time::{Duration, Instant},
23};
24
25use corosensei::{
26  stack::{DefaultStack, Stack},
27  Coroutine, CoroutineResult, ScopedCoroutine, Yielder,
28};
29use futures::{Future, FutureExt};
30use memmap2::{MmapOptions, MmapRaw};
31use rand::prelude::SliceRandom;
32use rand::SeedableRng;
33use rand_chacha::ChaCha8Rng;
34
35use crate::{
36  error::{Error, RuntimeError},
37  helpers::Helper,
38  linker::link_elf,
39  pointer_cage::PointerCage,
40  util::nonnull_bytes_overlap,
41};
42
43const NATIVE_STACK_SIZE: usize = 16384;
44const SHADOW_STACK_SIZE: usize = 4096;
45const MAX_CALLDATA_SIZE: usize = 512;
46const MAX_DEREF_REGIONS: usize = 4;
47
48/// Per-invocation storage for helper state during a program run.
49pub struct InvokeScope {
50  data: HashMap<TypeId, Box<dyn Any + Send>>,
51}
52
53impl InvokeScope {
54  /// Gets or creates typed data scoped to this invocation.
55  pub fn data_mut<T: Default + Send + 'static>(&mut self) -> &mut T {
56    let ty = TypeId::of::<T>();
57    self
58      .data
59      .entry(ty)
60      .or_insert_with(|| Box::new(T::default()))
61      .downcast_mut()
62      .expect("InvokeScope::data_mut: downcast failed")
63  }
64}
65
66/// Context passed to helpers while a program is executing.
67pub struct HelperScope<'a, 'b> {
68  /// The program being executed.
69  pub program: &'a Program,
70  /// Mutable per-invocation data for helpers.
71  pub invoke: RefCell<&'a mut InvokeScope>,
72  resources: RefCell<&'a mut [&'b mut dyn Any]>,
73  mutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
74  immutable_dereferenced_regions: [Cell<Option<NonNull<[u8]>>>; MAX_DEREF_REGIONS],
75  can_post_task: bool,
76}
77
78/// A validated mutable view into user memory.
79pub struct MutableUserMemory<'a, 'b, 'c> {
80  _scope: &'c HelperScope<'a, 'b>,
81  region: NonNull<[u8]>,
82}
83
84impl<'a, 'b, 'c> Deref for MutableUserMemory<'a, 'b, 'c> {
85  type Target = [u8];
86
87  fn deref(&self) -> &Self::Target {
88    unsafe { self.region.as_ref() }
89  }
90}
91
92impl<'a, 'b, 'c> DerefMut for MutableUserMemory<'a, 'b, 'c> {
93  fn deref_mut(&mut self) -> &mut Self::Target {
94    unsafe { self.region.as_mut() }
95  }
96}
97
98impl<'a, 'b> HelperScope<'a, 'b> {
99  /// Posts an async task to be run between timeslices.
100  pub fn post_task(
101    &self,
102    task: impl Future<Output = impl FnOnce(&HelperScope) -> Result<u64, ()> + 'static> + 'static,
103  ) {
104    if !self.can_post_task {
105      panic!("HelperScope::post_task() called in a context where posting task is not allowed");
106    }
107
108    PENDING_ASYNC_TASK.with(|x| {
109      let mut x = x.borrow_mut();
110      if x.is_some() {
111        panic!("post_task called while another task is pending");
112      }
113      *x = Some(async move { Box::new(task.await) as AsyncTaskOutput }.boxed_local());
114    });
115  }
116
117  /// Calls `callback` with a mutable resource of type `T`, if present.
118  pub fn with_resource_mut<'c, T: 'static, R>(
119    &'c self,
120    callback: impl FnOnce(Result<&mut T, ()>) -> R,
121  ) -> R {
122    let mut resources = self.resources.borrow_mut();
123    let Some(res) = resources
124      .iter_mut()
125      .filter_map(|x| x.downcast_mut::<T>())
126      .next()
127    else {
128      tracing::warn!(resource_type = ?TypeId::of::<T>(), "resource not found");
129      return callback(Err(()));
130    };
131
132    callback(Ok(res))
133  }
134
135  /// Validates and returns an immutable view into user memory.
136  pub fn user_memory(&self, ptr: u64, size: u64) -> Result<&[u8], ()> {
137    let Some(region) = self
138      .program
139      .unbound
140      .cage
141      .safe_deref_for_read(ptr as usize, size as usize)
142    else {
143      tracing::warn!(ptr, size, "invalid read");
144      return Err(());
145    };
146
147    if size != 0 {
148      // The region must not overlap with any previously dereferenced mutable regions
149      if self
150        .mutable_dereferenced_regions
151        .iter()
152        .filter_map(|x| x.get())
153        .any(|x| nonnull_bytes_overlap(x, region))
154      {
155        tracing::warn!(ptr, size, "read overlapped with previous write");
156        return Err(());
157      }
158
159      // Find a slot to record this dereference
160      let Some(slot) = self
161        .immutable_dereferenced_regions
162        .iter()
163        .find(|x| x.get().is_none())
164      else {
165        tracing::warn!(ptr, size, "too many reads");
166        return Err(());
167      };
168      slot.set(Some(region));
169    }
170
171    Ok(unsafe { region.as_ref() })
172  }
173
174  /// Validates and returns a mutable view into user memory.
175  pub fn user_memory_mut<'c>(
176    &'c self,
177    ptr: u64,
178    size: u64,
179  ) -> Result<MutableUserMemory<'a, 'b, 'c>, ()> {
180    let Some(region) = self
181      .program
182      .unbound
183      .cage
184      .safe_deref_for_write(ptr as usize, size as usize)
185    else {
186      tracing::warn!(ptr, size, "invalid write");
187      return Err(());
188    };
189
190    if size != 0 {
191      // The region must not overlap with any other previously dereferenced mutable or immutable regions
192      if self
193        .mutable_dereferenced_regions
194        .iter()
195        .chain(self.immutable_dereferenced_regions.iter())
196        .filter_map(|x| x.get())
197        .any(|x| nonnull_bytes_overlap(x, region))
198      {
199        tracing::warn!(ptr, size, "write overlapped with previous read/write");
200        return Err(());
201      }
202
203      // Find a slot to record this dereference
204      let Some(slot) = self
205        .mutable_dereferenced_regions
206        .iter()
207        .find(|x| x.get().is_none())
208      else {
209        tracing::warn!(ptr, size, "too many writes");
210        return Err(());
211      };
212      slot.set(Some(region));
213    }
214
215    Ok(MutableUserMemory {
216      _scope: self,
217      region,
218    })
219  }
220}
221
222#[derive(Copy, Clone)]
223struct AssumeSend<T>(T);
224unsafe impl<T> Send for AssumeSend<T> {}
225
226struct ExecContext {
227  native_stack: DefaultStack,
228  copy_stack: Box<[u8; SHADOW_STACK_SIZE]>,
229}
230
231impl ExecContext {
232  fn new() -> Self {
233    Self {
234      native_stack: DefaultStack::new(NATIVE_STACK_SIZE)
235        .expect("failed to initialize native stack"),
236      copy_stack: Box::new([0u8; SHADOW_STACK_SIZE]),
237    }
238  }
239}
240
241/// A pending async task spawned by a helper.
242pub type PendingAsyncTask = Pin<Box<dyn Future<Output = AsyncTaskOutput>>>;
243/// The callback produced by a helper async task when it resumes.
244pub type AsyncTaskOutput = Box<dyn FnOnce(&HelperScope) -> Result<u64, ()>>;
245
246static NEXT_PROGRAM_ID: AtomicU64 = AtomicU64::new(1);
247
248thread_local! {
249  static RUST_TID: ThreadId = std::thread::current().id();
250  static SIGUSR1_COUNTER: Cell< u64> = Cell::new(0);
251  static ACTIVE_JIT_CODE_ZONE: ActiveJitCodeZone = ActiveJitCodeZone::default();
252  static EXEC_CONTEXT_POOL: RefCell<Vec<ExecContext>> = Default::default();
253  static PENDING_ASYNC_TASK: RefCell<Option<PendingAsyncTask>> = RefCell::new(None);
254}
255
256struct BorrowedExecContext {
257  ctx: ManuallyDrop<ExecContext>,
258}
259
260impl BorrowedExecContext {
261  fn new(stack_rng: &mut impl rand::Rng) -> Self {
262    let mut me = Self {
263      ctx: ManuallyDrop::new(
264        EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().pop().unwrap_or_else(ExecContext::new)),
265      ),
266    };
267    stack_rng.fill(&mut *me.ctx.copy_stack);
268    me
269  }
270}
271
272impl Drop for BorrowedExecContext {
273  fn drop(&mut self) {
274    let ctx = unsafe { ManuallyDrop::take(&mut self.ctx) };
275    EXEC_CONTEXT_POOL.with(|x| x.borrow_mut().push(ctx));
276  }
277}
278
279#[derive(Default)]
280struct ActiveJitCodeZone {
281  valid: AtomicBool,
282  code_range: Cell<(usize, usize)>,
283  pointer_cage_protected_range: Cell<(usize, usize)>,
284  yielder: Cell<Option<NonNull<Yielder<u64, Dispatch>>>>,
285}
286
287/// Hooks for observing program execution events.
288pub trait ProgramEventListener: Send + Sync + 'static {
289  /// Called after an async preemption is triggered.
290  fn did_async_preempt(&self, _scope: &HelperScope) {}
291  /// Called after yielding back to the async runtime.
292  fn did_yield(&self) {}
293  /// Called after throttling a program's execution.
294  fn did_throttle(&self, _scope: &HelperScope) -> Option<Pin<Box<dyn Future<Output = ()>>>> {
295    None
296  }
297  /// Called after saving the shadow stack before yielding.
298  fn did_save_shadow_stack(&self) {}
299  /// Called after restoring the shadow stack on resume.
300  fn did_restore_shadow_stack(&self) {}
301}
302
303/// No-op event listener implementation.
304pub struct DummyProgramEventListener;
305impl ProgramEventListener for DummyProgramEventListener {}
306
307/// Prepares helper tables and loads eBPF programs.
308pub struct ProgramLoader {
309  helpers_inverse: HashMap<&'static str, i32>,
310  event_listener: Arc<dyn ProgramEventListener>,
311  helper_id_xor: u16,
312  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
313}
314
315/// A loaded program that is not yet pinned to a thread.
316pub struct UnboundProgram {
317  id: u64,
318  _code_mem: MmapRaw,
319  cage: PointerCage,
320  helper_id_xor: u16,
321  helpers: Arc<Vec<(u16, &'static str, Helper)>>,
322  event_listener: Arc<dyn ProgramEventListener>,
323  entrypoints: HashMap<String, Entrypoint>,
324  stack_rng_seed: [u8; 32],
325}
326
327/// A program pinned to a specific thread and ready to execute.
328pub struct Program {
329  unbound: UnboundProgram,
330  data: RefCell<HashMap<TypeId, Rc<dyn Any>>>,
331  stack_rng: RefCell<ChaCha8Rng>,
332}
333
334#[derive(Copy, Clone)]
335struct Entrypoint {
336  code_ptr: usize,
337  code_len: usize,
338}
339
340/// Time limits used to yield or throttle execution.
341#[derive(Clone, Debug)]
342pub struct TimesliceConfig {
343  /// Maximum runtime before yielding to the async scheduler.
344  pub max_run_time_before_yield: Duration,
345  /// Maximum runtime before a throttle sleep is forced.
346  pub max_run_time_before_throttle: Duration,
347  /// Duration of the throttle sleep once triggered.
348  pub throttle_duration: Duration,
349}
350
351/// Async runtime integration for yielding and sleeping.
352pub trait Timeslicer {
353  /// Sleep for the provided duration.
354  fn sleep(&self, duration: Duration) -> impl Future<Output = ()>;
355  /// Yield to the async scheduler.
356  fn yield_now(&self) -> impl Future<Output = ()>;
357}
358
359/// Global runtime environment for signal handlers.
360#[derive(Copy, Clone)]
361pub struct GlobalEnv(());
362
363/// Per-thread runtime environment for preemption handling.
364#[derive(Copy, Clone)]
365pub struct ThreadEnv {
366  _not_send_sync: std::marker::PhantomData<*const ()>,
367}
368
369impl GlobalEnv {
370  /// Initializes global state and installs signal handlers.
371  ///
372  /// # Safety
373  /// Must be called in a process that can install SIGUSR1/SIGSEGV handlers.
374  pub unsafe fn new() -> Self {
375    static INIT: Once = Once::new();
376
377    // SIGUSR1 must be blocked during exception handling
378    // Otherwise it seems that Linux gives up and throws an uncatchable SI_KERNEL SIGSEGV:
379    //
380    // [pid 517110] tgkill(517109, 517112, SIGUSR1 <unfinished ...>
381    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SEGV_ACCERR, si_addr=0x793789be227e} ---
382    // [pid 517109] write(15, "\1\0\0\0\0\0\0\0", 8 <unfinished ...>
383    // [pid 517110] <... tgkill resumed>)      = 0
384    // [pid 517109] <... write resumed>)       = 8
385    // [pid 517112] --- SIGUSR1 {si_signo=SIGUSR1, si_code=SI_TKILL, si_pid=517109, si_uid=1000} ---
386    // [pid 517109] recvfrom(57,  <unfinished ...>
387    // [pid 517110] futex(0x79378abff5a8, FUTEX_WAIT_PRIVATE, 1, {tv_sec=0, tv_nsec=5999909} <unfinished ...>
388    // [pid 517109] <... recvfrom resumed>"GET /write_rodata HTTP/1.1\r\nHost"..., 8192, 0, NULL, NULL) = 61
389    // [pid 517112] --- SIGSEGV {si_signo=SIGSEGV, si_code=SI_KERNEL, si_addr=NULL} ---
390
391    INIT.call_once(|| {
392      let sa_mask = get_blocked_sigset();
393
394      for (sig, handler) in [
395        (libc::SIGUSR1, sigusr1_handler as usize),
396        (libc::SIGSEGV, sigsegv_handler as usize),
397      ] {
398        let act = libc::sigaction {
399          sa_sigaction: handler,
400          sa_flags: libc::SA_SIGINFO,
401          sa_mask,
402          sa_restorer: None,
403        };
404        if libc::sigaction(sig, &act, std::ptr::null_mut()) != 0 {
405          panic!("failed to setup handler for signal {}", sig);
406        }
407      }
408    });
409
410    Self(())
411  }
412
413  /// Initializes per-thread state and starts the async preemption watcher.
414  pub fn init_thread(self, async_preemption_interval: Duration) -> ThreadEnv {
415    struct DeferDrop(
416      Option<std::sync::mpsc::Sender<()>>,
417      std::sync::mpsc::Receiver<()>,
418    );
419    impl Drop for DeferDrop {
420      fn drop(&mut self) {
421        self.0.take();
422        let _ = self.1.recv();
423        // tracing::info! uses TLS state and it is not allowed during thread exit
424      }
425    }
426
427    thread_local! {
428      static WATCHER: RefCell<Option<DeferDrop>> = RefCell::new(None);
429    }
430
431    if WATCHER.with(|x| x.borrow().is_some()) {
432      return ThreadEnv {
433        _not_send_sync: PhantomData,
434      };
435    }
436
437    unsafe {
438      let tgid = libc::getpid();
439      let tid = libc::gettid();
440      let (watcher_completion_tx, watcher_completion_rx) = std::sync::mpsc::channel::<()>();
441      let (shutdown_tx, shutdown_rx) = std::sync::mpsc::channel::<()>();
442
443      std::thread::Builder::new()
444        .name("preempt-watcher".to_string())
445        .spawn(move || {
446          let _watcher_completion_tx = watcher_completion_tx;
447          loop {
448            if !matches!(
449              shutdown_rx.recv_timeout(async_preemption_interval),
450              Err(RecvTimeoutError::Timeout)
451            ) {
452              break;
453            }
454            let ret = libc::syscall(libc::SYS_tgkill, tgid, tid, libc::SIGUSR1);
455            if ret != 0 {
456              break;
457            }
458          }
459        })
460        .expect("failed to spawn preemption watcher");
461
462      WATCHER.with(|x| {
463        x.borrow_mut()
464          .replace(DeferDrop(Some(shutdown_tx), watcher_completion_rx));
465      });
466
467      ThreadEnv {
468        _not_send_sync: PhantomData,
469      }
470    }
471  }
472}
473
474impl UnboundProgram {
475  /// Pins the program to the current thread using a prepared `ThreadEnv`.
476  pub fn pin_to_current_thread(self, _: ThreadEnv) -> Program {
477    let stack_rng = RefCell::new(ChaCha8Rng::from_seed(self.stack_rng_seed));
478    Program {
479      unbound: self,
480      data: RefCell::new(HashMap::new()),
481      stack_rng,
482    }
483  }
484}
485
486impl Program {
487  /// Returns the unique program identifier.
488  pub fn id(&self) -> u64 {
489    self.unbound.id
490  }
491
492  /// Gets or creates shared typed data for this program instance.
493  pub fn data<T: Default + 'static>(&self) -> Rc<T> {
494    let mut data = self.data.borrow_mut();
495    let entry = data.entry(TypeId::of::<T>());
496    let entry = entry.or_insert_with(|| Rc::new(T::default()));
497    entry.clone().downcast().unwrap()
498  }
499
500  pub fn has_section(&self, name: &str) -> bool {
501    self.unbound.entrypoints.contains_key(name)
502  }
503
504  /// Runs the program entrypoint with the provided resources and calldata.
505  pub async fn run(
506    &self,
507    timeslice: &TimesliceConfig,
508    timeslicer: &impl Timeslicer,
509    entrypoint: &str,
510    resources: &mut [&mut dyn Any],
511    calldata: &[u8],
512  ) -> Result<i64, Error> {
513    self
514      ._run(timeslice, timeslicer, entrypoint, resources, calldata)
515      .await
516      .map_err(Error)
517  }
518
519  async fn _run(
520    &self,
521    timeslice: &TimesliceConfig,
522    timeslicer: &impl Timeslicer,
523    entrypoint: &str,
524    resources: &mut [&mut dyn Any],
525    calldata: &[u8],
526  ) -> Result<i64, RuntimeError> {
527    let Some(entrypoint) = self.unbound.entrypoints.get(entrypoint).copied() else {
528      return Err(RuntimeError::InvalidArgument("entrypoint not found"));
529    };
530
531    let entry = unsafe {
532      std::mem::transmute::<_, unsafe extern "C" fn(ctx: usize, shadow_stack: usize) -> u64>(
533        entrypoint.code_ptr,
534      )
535    };
536    struct CoDropper<'a, Input, Yield, Return, DefaultStack: Stack>(
537      ScopedCoroutine<'a, Input, Yield, Return, DefaultStack>,
538    );
539    impl<'a, Input, Yield, Return, DefaultStack: Stack> Drop
540      for CoDropper<'a, Input, Yield, Return, DefaultStack>
541    {
542      fn drop(&mut self) {
543        // Prevent the coroutine library from attempting to unwind the stack of the coroutine
544        // and run destructors, because this stack might be running a signal handler and
545        // it's not allowed to unwind from there.
546        //
547        // SAFETY: The coroutine stack only contains stack frames for JIT-compiled code and
548        // carefully chosen Rust code that do not hold Droppable values, so it's safe to
549        // skip destructors.
550        unsafe {
551          self.0.force_reset();
552        }
553      }
554    }
555
556    let mut ectx = BorrowedExecContext::new(&mut *self.stack_rng.borrow_mut());
557
558    if calldata.len() > MAX_CALLDATA_SIZE {
559      return Err(RuntimeError::InvalidArgument("calldata too large"));
560    }
561    ectx.ctx.copy_stack[SHADOW_STACK_SIZE - calldata.len()..].copy_from_slice(calldata);
562    let calldata_len = calldata.len();
563
564    let program_ret: u64 = {
565      let shadow_stack_top = self.unbound.cage.stack_top();
566      let shadow_stack_ptr = AssumeSend(
567        self
568          .unbound
569          .cage
570          .safe_deref_for_write(self.unbound.cage.stack_bottom(), SHADOW_STACK_SIZE)
571          .unwrap(),
572      );
573      let ctx = &mut *ectx.ctx;
574
575      let mut co = AssumeSend(CoDropper(Coroutine::with_stack(
576        &mut ctx.native_stack,
577        move |yielder, _input| unsafe {
578          ACTIVE_JIT_CODE_ZONE.with(|x| {
579            x.yielder.set(NonNull::new(yielder as *const _ as *mut _));
580          });
581          entry(
582            shadow_stack_top - calldata_len,
583            shadow_stack_top - calldata_len,
584          )
585        },
586      )));
587
588      let mut last_yield_time = Instant::now();
589      let mut last_throttle_time = Instant::now();
590      let mut yielder: Option<AssumeSend<NonNull<Yielder<u64, Dispatch>>>> = None;
591      let mut resume_input: u64 = 0;
592      let mut did_throttle = false;
593      let mut shadow_stack_saved = true;
594      let mut rust_tid_sigusr1_counter = (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
595      let mut prev_async_task_output: Option<(&'static str, AsyncTaskOutput)> = None;
596      let mut invoke_scope = InvokeScope {
597        data: HashMap::new(),
598      };
599
600      loop {
601        ACTIVE_JIT_CODE_ZONE.with(|x| {
602          x.code_range.set((
603            entrypoint.code_ptr,
604            entrypoint.code_ptr + entrypoint.code_len,
605          ));
606          x.yielder.set(yielder.map(|x| x.0));
607          x.pointer_cage_protected_range
608            .set(self.unbound.cage.protected_range_without_margins());
609          compiler_fence(Ordering::Release);
610          x.valid.store(true, Ordering::Relaxed);
611        });
612
613        if shadow_stack_saved {
614          shadow_stack_saved = false;
615
616          // restore shadow stack
617          unsafe {
618            std::ptr::copy_nonoverlapping(
619              ctx.copy_stack.as_ptr() as *const u8,
620              shadow_stack_ptr.0.as_ptr() as *mut u8,
621              SHADOW_STACK_SIZE,
622            );
623          }
624
625          self.unbound.event_listener.did_restore_shadow_stack();
626        }
627
628        // If the previous iteration wants to write back to machine state
629        if let Some((helper_name, prev_async_task_output)) = prev_async_task_output.take() {
630          resume_input = prev_async_task_output(&HelperScope {
631            program: self,
632            invoke: RefCell::new(&mut invoke_scope),
633            resources: RefCell::new(resources),
634            mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
635            immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
636            can_post_task: false,
637          })
638          .map_err(|_| RuntimeError::AsyncHelperError(helper_name))?;
639        }
640
641        let ret = co.0 .0.resume(resume_input);
642        ACTIVE_JIT_CODE_ZONE.with(|x| {
643          x.valid.store(false, Ordering::Relaxed);
644          compiler_fence(Ordering::Release);
645          yielder = x.yielder.get().map(AssumeSend);
646        });
647
648        let dispatch: Dispatch = match ret {
649          CoroutineResult::Return(x) => break x,
650          CoroutineResult::Yield(x) => x,
651        };
652
653        // restore signal mask of current thread
654        if dispatch.memory_access_error.is_some() || dispatch.async_preemption {
655          unsafe {
656            let unblock = get_blocked_sigset();
657            libc::sigprocmask(libc::SIG_UNBLOCK, &unblock, std::ptr::null_mut());
658          }
659        }
660
661        if let Some(si_addr) = dispatch.memory_access_error {
662          let vaddr = si_addr - self.unbound.cage.offset();
663          return Err(RuntimeError::MemoryFault(vaddr));
664        }
665
666        // Clear pending task if something else has set it
667        PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
668        let mut helper_name: &'static str = "";
669
670        let mut helper_scope = HelperScope {
671          program: self,
672          invoke: RefCell::new(&mut invoke_scope),
673          resources: RefCell::new(resources),
674          mutable_dereferenced_regions: unsafe { std::mem::zeroed() },
675          immutable_dereferenced_regions: unsafe { std::mem::zeroed() },
676          can_post_task: false,
677        };
678
679        if dispatch.async_preemption {
680          self
681            .unbound
682            .event_listener
683            .did_async_preempt(&mut helper_scope);
684        } else {
685          // validator should ensure all helper indexes are present in the table
686          let Some((_, got_helper_name, helper)) = self
687            .unbound
688            .helpers
689            .get(
690              ((dispatch.index & 0xffff) as u16 ^ self.unbound.helper_id_xor).wrapping_sub(1)
691                as usize,
692            )
693            .copied()
694          else {
695            panic!("unknown helper index: {}", dispatch.index);
696          };
697          helper_name = got_helper_name;
698
699          helper_scope.can_post_task = true;
700          resume_input = helper(
701            &mut helper_scope,
702            dispatch.arg1,
703            dispatch.arg2,
704            dispatch.arg3,
705            dispatch.arg4,
706            dispatch.arg5,
707          )
708          .map_err(|()| RuntimeError::HelperError(helper_name))?;
709          helper_scope.can_post_task = false;
710        }
711
712        let pending_async_task = PENDING_ASYNC_TASK.with(|x| x.borrow_mut().take());
713
714        // Fast path: do not read timestamp if no thread migration or async preemption happened
715        let new_rust_tid_sigusr1_counter =
716          (RUST_TID.with(|x| *x), SIGUSR1_COUNTER.with(|x| x.get()));
717        if new_rust_tid_sigusr1_counter == rust_tid_sigusr1_counter && pending_async_task.is_none()
718        {
719          continue;
720        }
721
722        rust_tid_sigusr1_counter = new_rust_tid_sigusr1_counter;
723
724        let now = Instant::now();
725        let should_throttle = now > last_throttle_time
726          && now.duration_since(last_throttle_time) >= timeslice.max_run_time_before_throttle;
727        let should_yield = now > last_yield_time
728          && now.duration_since(last_yield_time) >= timeslice.max_run_time_before_yield;
729        if should_throttle || should_yield || pending_async_task.is_some() {
730          // We are about to yield control to tokio. Save the shadow stack, and release the guard.
731          shadow_stack_saved = true;
732          unsafe {
733            std::ptr::copy_nonoverlapping(
734              shadow_stack_ptr.0.as_ptr() as *const u8,
735              ctx.copy_stack.as_mut_ptr() as *mut u8,
736              SHADOW_STACK_SIZE,
737            );
738          }
739          self.unbound.event_listener.did_save_shadow_stack();
740
741          // we are now free to give up control of current thread to other async tasks
742
743          if should_throttle {
744            if !did_throttle {
745              did_throttle = true;
746              tracing::warn!("throttling program");
747            }
748            timeslicer.sleep(timeslice.throttle_duration).await;
749            let now = Instant::now();
750            last_throttle_time = now;
751            last_yield_time = now;
752            let task = self.unbound.event_listener.did_throttle(&mut helper_scope);
753            if let Some(task) = task {
754              task.await;
755            }
756          } else if should_yield {
757            timeslicer.yield_now().await;
758            let now = Instant::now();
759            last_yield_time = now;
760            self.unbound.event_listener.did_yield();
761          }
762
763          // Now we have released all exclusive resources and can safely execute the async task
764          if let Some(pending_async_task) = pending_async_task {
765            let async_start = Instant::now();
766            prev_async_task_output = Some((helper_name, pending_async_task.await));
767            let async_dur = async_start.elapsed();
768            last_throttle_time += async_dur;
769            last_yield_time += async_dur;
770          }
771        }
772      }
773    };
774
775    Ok(program_ret as i64)
776  }
777}
778
779struct Vm(NonNull<crate::ubpf::ubpf_vm>);
780
781impl Vm {
782  fn new(cage: &PointerCage) -> Self {
783    let vm = NonNull::new(unsafe { crate::ubpf::ubpf_create() }).expect("failed to create ubpf_vm");
784    unsafe {
785      crate::ubpf::ubpf_toggle_bounds_check(vm.as_ptr(), false);
786      crate::ubpf::ubpf_toggle_jit_shadow_stack(vm.as_ptr(), true);
787      crate::ubpf::ubpf_set_jit_pointer_mask_and_offset(vm.as_ptr(), cage.mask(), cage.offset());
788    }
789    Self(vm)
790  }
791}
792
793impl Drop for Vm {
794  fn drop(&mut self) {
795    unsafe {
796      crate::ubpf::ubpf_destroy(self.0.as_ptr());
797    }
798  }
799}
800
801impl ProgramLoader {
802  /// Creates a new `ProgramLoader` to load eBPF code.
803  pub fn new(
804    rng: &mut impl rand::Rng,
805    event_listener: Arc<dyn ProgramEventListener>,
806    raw_helpers: &[&[(&'static str, Helper)]],
807  ) -> Self {
808    let helper_id_xor = rng.gen::<u16>();
809    let mut helpers_inverse: HashMap<&'static str, i32> = HashMap::new();
810    // Collect first to a HashMap then to a Vec to deduplicate
811    let mut shuffled_helpers = raw_helpers
812      .iter()
813      .flat_map(|x| x.iter().copied())
814      .collect::<HashMap<_, _>>()
815      .into_iter()
816      .collect::<Vec<_>>();
817    shuffled_helpers.shuffle(rng);
818    let mut helpers: Vec<(u16, &'static str, Helper)> = Vec::with_capacity(shuffled_helpers.len());
819
820    assert!(shuffled_helpers.len() <= 65535);
821
822    for (i, (name, helper)) in shuffled_helpers.into_iter().enumerate() {
823      let entropy = rng.gen::<u16>() & 0x7fff;
824      helpers.push((entropy, name, helper));
825      helpers_inverse.insert(
826        name,
827        (((entropy as usize) << 16) | ((i + 1) ^ (helper_id_xor as usize))) as i32,
828      );
829    }
830
831    tracing::info!(?helpers_inverse, "generated helper table");
832    Self {
833      helper_id_xor,
834      helpers: Arc::new(helpers),
835      helpers_inverse,
836      event_listener,
837    }
838  }
839
840  /// Loads an ELF image into a new `UnboundProgram`.
841  pub fn load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, Error> {
842    self._load(rng, elf).map_err(Error)
843  }
844
845  fn _load(&self, rng: &mut impl rand::Rng, elf: &[u8]) -> Result<UnboundProgram, RuntimeError> {
846    let start_time = Instant::now();
847    let cage = PointerCage::new(rng, SHADOW_STACK_SIZE, elf.len())?;
848    let vm = Vm::new(&cage);
849
850    // Relocate ELF
851    let code_sections = {
852      // XXX: Although we are writing to the data region, we need to use `safe_deref_for_read`
853      // here because the `_write` variant checks that the requested region is within the
854      // stack. It's safe here because `freeze_data` is not yet called.
855      let mut data = cage
856        .safe_deref_for_read(cage.data_bottom(), elf.len())
857        .unwrap();
858      let data = unsafe { data.as_mut() };
859      data.copy_from_slice(elf);
860
861      link_elf(data, cage.data_bottom(), &self.helpers_inverse).map_err(RuntimeError::Linker)?
862    };
863    cage.freeze_data();
864
865    let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
866    if page_size < 0 {
867      return Err(RuntimeError::PlatformError("failed to get page size"));
868    }
869    let page_size = page_size as usize;
870
871    // Allocate code memory
872    let guard_size_before = rng.gen_range(16..128) * page_size;
873    let mut guard_size_after = rng.gen_range(16..128) * page_size;
874
875    let code_len_allocated: usize = 65536;
876    let code_mem = MmapRaw::from(
877      MmapOptions::new()
878        .len(code_len_allocated + guard_size_before + guard_size_after)
879        .map_anon()
880        .map_err(|_| RuntimeError::PlatformError("failed to allocate code memory"))?,
881    );
882
883    unsafe {
884      if crate::ubpf::ubpf_register_external_dispatcher(
885        vm.0.as_ptr(),
886        Some(tls_dispatcher),
887        Some(std_validator),
888        self as *const _ as *mut c_void,
889      ) != 0
890      {
891        return Err(RuntimeError::PlatformError(
892          "ubpf: failed to register external dispatcher",
893        ));
894      }
895      if libc::mprotect(
896        code_mem.as_mut_ptr() as *mut _,
897        guard_size_before,
898        libc::PROT_NONE,
899      ) != 0
900        || libc::mprotect(
901          code_mem
902            .as_mut_ptr()
903            .offset((guard_size_before + code_len_allocated) as isize) as *mut _,
904          guard_size_after,
905          libc::PROT_NONE,
906        ) != 0
907      {
908        return Err(RuntimeError::PlatformError("failed to protect guard pages"));
909      }
910    }
911
912    let mut entrypoints: HashMap<String, Entrypoint> = HashMap::new();
913
914    unsafe {
915      // Translate eBPF to native code
916      let mut code_slice = std::slice::from_raw_parts_mut(
917        code_mem.as_mut_ptr().offset(guard_size_before as isize),
918        code_len_allocated,
919      );
920      for (section_name, code_vaddr_size) in code_sections {
921        if code_slice.is_empty() {
922          return Err(RuntimeError::InvalidArgument(
923            "no space left for jit compilation",
924          ));
925        }
926
927        crate::ubpf::ubpf_unload_code(vm.0.as_ptr());
928
929        let mut errmsg_ptr = std::ptr::null_mut();
930        let code = cage
931          .safe_deref_for_read(code_vaddr_size.0, code_vaddr_size.1)
932          .unwrap();
933        let ret = crate::ubpf::ubpf_load(
934          vm.0.as_ptr(),
935          code.as_ptr() as *const _,
936          code.len() as u32,
937          &mut errmsg_ptr,
938        );
939        if ret != 0 {
940          let errmsg = if errmsg_ptr.is_null() {
941            "".to_string()
942          } else {
943            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
944          };
945          if !errmsg_ptr.is_null() {
946            libc::free(errmsg_ptr as _);
947          }
948          tracing::error!(section_name, error = errmsg, "failed to load code");
949          return Err(RuntimeError::PlatformError("ubpf: code load failed"));
950        }
951
952        let mut written_len = code_slice.len();
953        let ret = crate::ubpf::ubpf_translate(
954          vm.0.as_ptr(),
955          code_slice.as_mut_ptr(),
956          &mut written_len,
957          &mut errmsg_ptr,
958        );
959        if ret != 0 {
960          let errmsg = if errmsg_ptr.is_null() {
961            "".to_string()
962          } else {
963            CStr::from_ptr(errmsg_ptr).to_string_lossy().into_owned()
964          };
965          if !errmsg_ptr.is_null() {
966            libc::free(errmsg_ptr as _);
967          }
968          tracing::error!(section_name, error = errmsg, "failed to translate code");
969          return Err(RuntimeError::PlatformError("ubpf: code translation failed"));
970        }
971
972        assert!(written_len <= code_slice.len());
973        entrypoints.insert(
974          section_name,
975          Entrypoint {
976            code_ptr: code_mem.as_ptr() as usize + guard_size_before + code_len_allocated
977              - code_slice.len(),
978            code_len: written_len,
979          },
980        );
981        code_slice = &mut code_slice[written_len..];
982      }
983
984      // Align up code_len to page size
985      let unpadded_code_len = code_len_allocated - code_slice.len();
986      let code_len = (unpadded_code_len + page_size - 1) & !(page_size - 1);
987      assert!(code_len <= code_len_allocated);
988
989      // RW- -> R-X
990      // Also make the unused part of the pre-allocated code region PROT_NONE
991      if libc::mprotect(
992        code_mem.as_mut_ptr().offset(guard_size_before as isize) as *mut _,
993        code_len,
994        libc::PROT_READ | libc::PROT_EXEC,
995      ) != 0
996        || (code_len < code_len_allocated
997          && libc::mprotect(
998            code_mem
999              .as_mut_ptr()
1000              .offset((guard_size_before + code_len) as isize) as *mut _,
1001            code_len_allocated - code_len,
1002            libc::PROT_NONE,
1003          ) != 0)
1004      {
1005        return Err(RuntimeError::PlatformError("failed to protect code memory"));
1006      }
1007
1008      guard_size_after += code_len_allocated - code_len;
1009
1010      tracing::info!(
1011        elf_size = elf.len(),
1012        native_code_addr = ?code_mem.as_ptr(),
1013        native_code_size = code_len,
1014        native_code_size_unpadded = unpadded_code_len,
1015        guard_size_before,
1016        guard_size_after,
1017        duration = ?start_time.elapsed(),
1018        cage_ptr = ?cage.region().as_ptr(),
1019        cage_mapped_size = cage.region().len(),
1020        "jit compiled program"
1021      );
1022
1023      Ok(UnboundProgram {
1024        id: NEXT_PROGRAM_ID.fetch_add(1, Ordering::Relaxed),
1025        _code_mem: code_mem,
1026        cage,
1027        helper_id_xor: self.helper_id_xor,
1028        helpers: self.helpers.clone(),
1029        event_listener: self.event_listener.clone(),
1030        entrypoints,
1031        stack_rng_seed: rng.gen(),
1032      })
1033    }
1034  }
1035}
1036
1037#[derive(Default)]
1038struct Dispatch {
1039  async_preemption: bool,
1040  memory_access_error: Option<usize>,
1041
1042  index: u32,
1043  arg1: u64,
1044  arg2: u64,
1045  arg3: u64,
1046  arg4: u64,
1047  arg5: u64,
1048}
1049
1050unsafe extern "C" fn tls_dispatcher(
1051  arg1: u64,
1052  arg2: u64,
1053  arg3: u64,
1054  arg4: u64,
1055  arg5: u64,
1056  index: std::os::raw::c_uint,
1057  _cookie: *mut std::os::raw::c_void,
1058) -> u64 {
1059  let yielder = ACTIVE_JIT_CODE_ZONE
1060    .with(|x| x.yielder.get())
1061    .expect("no yielder");
1062  let yielder = yielder.as_ref();
1063  let ret = yielder.suspend(Dispatch {
1064    async_preemption: false,
1065    memory_access_error: None,
1066    index,
1067    arg1,
1068    arg2,
1069    arg3,
1070    arg4,
1071    arg5,
1072  });
1073  ret
1074}
1075
1076unsafe extern "C" fn std_validator(
1077  index: std::os::raw::c_uint,
1078  loader: *mut std::os::raw::c_void,
1079) -> bool {
1080  let loader = &*(loader as *const ProgramLoader);
1081  let entropy = (index >> 16) & 0xffff;
1082  let index = (((index & 0xffff) as u16) ^ loader.helper_id_xor).wrapping_sub(1);
1083  loader.helpers.get(index as usize).map(|x| x.0) == Some(entropy as u16)
1084}
1085
1086#[cfg(all(target_arch = "x86_64", target_os = "linux"))]
1087unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1088  (*uctx).uc_mcontext.gregs[libc::REG_RIP as usize] as usize
1089}
1090
1091#[cfg(all(target_arch = "aarch64", target_os = "linux"))]
1092unsafe fn program_counter(uctx: *mut libc::ucontext_t) -> usize {
1093  (*uctx).uc_mcontext.pc as usize
1094}
1095
1096unsafe extern "C" fn sigsegv_handler(
1097  _sig: i32,
1098  siginfo: *mut libc::siginfo_t,
1099  uctx: *mut libc::ucontext_t,
1100) {
1101  let fail = || restore_default_signal_handler(libc::SIGSEGV);
1102
1103  let Some((jit_code_zone, pointer_cage, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1104    if x.valid.load(Ordering::Relaxed) {
1105      compiler_fence(Ordering::Acquire);
1106      Some((
1107        x.code_range.get(),
1108        x.pointer_cage_protected_range.get(),
1109        x.yielder.get(),
1110      ))
1111    } else {
1112      None
1113    }
1114  }) else {
1115    return fail();
1116  };
1117
1118  let pc = program_counter(uctx);
1119
1120  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1121    return fail();
1122  }
1123
1124  // SEGV_ACCERR
1125  if (*siginfo).si_code != 2 {
1126    return fail();
1127  }
1128
1129  let si_addr = (*siginfo).si_addr() as usize;
1130  if si_addr < pointer_cage.0 || si_addr >= pointer_cage.1 {
1131    return fail();
1132  }
1133
1134  let yielder = yielder.expect("no yielder").as_ref();
1135  yielder.suspend(Dispatch {
1136    memory_access_error: Some(si_addr),
1137    ..Default::default()
1138  });
1139}
1140
1141unsafe extern "C" fn sigusr1_handler(
1142  _sig: i32,
1143  _siginfo: *mut libc::siginfo_t,
1144  uctx: *mut libc::ucontext_t,
1145) {
1146  SIGUSR1_COUNTER.with(|x| x.set(x.get() + 1));
1147
1148  let Some((jit_code_zone, yielder)) = ACTIVE_JIT_CODE_ZONE.with(|x| {
1149    if x.valid.load(Ordering::Relaxed) {
1150      compiler_fence(Ordering::Acquire);
1151      Some((x.code_range.get(), x.yielder.get()))
1152    } else {
1153      None
1154    }
1155  }) else {
1156    return;
1157  };
1158  let pc = program_counter(uctx);
1159  if pc < jit_code_zone.0 || pc >= jit_code_zone.1 {
1160    return;
1161  }
1162
1163  let yielder = yielder.expect("no yielder").as_ref();
1164  yielder.suspend(Dispatch {
1165    async_preemption: true,
1166    ..Default::default()
1167  });
1168}
1169
1170unsafe fn restore_default_signal_handler(signum: i32) {
1171  let act = libc::sigaction {
1172    sa_sigaction: libc::SIG_DFL,
1173    sa_flags: libc::SA_SIGINFO,
1174    sa_mask: std::mem::zeroed(),
1175    sa_restorer: None,
1176  };
1177  if libc::sigaction(signum, &act, std::ptr::null_mut()) != 0 {
1178    libc::abort();
1179  }
1180}
1181
1182fn get_blocked_sigset() -> libc::sigset_t {
1183  unsafe {
1184    let mut s: libc::sigset_t = std::mem::zeroed();
1185    libc::sigaddset(&mut s, libc::SIGUSR1);
1186    libc::sigaddset(&mut s, libc::SIGSEGV);
1187    s
1188  }
1189}