1use crate::{debug, ElfInfo, Interrupt, MemValue, PageProtValue, RiscRegister, TraceChunkHeader};
2use memmap2::{MmapMut, MmapOptions};
3use sp1_primitives::consts::{PROT_READ, PROT_WRITE};
4use std::{collections::VecDeque, io, os::fd::RawFd, ptr::NonNull, sync::mpsc};
5
6pub const PUBLIC_VALUE_DIGEST_WORDS: usize = 8;
11
12pub trait SyscallContext {
13 fn rr(&self, reg: RiscRegister) -> u64;
15 fn rw(&mut self, reg: RiscRegister, value: u64);
17 fn set_next_pc(&mut self, pc: u64);
19 fn mr_without_prot(&mut self, addr: u64) -> u64;
21 fn mw_without_prot(&mut self, addr: u64, val: u64);
23 fn mr_slice(
25 &mut self,
26 addr: u64,
27 len: usize,
28 ) -> Result<impl IntoIterator<Item = &u64>, Interrupt> {
29 self.prot_slice_check(addr, len, PROT_READ)?;
30 Ok(self.mr_slice_without_prot(addr, len))
31 }
32 fn mr_slice_without_prot(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
34 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
37 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64>;
39 fn mw_slice(&mut self, addr: u64, vals: &[u64]) -> Result<(), Interrupt> {
41 self.prot_slice_check(addr, vals.len(), PROT_WRITE)?;
42 self.mw_slice_without_prot(addr, vals);
43 Ok(())
44 }
45 fn mw_slice_without_prot(&mut self, addr: u64, vals: &[u64]);
47 #[inline]
49 fn read_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
50 self.prot_slice_check(addr, len, PROT_READ)
51 }
52 #[inline]
54 fn write_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
55 self.prot_slice_check(addr, len, PROT_WRITE)
56 }
57 #[inline]
59 fn read_write_slice_check(&mut self, addr: u64, len: usize) -> Result<(), Interrupt> {
60 self.prot_slice_check(addr, len, PROT_READ | PROT_WRITE)
61 }
62 fn prot_slice_check(&mut self, addr: u64, len: usize, prot_bitmap: u8)
64 -> Result<(), Interrupt>;
65 fn page_prot_write(&mut self, addr: u64, val: u8);
67 fn page_prot_flush(&mut self) {}
69 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>>;
71 fn public_values_stream(&mut self) -> &mut Vec<u8>;
73 fn enter_unconstrained(&mut self) -> io::Result<()>;
75 fn exit_unconstrained(&mut self);
77 fn trace_hint(&mut self, addr: u64, value: Vec<u8>);
79 fn trace_value(&mut self, value: u64);
81 fn mw_hint(&mut self, addr: u64, val: u64);
84 fn bump_memory_clk(&mut self);
88 fn get_current_clk(&self) -> u64;
90 fn set_clk(&mut self, clk: u64);
92 fn set_exit_code(&mut self, exit_code: u32);
94 fn set_public_value_digest_word(&mut self, word_idx: u64, digest_word: u32);
98 fn is_unconstrained(&self) -> bool;
100 fn global_clk(&self) -> u64;
102
103 #[cfg(feature = "profiling")]
107 fn cycle_tracker_start(&mut self, name: &str) -> u32;
108
109 #[cfg(feature = "profiling")]
112 fn cycle_tracker_end(&mut self, name: &str) -> Option<(u64, u32)>;
113
114 #[cfg(feature = "profiling")]
118 fn cycle_tracker_report_end(&mut self, name: &str) -> Option<(u64, u32)>;
119
120 fn elf_info(&self) -> ElfInfo;
122 fn init_addr_iter(&self) -> impl IntoIterator<Item = u64>;
124 fn page_prot_iter(&self) -> impl IntoIterator<Item = (&u64, &PageProtValue)>;
126 fn maybe_dump_profiler_data(&self) -> (Vec<(String, u64, u64)>, Vec<u64>);
128 fn maybe_insert_profiler_symbols<I: Iterator<Item = (String, u64, u64)>>(&mut self, iter: I);
130 fn maybe_delete_profiler_symbols<I: Iterator<Item = u64>>(&mut self, iter: I);
132}
133
134impl SyscallContext for JitContext {
135 #[inline]
136 fn bump_memory_clk(&mut self) {
137 self.clk += 1;
138 }
139
140 #[inline]
141 fn get_current_clk(&self) -> u64 {
142 self.clk
143 }
144
145 #[inline]
146 fn set_clk(&mut self, clk: u64) {
147 self.clk = clk;
148 }
149
150 fn rr(&self, reg: RiscRegister) -> u64 {
151 self.registers[reg as usize]
152 }
153
154 fn rw(&mut self, _reg: RiscRegister, _value: u64) {
155 unimplemented!()
156 }
157
158 fn set_next_pc(&mut self, _pc: u64) {
159 unimplemented!()
160 }
161
162 fn mr_without_prot(&mut self, addr: u64) -> u64 {
163 unsafe { ContextMemory::new(self).mr(addr) }
164 }
165
166 fn mw_without_prot(&mut self, addr: u64, val: u64) {
167 unsafe { ContextMemory::new(self).mw(addr, val) };
168 }
169
170 fn mr_slice_without_prot(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
171 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
172
173 let word_address = addr / 8;
175
176 let ptr = self.memory.as_ptr() as *mut MemValue;
177 let ptr = unsafe { ptr.add(word_address as usize) };
178
179 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
182
183 if self.tracing() {
184 unsafe {
185 self.trace_mem_access(slice);
186
187 for (i, entry) in slice.iter().enumerate() {
189 let new_entry = MemValue { value: entry.value, clk: self.clk };
190 std::ptr::write(ptr.add(i), new_entry)
191 }
192 }
193 }
194
195 slice.iter().map(|val| &val.value)
196 }
197
198 fn mr_slice_no_trace(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
199 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
200
201 let word_address = addr / 8;
203
204 let ptr = self.memory.as_ptr() as *mut MemValue;
205 let ptr = unsafe { ptr.add(word_address as usize) };
206
207 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
210
211 slice.iter().map(|val| &val.value)
212 }
213
214 fn mr_slice_unsafe(&mut self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> {
215 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
216
217 let word_address = addr / 8;
219
220 let ptr = self.memory.as_ptr() as *mut MemValue;
221 let ptr = unsafe { ptr.add(word_address as usize) };
222
223 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
226
227 if self.tracing() {
228 unsafe {
229 self.trace_mem_access(slice);
230 }
231 }
232
233 slice.iter().map(|val| &val.value)
234 }
235
236 fn mw_slice_without_prot(&mut self, addr: u64, vals: &[u64]) {
237 unsafe { ContextMemory::new(self).mw_slice(addr, vals) };
238 }
239
240 fn prot_slice_check(
241 &mut self,
242 _addr: u64,
243 _len: usize,
244 _prot_bitmap: u8,
245 ) -> Result<(), Interrupt> {
246 Ok(())
249 }
250
251 fn page_prot_write(&mut self, _addr: u64, _val: u8) {
252 unimplemented!("page_prot_write not implemented for JitContext")
253 }
254
255 fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
256 unsafe { self.input_buffer() }
257 }
258
259 fn public_values_stream(&mut self) -> &mut Vec<u8> {
260 unsafe { self.public_values_stream() }
261 }
262
263 fn enter_unconstrained(&mut self) -> io::Result<()> {
264 self.enter_unconstrained()
265 }
266
267 fn exit_unconstrained(&mut self) {
268 self.exit_unconstrained()
269 }
270
271 fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
272 if self.tracing {
273 unsafe { self.trace_hint(addr, value) };
274 }
275 }
276
277 fn trace_value(&mut self, value: u64) {
278 if self.tracing {
279 unsafe {
280 self.trace_mem_access(&[MemValue { clk: u64::MAX, value }]);
283 }
284 }
285 }
286
287 fn mw_hint(&mut self, addr: u64, val: u64) {
288 unsafe { ContextMemory::new(self).mw_hint(addr, val) };
289 }
290
291 fn set_exit_code(&mut self, exit_code: u32) {
292 self.exit_code = exit_code;
293 }
294
295 fn set_public_value_digest_word(&mut self, word_idx: u64, digest_word: u32) {
296 let idx = word_idx as usize;
297 debug_assert!(
298 idx < PUBLIC_VALUE_DIGEST_WORDS,
299 "public value digest word index out of bounds: {idx}"
300 );
301 self.public_value_digest[idx] = digest_word;
302 }
303
304 fn is_unconstrained(&self) -> bool {
305 self.is_unconstrained == 1
306 }
307
308 fn global_clk(&self) -> u64 {
309 self.global_clk
310 }
311
312 #[cfg(feature = "profiling")]
313 fn cycle_tracker_start(&mut self, _name: &str) -> u32 {
314 0
317 }
318
319 #[cfg(feature = "profiling")]
320 fn cycle_tracker_end(&mut self, _name: &str) -> Option<(u64, u32)> {
321 None
324 }
325
326 #[cfg(feature = "profiling")]
327 fn cycle_tracker_report_end(&mut self, _name: &str) -> Option<(u64, u32)> {
328 None
331 }
332
333 fn elf_info(&self) -> ElfInfo {
334 unimplemented!()
335 }
336
337 fn init_addr_iter(&self) -> impl IntoIterator<Item = u64> {
338 Vec::new()
339 }
340
341 fn page_prot_iter(&self) -> impl IntoIterator<Item = (&u64, &PageProtValue)> {
342 Vec::new()
343 }
344
345 fn maybe_dump_profiler_data(&self) -> (Vec<(String, u64, u64)>, Vec<u64>) {
346 unimplemented!()
347 }
348
349 fn maybe_insert_profiler_symbols<I: Iterator<Item = (String, u64, u64)>>(&mut self, _iter: I) {
350 unimplemented!()
351 }
352
353 fn maybe_delete_profiler_symbols<I: Iterator<Item = u64>>(&mut self, _iter: I) {
354 unimplemented!()
355 }
356}
357
358#[repr(C)]
359#[derive(Debug)]
360pub struct JitContext {
361 pub pc: u64,
363 pub clk: u64,
365 pub global_clk: u64,
367 pub is_unconstrained: u64,
370 pub(crate) jump_table: NonNull<*const u8>,
372 pub(crate) memory: NonNull<u8>,
374 pub(crate) trace_buf: *mut u8,
376 pub(crate) registers: [u64; 32],
379 pub(crate) input_buffer: NonNull<VecDeque<Vec<u8>>>,
381 pub(crate) public_values_stream: NonNull<Vec<u8>>,
383 pub(crate) hints: NonNull<Vec<(u64, Vec<u8>)>>,
385 pub(crate) memory_fd: RawFd,
387 pub(crate) maybe_unconstrained: Option<UnconstrainedCtx>,
389 pub(crate) tracing: bool,
391 pub(crate) debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
393 pub(crate) exit_code: u32,
395 pub public_value_digest: [u32; PUBLIC_VALUE_DIGEST_WORDS],
397}
398
399impl JitContext {
400 pub unsafe fn trace_mem_access(&self, reads: &[MemValue]) {
403 let raw = self.trace_buf;
408 let num_reads_offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
409 let num_reads_ptr = raw.add(num_reads_offset);
410 let num_reads = std::ptr::read_unaligned(num_reads_ptr as *mut u64);
411
412 let new_num_reads = num_reads + reads.len() as u64;
414 std::ptr::write_unaligned(num_reads_ptr as *mut u64, new_num_reads);
415
416 let reads_start = std::mem::size_of::<TraceChunkHeader>();
418 let tail_ptr = raw.add(reads_start) as *mut MemValue;
419 let tail_ptr = tail_ptr.add(num_reads as usize);
420
421 for (i, read) in reads.iter().enumerate() {
422 std::ptr::write(tail_ptr.add(i), *read);
423 }
424 }
425
426 pub fn enter_unconstrained(&mut self) -> io::Result<()> {
429 let mut cow_memory =
432 unsafe { MmapOptions::new().no_reserve_swap().map_copy(self.memory_fd)? };
433 let cow_memory_ptr = cow_memory.as_mut_ptr();
434
435 let align_offset = cow_memory_ptr.align_offset(std::mem::align_of::<u64>());
438 let cow_memory_ptr = unsafe { cow_memory_ptr.add(align_offset) };
439
440 self.maybe_unconstrained = Some(UnconstrainedCtx {
442 cow_memory,
443 actual_memory_ptr: self.memory,
444 pc: self.pc,
445 clk: self.clk,
446 global_clk: self.global_clk,
447 registers: self.registers,
448 });
449
450 self.pc = self.pc.wrapping_add(4);
452
453 self.memory = unsafe { NonNull::new_unchecked(cow_memory_ptr) };
457
458 self.is_unconstrained = 1;
460
461 Ok(())
462 }
463
464 pub fn exit_unconstrained(&mut self) {
466 let unconstrained = std::mem::take(&mut self.maybe_unconstrained)
467 .expect("Exit unconstrained called but no context is present, this is a bug.");
468
469 self.memory = unconstrained.actual_memory_ptr;
470 self.pc = unconstrained.pc;
471 self.registers = unconstrained.registers;
472 self.clk = unconstrained.clk;
473 self.is_unconstrained = 0;
474 }
475
476 pub unsafe fn trace_hint(&mut self, addr: u64, value: Vec<u8>) {
484 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
485 self.hints.as_mut().push((addr, value));
486 }
487
488 pub const fn memory(&mut self) -> ContextMemory<'_> {
490 unsafe { ContextMemory::new(self) }
491 }
492
493 pub const unsafe fn input_buffer(&mut self) -> &mut VecDeque<Vec<u8>> {
496 self.input_buffer.as_mut()
497 }
498
499 pub const unsafe fn public_values_stream(&mut self) -> &mut Vec<u8> {
502 self.public_values_stream.as_mut()
503 }
504
505 pub const fn registers(&self) -> &[u64; 32] {
507 &self.registers
508 }
509
510 pub const fn rw(&mut self, reg: RiscRegister, val: u64) {
511 self.registers[reg as usize] = val;
512 }
513
514 pub const fn rr(&self, reg: RiscRegister) -> u64 {
515 self.registers[reg as usize]
516 }
517
518 #[inline]
519 pub const fn tracing(&self) -> bool {
520 self.tracing
521 }
522}
523
524#[derive(Debug)]
526pub struct UnconstrainedCtx {
527 pub cow_memory: MmapMut,
529 pub actual_memory_ptr: NonNull<u8>,
531 pub pc: u64,
533 pub clk: u64,
535 pub global_clk: u64,
537 pub registers: [u64; 32],
539}
540
541pub struct ContextMemory<'a> {
545 ctx: &'a mut JitContext,
546}
547
548impl<'a> ContextMemory<'a> {
549 const unsafe fn new(ctx: &'a mut JitContext) -> Self {
560 Self { ctx }
561 }
562
563 #[inline]
564 pub const fn tracing(&self) -> bool {
565 self.ctx.tracing()
566 }
567
568 pub fn mr(&self, addr: u64) -> u64 {
570 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
571 let word_address = addr / 8;
573
574 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
575 let ptr = unsafe { ptr.add(word_address as usize) };
576
577 let entry = unsafe { std::ptr::read(ptr) };
580
581 if self.tracing() {
582 unsafe {
583 self.ctx.trace_mem_access(&[entry]);
584
585 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
587 std::ptr::write(ptr, new_entry);
588 }
589 }
590
591 entry.value
592 }
593
594 pub fn mw(&mut self, addr: u64, val: u64) {
596 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
597
598 let word_address = addr / 8;
600
601 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
602 let ptr = unsafe { ptr.add(word_address as usize) };
603
604 let value = MemValue { value: val, clk: self.ctx.clk };
606
607 if self.tracing() {
609 unsafe {
610 let current_entry = std::ptr::read(ptr);
612 self.ctx.trace_mem_access(&[current_entry, value]);
613 }
614 }
615
616 unsafe { std::ptr::write(ptr, value) };
619 }
620
621 pub fn mr_slice(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
623 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
624
625 let word_address = addr / 8;
627
628 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
629 let ptr = unsafe { ptr.add(word_address as usize) };
630
631 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
634
635 if self.tracing() {
636 unsafe {
637 self.ctx.trace_mem_access(slice);
638
639 for (i, entry) in slice.iter().enumerate() {
641 let new_entry = MemValue { value: entry.value, clk: self.ctx.clk };
642 std::ptr::write(ptr.add(i), new_entry)
643 }
644 }
645 }
646
647 slice.iter().map(|val| &val.value)
648 }
649
650 pub fn mr_slice_unsafe(&self, addr: u64, len: usize) -> impl IntoIterator<Item = &u64> + Clone {
652 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
653
654 let word_address = addr / 8;
656
657 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
658 let ptr = unsafe { ptr.add(word_address as usize) };
659
660 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
663
664 if self.tracing() {
665 unsafe {
666 self.ctx.trace_mem_access(slice);
667 }
668 }
669
670 slice.iter().map(|val| &val.value)
671 }
672
673 pub fn mw_slice(&mut self, addr: u64, vals: &[u64]) {
675 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
676
677 let word_address = addr / 8;
679
680 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
681 let ptr = unsafe { ptr.add(word_address as usize) };
682
683 let values = vals.iter().map(|val| MemValue { value: *val, clk: self.ctx.clk });
685
686 if self.tracing() {
689 unsafe {
690 let current_entries = std::slice::from_raw_parts(ptr, vals.len());
691
692 for (curr, new) in current_entries.iter().zip(values.clone()) {
693 self.ctx.trace_mem_access(&[*curr, new]);
694 }
695 }
696 }
697
698 for (i, val) in values.enumerate() {
699 unsafe { std::ptr::write(ptr.add(i), val) };
700 }
701 }
702
703 pub fn mr_slice_no_trace(
705 &self,
706 addr: u64,
707 len: usize,
708 ) -> impl IntoIterator<Item = &u64> + Clone {
709 debug_assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
710
711 let word_address = addr / 8;
713
714 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
715 let ptr = unsafe { ptr.add(word_address as usize) };
716
717 let slice = unsafe { std::slice::from_raw_parts(ptr, len) };
720
721 slice.iter().map(|val| &val.value)
722 }
723
724 pub fn mw_hint(&mut self, addr: u64, val: u64) {
726 let words = addr / 8;
727
728 let ptr = self.ctx.memory.as_ptr() as *mut MemValue;
729 let ptr = unsafe { ptr.add(words as usize) };
730
731 let new_entry = MemValue { value: val, clk: 0 };
732 unsafe { std::ptr::write(ptr, new_entry) };
733 }
734}