Skip to main content

sp1_jit/
risc.rs

1use std::{marker::PhantomData, sync::Arc};
2
3use memmap2::Mmap;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
7#[repr(u8)]
8pub enum RiscRegister {
9    X0 = 0,
10    X1 = 1,
11    X2 = 2,
12    X3 = 3,
13    X4 = 4,
14    X5 = 5,
15    X6 = 6,
16    X7 = 7,
17    X8 = 8,
18    X9 = 9,
19    X10 = 10,
20    X11 = 11,
21    X12 = 12,
22    X13 = 13,
23    X14 = 14,
24    X15 = 15,
25    X16 = 16,
26    X17 = 17,
27    X18 = 18,
28    X19 = 19,
29    X20 = 20,
30    X21 = 21,
31    X22 = 22,
32    X23 = 23,
33    X24 = 24,
34    X25 = 25,
35    X26 = 26,
36    X27 = 27,
37    X28 = 28,
38    X29 = 29,
39    X30 = 30,
40    X31 = 31,
41}
42
43impl RiscRegister {
44    pub fn all_registers() -> &'static [RiscRegister] {
45        &[
46            RiscRegister::X0,
47            RiscRegister::X1,
48            RiscRegister::X2,
49            RiscRegister::X3,
50            RiscRegister::X4,
51            RiscRegister::X5,
52            RiscRegister::X6,
53            RiscRegister::X7,
54            RiscRegister::X8,
55            RiscRegister::X9,
56            RiscRegister::X10,
57            RiscRegister::X11,
58            RiscRegister::X12,
59            RiscRegister::X13,
60            RiscRegister::X14,
61            RiscRegister::X15,
62            RiscRegister::X16,
63            RiscRegister::X17,
64            RiscRegister::X18,
65            RiscRegister::X19,
66            RiscRegister::X20,
67            RiscRegister::X21,
68            RiscRegister::X22,
69            RiscRegister::X23,
70            RiscRegister::X24,
71            RiscRegister::X25,
72            RiscRegister::X26,
73            RiscRegister::X27,
74            RiscRegister::X28,
75            RiscRegister::X29,
76            RiscRegister::X30,
77            RiscRegister::X31,
78        ]
79    }
80}
81
82/// ALU operations can either have register or immediate operands.
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
84pub enum RiscOperand {
85    Register(RiscRegister),
86    Immediate(i32),
87}
88
89impl From<RiscRegister> for RiscOperand {
90    fn from(reg: RiscRegister) -> Self {
91        RiscOperand::Register(reg)
92    }
93}
94
95impl From<u32> for RiscOperand {
96    fn from(imm: u32) -> Self {
97        RiscOperand::Immediate(imm as i32)
98    }
99}
100
101impl From<i32> for RiscOperand {
102    fn from(imm: i32) -> Self {
103        RiscOperand::Immediate(imm)
104    }
105}
106
107impl From<u64> for RiscOperand {
108    fn from(imm: u64) -> Self {
109        RiscOperand::Immediate(imm as i32)
110    }
111}
112
113#[repr(C)]
114#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115pub struct MemValue {
116    pub clk: u64,
117    pub value: u64,
118}
119
120/// A convience structure for getting offsets of fields in the actual [TraceChunk].
121#[repr(C)]
122pub struct TraceChunkHeader {
123    pub start_registers: [u64; 32],
124    pub pc_start: u64,
125    pub clk_start: u64,
126    pub clk_end: u64,
127    pub num_mem_reads: u64,
128}
129
130#[repr(C)]
131#[derive(Clone)]
132pub struct TraceChunkRaw {
133    inner: Arc<Mmap>,
134    hint_lens: Vec<usize>,
135}
136
137impl TraceChunkRaw {
138    /// # Safety
139    ///
140    /// - The mmap must be a valid [`TraceChunkHeader`].
141    /// - The mmap must contain valid [`MemValue`]s in after the header.
142    /// - The `num_mem_reads` must be the number of [`MemValue`]s in the mmap after the header.
143    pub unsafe fn new(inner: Mmap, hint_lens: Vec<usize>) -> Self {
144        Self { inner: Arc::new(inner), hint_lens }
145    }
146}
147
148impl MinimalTrace for TraceChunkRaw {
149    fn start_registers(&self) -> [u64; 32] {
150        let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
151
152        unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const [u64; 32]) }
153    }
154
155    fn pc_start(&self) -> u64 {
156        let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
157
158        unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
159    }
160
161    fn clk_start(&self) -> u64 {
162        let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
163
164        unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
165    }
166
167    fn clk_end(&self) -> u64 {
168        let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
169
170        unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
171    }
172
173    fn num_mem_reads(&self) -> u64 {
174        let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
175
176        unsafe { std::ptr::read_unaligned(self.inner.as_ptr().add(offset) as *const u64) }
177    }
178
179    fn mem_reads(&self) -> MemReads<'_> {
180        let header_end = std::mem::size_of::<TraceChunkHeader>();
181        let len = self.num_mem_reads() as usize;
182
183        debug_assert!(self.inner.len() - header_end >= len);
184
185        // SAFETY:
186        // - The memory is valid assuming num_mem_reads is correct.
187        // - The memory is technically always valid for reads since all bitpatterns are valid for
188        //   `MemValue`.
189        unsafe { MemReads::new(self.inner.as_ptr().add(header_end) as *const MemValue, len) }
190    }
191
192    fn hint_lens(&self) -> &[usize] {
193        &self.hint_lens
194    }
195}
196
197pub struct MemReads<'a> {
198    inner: *const MemValue,
199    end: *const MemValue,
200    /// Capture the lifetime of the buffer for saftey reasons.
201    _phantom: PhantomData<&'a ()>,
202}
203
204impl<'a> MemReads<'a> {
205    /// # Safety
206    ///
207    /// - The underlying memory is valid and contains valid `MemValue`s.
208    /// - The length is the number of `MemValue`s in the underlying memory.
209    pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
210        debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
211
212        Self { inner, end: inner.add(len), _phantom: PhantomData }
213    }
214
215    /// Advance the pointer by `n` elements.
216    ///
217    /// # Panics
218    ///
219    /// Panics if `n` is greater than the purported length of the underlying buffer.
220    pub fn advance(&mut self, n: usize) {
221        unsafe {
222            let advanced = self.inner.add(n);
223
224            if advanced > self.end {
225                panic!("Cannot advance by more than the length of the slice");
226            }
227
228            self.inner = advanced;
229        }
230    }
231
232    /// Get the raw pointer to the head of the slice.
233    pub fn head_raw(&self) -> *const MemValue {
234        self.inner
235    }
236
237    /// The remaining length of the slice from our current position.
238    #[must_use]
239    pub fn len(&self) -> usize {
240        unsafe { self.end.offset_from_unsigned(self.inner) }
241    }
242
243    /// Check if the iterator is empty.
244    #[must_use]
245    pub fn is_empty(&self) -> bool {
246        self.inner == self.end
247    }
248}
249
250impl<'a> Iterator for MemReads<'a> {
251    type Item = MemValue;
252
253    fn next(&mut self) -> Option<Self::Item> {
254        if self.inner == self.end {
255            None
256        } else {
257            let value = unsafe { std::ptr::read(self.inner) };
258            self.inner = unsafe { self.inner.add(1) };
259
260            Some(value)
261        }
262    }
263}
264
265/// A trace chunk is all the data needed to continue the execution of a program at
266/// pc_start/clk_start.
267///
268/// We transmute this type directly from bytes, and the buffer should be of [TraceChunkRaw] form,
269/// plus, a slice of the memory reads.
270///
271/// When we read this type from the buffer, we will copy the registers, the pc/clk start and end,
272/// and take a pointer to the memory reads, by reading the num_mem_vals field.
273///
274/// The fields should be placed in the buffer according to the layout of [TraceChunkRaw].
275#[repr(C)]
276#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
277pub struct TraceChunk {
278    pub start_registers: [u64; 32],
279    pub pc_start: u64,
280    pub clk_start: u64,
281    pub clk_end: u64,
282    pub hint_lens: Vec<usize>,
283    #[serde(serialize_with = "ser::serialize_mem_reads")]
284    #[serde(deserialize_with = "ser::deserialize_mem_reads")]
285    pub mem_reads: Arc<[MemValue]>,
286}
287
288impl From<TraceChunkRaw> for TraceChunk {
289    fn from(raw: TraceChunkRaw) -> Self {
290        TraceChunk::copy_from_bytes(raw.hint_lens, raw.inner.as_ref())
291    }
292}
293
294impl TraceChunk {
295    /// Copy the bytes into a [TraceChunk]. We dont just back it with the original bytes,
296    /// since this type is likely to be sent off to worker for proving.
297    ///
298    /// # Note:
299    /// This method will panic if the buffer is not large enough,
300    /// or the number of reads causes an overflow.
301    pub fn copy_from_bytes(hint_lens: Vec<usize>, src: &[u8]) -> Self {
302        const HDR: usize = size_of::<TraceChunkHeader>();
303
304        /* ---------- 1. header must fit ---------- */
305        if src.len() < HDR {
306            panic!("TraceChunk header too small");
307        }
308
309        /* ---------- 2. copy-out the header ---------- */
310        // SAFETY:
311        // we just checked that `src` contains at least `HDR` bytes,
312        // and `read_unaligne
313        //
314        // Note: All bit patterns are valid for `TraceChunkRaw`.
315        let raw: TraceChunkHeader =
316            unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
317
318        /* ---------- 3. tail must fit ---------- */
319        let n_words = raw.num_mem_reads as usize;
320        let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
321        let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
322        if src.len() < total {
323            panic!("TraceChunk tail too small");
324        }
325
326        /* ---------- 4. extract tail ---------- */
327        let tail = &src[HDR..total]; // only after the length check
328
329        let mem_reads = Arc::new_uninit_slice(n_words);
330
331        // SAFETY:
332        // - The tail contains valid u64s, so doing a bitwise copy preserves the validity and
333        //   endianness.
334        // - tail is likely unaligned, so casting to a u8 pointer gives the alignmnt guarantee the
335        //   compiler needs to do a copy.
336        // - `mem_reads` was just allocated to have enough space.
337        // - u8 has minimum alignment, so casting the pointer allocated by the vec is valid.
338        // - The cast from const -> mut is valid since there are no other references to the memory.
339        //
340        // This trick is mostly taken from [`std::ptr::read_unaligned`]
341        // see: <https://doc.rust-lang.org/src/core/ptr/mod.rs.html#1811>.
342        unsafe {
343            std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
344        };
345
346        Self {
347            start_registers: raw.start_registers,
348            pc_start: raw.pc_start,
349            clk_start: raw.clk_start,
350            clk_end: raw.clk_end,
351            hint_lens,
352            // SAFETY: We know the memory is initialized, so we can assume it.
353            mem_reads: unsafe { mem_reads.assume_init() },
354        }
355    }
356}
357
358/// A trait that represents a minimal trace.
359///
360/// A minimal trace is the minimum required information to rexecute from
361/// `pc_start` and `clk_start` -> `clk_end`.
362///
363/// It effectively acts as an oracle for the results of memory read operations.
364pub trait MinimalTrace: Clone + Send + Sync + 'static {
365    fn start_registers(&self) -> [u64; 32];
366
367    fn pc_start(&self) -> u64;
368
369    fn clk_start(&self) -> u64;
370
371    fn clk_end(&self) -> u64;
372
373    fn num_mem_reads(&self) -> u64;
374
375    fn mem_reads(&self) -> MemReads<'_>;
376
377    fn hint_lens(&self) -> &[usize];
378}
379
380impl MinimalTrace for TraceChunk {
381    fn start_registers(&self) -> [u64; 32] {
382        self.start_registers
383    }
384
385    fn pc_start(&self) -> u64 {
386        self.pc_start
387    }
388
389    fn clk_start(&self) -> u64 {
390        self.clk_start
391    }
392
393    fn clk_end(&self) -> u64 {
394        self.clk_end
395    }
396
397    fn num_mem_reads(&self) -> u64 {
398        self.mem_reads.len() as u64
399    }
400
401    fn mem_reads(&self) -> MemReads<'_> {
402        // SAFETY:
403        // - The memory is technically always valid for reads since all bitpatterns are valid for
404        //   `MemValue`.
405        // - the length comes directly from the Vec, which we know to be valid.
406        unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
407    }
408
409    fn hint_lens(&self) -> &[usize] {
410        &self.hint_lens
411    }
412}
413
414mod ser {
415    use super::*;
416    use serde::{Deserializer, Serializer};
417
418    pub fn serialize_mem_reads<S: Serializer>(
419        mem_reads: &Arc<[MemValue]>,
420        serializer: S,
421    ) -> Result<S::Ok, S::Error> {
422        let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
423
424        Vec::serialize(&as_vec, serializer)
425    }
426
427    pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
428        deserializer: D,
429    ) -> Result<Arc<[MemValue]>, D::Error> {
430        let as_vec = Vec::deserialize(deserializer)?;
431
432        Ok(as_vec.into())
433    }
434
435    #[test]
436    #[cfg(test)]
437    fn test_mem_reads() {
438        let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
439        let trace = TraceChunk {
440            start_registers: [5; 32],
441            pc_start: 6,
442            clk_start: 7,
443            clk_end: 8,
444            hint_lens: vec![1, 2, 3],
445            mem_reads,
446        };
447
448        let serialized = bincode::serialize(&trace).unwrap();
449        let deserialized = bincode::deserialize(&serialized).unwrap();
450
451        assert_eq!(trace, deserialized);
452    }
453}