Skip to main content

sp1_jit/
context.rs

1use crate::{debug, MemValue, RiscRegister, TraceChunkHeader};
2use memmap2::{MmapMut, MmapOptions};
3use std::{collections::VecDeque, io, os::fd::RawFd, ptr::NonNull, sync::mpsc};
4
5pub trait SyscallContext {
6    /// Read a value from a register.
7    fn rr(&self, reg: RiscRegister) -> u64;
8    /// Read a value from memory.
9    fn mr(&mut self, addr: u64) -> u64;
10    /// Write a value to memory.
11    fn mw(&mut self, addr: u64, val: u64);
12    /// Read a slice of values from memory.
13    fn mr_slice(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
14    /// Read a slice of values from memory, without updating the memory clock
15    /// Note that it still traces the access when tracing is enabled.
16    fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
17    /// Read a slice of values from memory, without updating the memory clock or tracing the access.
18    fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
19    /// Write a slice of values to memory.
20    fn mw_slice(&mut self, addr: u64, vals: &[u64]);
21    /// Get the input buffer
22    fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>>;
23    /// Get the public values stream.
24    fn public_values_stream(&mut self) -> &mut Vec<u8>;
25    /// Enter the unconstrained context.
26    fn enter_unconstrained(&mut self) -> io::Result<()>;
27    /// Exit the unconstrained context.
28    fn exit_unconstrained(&mut self);
29    /// Trace a hint.
30    fn trace_hint(&mut self, addr: u64, value: Vec<u8>);
31    /// Trace a dummy value.
32    fn trace_value(&mut self, value: u64);
33    /// Write a hint to memory, which is like setting uninitialized memory to a nonzero value
34    /// The clk will be set to 0, just like for uninitialized memory.
35    fn mw_hint(&mut self, addr: u64, val: u64);
36    /// Used for precompiles that access memory, that need to bump the clk.
37    /// This increment is local to the precompile, and does not affect the number of cycles
38    /// the precompile itself takes up.
39    fn bump_memory_clk(&mut self);
40    /// Set the exit code of the program.
41    fn set_exit_code(&mut self, exit_code: u32);
42    /// Returns if were in unconstrained mode.
43    fn is_unconstrained(&self) -> bool;
44    /// Get the global clock (total cycles executed).
45    fn global_clk(&self) -> u64;
46
47    /// Start tracking cycles for a label (profiling only).
48    /// Records the current `global_clk` as the start time.
49    /// Returns the nesting depth (0 for top-level, 1 for first nested, etc.).
50    #[cfg(feature = "profiling")]
51    fn cycle_tracker_start(&mut self, name: &str) -> u32;
52
53    /// End tracking cycles for a label (profiling only).
54    /// Returns (cycles_elapsed, depth) or None if no matching start.
55    #[cfg(feature = "profiling")]
56    fn cycle_tracker_end(&mut self, name: &str) -> Option<(u64, u32)>;
57
58    /// End tracking cycles for a label and accumulate to report totals (profiling only).
59    /// This is for "report" variants that should be included in ExecutionReport.
60    /// Returns (cycles_elapsed, depth) or None if no matching start.
61    #[cfg(feature = "profiling")]
62    fn cycle_tracker_report_end(&mut self, name: &str) -> Option<(u64, u32)>;
63}
64
65impl SyscallContext for JitContext {
66    #[inline]
67    fn bump_memory_clk(&mut self) {
68        self.clk += 1;
69    }
70
71    fn rr(&self, reg: RiscRegister) -> u64 {
72        self.registers[reg as usize]
73    }
74
75    fn mr(&mut self, addr: u64) -> u64 {
76        unsafe { ContextMemory::new(self).mr(addr) }
77    }
78
79    fn mw(&mut self, addr: u64, val: u64) {
80        unsafe { ContextMemory::new(self).mw(addr, val) };
81    }
82
83    fn mr_slice(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
84        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
85
86        // Convert the byte address to the word address.
87        let word_address = addr / 8;
88
89        let ptr = self.memory.as_ptr() as *mut MemValue;
90        let ptr = unsafe { ptr.add(word_address as usize) };
91
92        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
93        // See [JitFunction::new] for more details.
94        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
95
96        if self.tracing() {
97            unsafe {
98                self.trace_mem_access(slice);
99
100                // Bump the clk on the all current entries.
101                for (i, entry) in slice.iter().enumerate() {
102                    let new_entry = MemValue { value: entry.value, clk: self.clk };
103                    std::ptr::write(ptr.add(i), new_entry)
104                }
105            }
106        }
107
108        slice.iter().map(|val| &val.value)
109    }
110
111    fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
112        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
113
114        // Convert the byte address to the word address.
115        let word_address = addr / 8;
116
117        let ptr = self.memory.as_ptr() as *mut MemValue;
118        let ptr = unsafe { ptr.add(word_address as usize) };
119
120        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
121        // See [JitFunction::new] for more details.
122        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
123
124        slice.iter().map(|val| &val.value)
125    }
126
127    fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
128        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
129
130        // Convert the byte address to the word address.
131        let word_address = addr / 8;
132
133        let ptr = self.memory.as_ptr() as *mut MemValue;
134        let ptr = unsafe { ptr.add(word_address as usize) };
135
136        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
137        // See [JitFunction::new] for more details.
138        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
139
140        if self.tracing() {
141            unsafe {
142                self.trace_mem_access(slice);
143            }
144        }
145
146        slice.iter().map(|val| &val.value)
147    }
148
149    fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
150        unsafe { ContextMemory::new(self).mw_slice(addr, vals) };
151    }
152
153    fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
154        unsafe { self.input_buffer() }
155    }
156
157    fn public_values_stream(&mut self) -> &mut Vec<u8> {
158        unsafe { self.public_values_stream() }
159    }
160
161    fn enter_unconstrained(&mut self) -> io::Result<()> {
162        self.enter_unconstrained()
163    }
164
165    fn exit_unconstrained(&mut self) {
166        self.exit_unconstrained()
167    }
168
169    fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
170        if self.tracing {
171            unsafe { self.trace_hint(addr, value) };
172        }
173    }
174
175    fn trace_value(&mut self, value: u64) {
176        if self.tracing {
177            unsafe {
178                // u64::MAX is used as the clock, so it should likely be distinguished
179                // from memory values.
180                self.trace_mem_access(&[MemValue { clk: u64::MAX, value }]);
181            }
182        }
183    }
184
185    fn mw_hint(&mut self, addr: u64, val: u64) {
186        unsafe { ContextMemory::new(self).mw_hint(addr, val) };
187    }
188
189    fn set_exit_code(&mut self, exit_code: u32) {
190        self.exit_code = exit_code;
191    }
192
193    fn is_unconstrained(&self) -> bool {
194        self.is_unconstrained == 1
195    }
196
197    fn global_clk(&self) -> u64 {
198        self.global_clk
199    }
200
201    #[cfg(feature = "profiling")]
202    fn cycle_tracker_start(&mut self, _name: &str) -> u32 {
203        // JitContext is not used when profiling is enabled (portable executor is used instead).
204        // This is a no-op implementation for trait completeness.
205        0
206    }
207
208    #[cfg(feature = "profiling")]
209    fn cycle_tracker_end(&mut self, _name: &str) -> Option<(u64, u32)> {
210        // JitContext is not used when profiling is enabled (portable executor is used instead).
211        // This is a no-op implementation for trait completeness.
212        None
213    }
214
215    #[cfg(feature = "profiling")]
216    fn cycle_tracker_report_end(&mut self, _name: &str) -> Option<(u64, u32)> {
217        // JitContext is not used when profiling is enabled (portable executor is used instead).
218        // This is a no-op implementation for trait completeness.
219        None
220    }
221}
222
223#[repr(C)]
224#[derive(Debug)]
225pub struct JitContext {
226    /// The current program counter
227    pub pc: u64,
228    /// The number of cycles executed.
229    pub clk: u64,
230    /// The number of cycles executed.
231    pub global_clk: u64,
232    /// This context is in unconstrainted mode.
233    /// 1 if unconstrained, 0 otherwise.
234    pub is_unconstrained: u64,
235    /// Mapping from (pc - pc_base) / 4 => absolute address of the instruction.
236    pub(crate) jump_table: NonNull<*const u8>,
237    /// The pointer to the program memory.
238    pub(crate) memory: NonNull<u8>,
239    /// The pointer to the trace buffer.
240    pub(crate) trace_buf: NonNull<u8>,
241    /// The registers to start the execution with,
242    /// these are loaded into real native registers at the start of execution.
243    pub(crate) registers: [u64; 32],
244    /// The input buffer to the program.
245    pub(crate) input_buffer: NonNull<VecDeque<Vec<u8>>>,
246    /// A stream of public values from the program (global to entire program).
247    pub(crate) public_values_stream: NonNull<Vec<u8>>,
248    /// The hints read by the program, with thier corresponding start address.
249    pub(crate) hints: NonNull<Vec<(u64, Vec<u8>)>>,
250    /// The memory file descriptor, this is used to create the COW memory at runtime.
251    pub(crate) memory_fd: RawFd,
252    /// The unconstrained context, this is used to create the COW memory at runtime.
253    pub(crate) maybe_unconstrained: Option<UnconstrainedCtx>,
254    /// Whether the JIT is tracing.
255    pub(crate) tracing: bool,
256    /// Whether the JIT is sending debug state every instruction.
257    pub(crate) debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
258    /// The exit code of the program.
259    pub(crate) exit_code: u32,
260}
261
262impl JitContext {
263    /// # Safety
264    /// - todo
265    pub unsafe fn trace_mem_access(&self, reads: &[MemValue]) {
266        // QUESTIONABLE: I think as long as Self is not `Sync` youre mostly fine, but its unclear,
267        // how to actually call this method safe without taking a `&mut self`.
268
269        // Read the current num reads from the trace buf.
270        let raw = self.trace_buf.as_ptr();
271        let num_reads_offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
272        let num_reads_ptr = raw.add(num_reads_offset);
273        let num_reads = std::ptr::read_unaligned(num_reads_ptr as *mut u64);
274
275        // Write the new num reads to the trace buf.
276        let new_num_reads = num_reads + reads.len() as u64;
277        std::ptr::write_unaligned(num_reads_ptr as *mut u64, new_num_reads);
278
279        // Write the new reads to the trace buf.
280        let reads_start = std::mem::size_of::<TraceChunkHeader>();
281        let tail_ptr = raw.add(reads_start) as *mut MemValue;
282        let tail_ptr = tail_ptr.add(num_reads as usize);
283
284        for (i, read) in reads.iter().enumerate() {
285            std::ptr::write(tail_ptr.add(i), *read);
286        }
287    }
288
289    /// Enter the unconstrained context, this will create a COW memory map of the memory file
290    /// descriptor.
291    pub fn enter_unconstrained(&mut self) -> io::Result<()> {
292        // SAFETY: The memory is allocated by the [JitFunction] and is valid, not aliased, and has
293        // enough space for the alignment.
294        let mut cow_memory =
295            unsafe { MmapOptions::new().no_reserve_swap().map_copy(self.memory_fd)? };
296        let cow_memory_ptr = cow_memory.as_mut_ptr();
297
298        // Align the ptr to u32.
299        // SAFETY: u8 has the minimum alignment, so any larger alignment will be a multiple of this.
300        let align_offset = cow_memory_ptr.align_offset(std::mem::align_of::<u64>());
301        let cow_memory_ptr = unsafe { cow_memory_ptr.add(align_offset) };
302
303        // Preserve the current state of the JIT context.
304        self.maybe_unconstrained = Some(UnconstrainedCtx {
305            cow_memory,
306            actual_memory_ptr: self.memory,
307            pc: self.pc,
308            clk: self.clk,
309            global_clk: self.global_clk,
310            registers: self.registers,
311        });
312
313        // Bump the PC to the next instruction.
314        self.pc = self.pc.wrapping_add(4);
315
316        // Set the memory pointer used by the JIT as the COW memory.
317        //
318        // SAFETY: [memmap2] does not return a null pointer.
319        self.memory = unsafe { NonNull::new_unchecked(cow_memory_ptr) };
320
321        // Set the is_unconstrained flag to 1.
322        self.is_unconstrained = 1;
323
324        Ok(())
325    }
326
327    /// Exit the unconstrained context, this will restore the original memory map.
328    pub fn exit_unconstrained(&mut self) {
329        let unconstrained = std::mem::take(&mut self.maybe_unconstrained)
330            .expect("Exit unconstrained called but not context is present, this is a bug.");
331
332        self.memory = unconstrained.actual_memory_ptr;
333        self.pc = unconstrained.pc;
334        self.registers = unconstrained.registers;
335        self.clk = unconstrained.clk;
336        self.is_unconstrained = 0;
337    }
338
339    /// Indicate that the program has read a hint.
340    ///
341    /// This is used to store the hints read by the program.
342    ///
343    /// # Safety
344    /// - The address must be aligned to 8 bytes.
345    /// - The hints pointer must not be mutably aliased.
346    pub unsafe fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
347        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
348        self.hints.as_mut().push((addr, value));
349    }
350
351    /// Obtain a mutable view of the emulated memory.
352    pub const fn memory(&mut self) -> ContextMemory<'_> {
353        unsafe { ContextMemory::new(self) }
354    }
355
356    /// # Safety
357    /// - The input buffer must be non null and valid to read from.
358    pub const unsafe fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
359        self.input_buffer.as_mut()
360    }
361
362    /// # Safety
363    /// - The public values stream must be non null and valid to read from.
364    pub const unsafe fn public_values_stream(&mut self) -> &mut Vec<u8> {
365        self.public_values_stream.as_mut()
366    }
367
368    /// Obtain a view of the registers.
369    pub const fn registers(&self) -> &[u64; 32] {
370        &self.registers
371    }
372
373    pub const fn rw(&mut self, reg: RiscRegister, val: u64) {
374        self.registers[reg as usize] = val;
375    }
376
377    pub const fn rr(&self, reg: RiscRegister) -> u64 {
378        self.registers[reg as usize]
379    }
380
381    #[inline]
382    pub const fn tracing(&self) -> bool {
383        self.tracing
384    }
385}
386
387/// The saved context of the JIT runtime, when entering the unconstrained context.
388#[derive(Debug)]
389pub struct UnconstrainedCtx {
390    // An COW version of the memory.
391    pub cow_memory: MmapMut,
392    // The pointer to the actual memory.
393    pub actual_memory_ptr: NonNull<u8>,
394    // The program counter.
395    pub pc: u64,
396    // The clock.
397    pub clk: u64,
398    // The clock.
399    pub global_clk: u64,
400    // The registers.
401    pub registers: [u64; 32],
402}
403
404/// A type representing the memory of the emulated program.
405///
406/// This is used to read and write to the memory in precompile impls.
407pub struct ContextMemory<'a> {
408    ctx: &'a mut JitContext,
409}
410
411impl<'a> ContextMemory<'a> {
412    /// Create a new memory view.
413    ///
414    /// This type takes in a mutable refrence with a lifetime to avoid aliasing the underlying
415    /// memory region.
416    ///
417    /// # Safety
418    /// - The memory is valid for the lifetime of this type.
419    /// - The memory should be aligned to 8 bytes.
420    /// - The memory should be valid to read from and write to.
421    /// - The memory should be the expected size.
422    const unsafe fn new(ctx: &'a mut JitContext) -> Self {
423        Self { ctx }
424    }
425
426    #[inline]
427    pub const fn tracing(&self) -> bool {
428        self.ctx.tracing()
429    }
430
431    /// Read a u64 from the memory.
432    pub fn mr(&self, addr: u64) -> u64 {
433        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
434        // Convert the byte address to the word address.
435        let word_address = addr / 8;
436
437        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
438        let ptr = unsafe { ptr.add(word_address as usize) };
439
440        // SAFETY: The pointer is valid to read from, as it was aligned by us during allocation.
441        // See [JitFunction::new] for more details.
442        let entry = unsafe { std::ptr::read(ptr) };
443
444        if self.tracing() {
445            unsafe {
446                self.ctx.trace_mem_access(&[entry]);
447
448                // Bump the clk
449                let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
450                std::ptr::write(ptr, new_entry);
451            }
452        }
453
454        entry.value
455    }
456
457    /// Write a u64 to the memory.
458    pub fn mw(&mut self, addr: u64, val: u64) {
459        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
460
461        // Convert the byte address to the word address.
462        let word_address = addr / 8;
463
464        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
465        let ptr = unsafe { ptr.add(word_address as usize) };
466
467        // Bump the clk and insert the new value.
468        let value = MemValue { value: val, clk: self.ctx.clk };
469
470        // Trace the current entry.
471        if self.tracing() {
472            unsafe {
473                // Trace the current entry, the clock is bumped in the subsequent write.
474                let current_entry = std::ptr::read(ptr);
475                self.ctx.trace_mem_access(&[current_entry, value]);
476            }
477        }
478
479        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
480        // See [JitFunction::new] for more details.
481        unsafe { std::ptr::write(ptr, value) };
482    }
483
484    /// Read a slice of u64 from the memory.
485    pub fn mr_slice(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
486        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
487
488        // Convert the byte address to the word address.
489        let word_address = addr / 8;
490
491        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
492        let ptr = unsafe { ptr.add(word_address as usize) };
493
494        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
495        // See [JitFunction::new] for more details.
496        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
497
498        if self.tracing() {
499            unsafe {
500                self.ctx.trace_mem_access(slice);
501
502                // Bump the clk on the all current entries.
503                for (i, entry) in slice.iter().enumerate() {
504                    let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
505                    std::ptr::write(ptr.add(i), new_entry)
506                }
507            }
508        }
509
510        slice.iter().map(|val| &val.value)
511    }
512
513    // Read a slice from memory, without bumping the clk.
514    pub fn mr_slice_unsafe(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
515        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
516
517        // Convert the byte address to the word address.
518        let word_address = addr / 8;
519
520        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
521        let ptr = unsafe { ptr.add(word_address as usize) };
522
523        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
524        // See [JitFunction::new] for more details.
525        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
526
527        if self.tracing() {
528            unsafe {
529                self.ctx.trace_mem_access(slice);
530            }
531        }
532
533        slice.iter().map(|val| &val.value)
534    }
535
536    /// Write a slice of u64 to the memory.
537    pub fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
538        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
539
540        // Convert the byte address to the word address.
541        let word_address = addr / 8;
542
543        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
544        let ptr = unsafe { ptr.add(word_address as usize) };
545
546        // Bump the clk and insert the new values.
547        let values = vals.iter().map(|val| MemValue { value: *val, clk: self.ctx.clk });
548
549        // Trace the current entries.
550
551        if self.tracing() {
552            unsafe {
553                let current_entries = std::slice::from_raw_parts(ptr, vals.len());
554
555                for (curr, new) in current_entries.iter().zip(values.clone()) {
556                    self.ctx.trace_mem_access(&[*curr, new]);
557                }
558            }
559        }
560
561        for (i, val) in values.enumerate() {
562            unsafe { std::ptr::write(ptr.add(i), val) };
563        }
564    }
565
566    // Read a slice from memory, without bumping the clk.
567    pub fn mr_slice_no_trace(
568        &self,
569        addr: u64,
570        len: usize,
571    ) -> impl IntoIterator<Item = &u64> + Clone {
572        debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
573
574        // Convert the byte address to the word address.
575        let word_address = addr / 8;
576
577        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
578        let ptr = unsafe { ptr.add(word_address as usize) };
579
580        // SAFETY: The pointer is valid to write to, as it was aligned by us during allocation.
581        // See [JitFunction::new] for more details.
582        let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
583
584        slice.iter().map(|val| &val.value)
585    }
586
587    /// Write a u64 to memory, without tracing and sets the clk in the entry to 0.
588    pub fn mw_hint(&mut self, addr: u64, val: u64) {
589        let words = addr / 8;
590
591        let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
592        let ptr = unsafe { ptr.add(words as usize) };
593
594        let new_entry = MemValue { value: val, clk: 0 };
595        unsafe { std::ptr::write(ptr, new_entry) };
596    }
597}