Skip to main content

sp1_jit/
risc.rs

1use crate::shm::ConsumerGuard;
2
3use std::{marker::PhantomData, ops::Deref, sync::Arc};
4
5use memmap2::Mmap;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9#[repr(u8)]
10pub enum RiscRegister {
11    X0 = 0,
12    X1 = 1,
13    X2 = 2,
14    X3 = 3,
15    X4 = 4,
16    X5 = 5,
17    X6 = 6,
18    X7 = 7,
19    X8 = 8,
20    X9 = 9,
21    X10 = 10,
22    X11 = 11,
23    X12 = 12,
24    X13 = 13,
25    X14 = 14,
26    X15 = 15,
27    X16 = 16,
28    X17 = 17,
29    X18 = 18,
30    X19 = 19,
31    X20 = 20,
32    X21 = 21,
33    X22 = 22,
34    X23 = 23,
35    X24 = 24,
36    X25 = 25,
37    X26 = 26,
38    X27 = 27,
39    X28 = 28,
40    X29 = 29,
41    X30 = 30,
42    X31 = 31,
43}
44
45impl RiscRegister {
46    pub fn all_registers() -> &'static [RiscRegister] {
47        &[
48            RiscRegister::X0,
49            RiscRegister::X1,
50            RiscRegister::X2,
51            RiscRegister::X3,
52            RiscRegister::X4,
53            RiscRegister::X5,
54            RiscRegister::X6,
55            RiscRegister::X7,
56            RiscRegister::X8,
57            RiscRegister::X9,
58            RiscRegister::X10,
59            RiscRegister::X11,
60            RiscRegister::X12,
61            RiscRegister::X13,
62            RiscRegister::X14,
63            RiscRegister::X15,
64            RiscRegister::X16,
65            RiscRegister::X17,
66            RiscRegister::X18,
67            RiscRegister::X19,
68            RiscRegister::X20,
69            RiscRegister::X21,
70            RiscRegister::X22,
71            RiscRegister::X23,
72            RiscRegister::X24,
73            RiscRegister::X25,
74            RiscRegister::X26,
75            RiscRegister::X27,
76            RiscRegister::X28,
77            RiscRegister::X29,
78            RiscRegister::X30,
79            RiscRegister::X31,
80        ]
81    }
82}
83
84/// ALU operations can either have register or immediate operands.
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
86pub enum RiscOperand {
87    Register(RiscRegister),
88    Immediate(i32),
89}
90
91impl From<RiscRegister> for RiscOperand {
92    fn from(reg: RiscRegister) -> Self {
93        RiscOperand::Register(reg)
94    }
95}
96
97impl From<u32> for RiscOperand {
98    fn from(imm: u32) -> Self {
99        RiscOperand::Immediate(imm as i32)
100    }
101}
102
103impl From<i32> for RiscOperand {
104    fn from(imm: i32) -> Self {
105        RiscOperand::Immediate(imm)
106    }
107}
108
109impl From<u64> for RiscOperand {
110    fn from(imm: u64) -> Self {
111        RiscOperand::Immediate(imm as i32)
112    }
113}
114
115#[repr(C)]
116#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
117pub struct MemValue {
118    pub clk: u64,
119    pub value: u64,
120}
121
122/// A convience structure for getting offsets of fields in the actual [TraceChunk].
123#[repr(C)]
124pub struct TraceChunkHeader {
125    pub start_registers: [u64; 32],
126    pub pc_start: u64,
127    pub clk_start: u64,
128    pub clk_end: u64,
129    pub num_mem_reads: u64,
130    pub global_clk_end: u64,
131    // This ensures TraceChunkHeader is aligned to 16 bytes.
132    _padding: u64,
133}
134
135#[derive(Clone)]
136pub enum TraceChunkRaw {
137    Mmap(Arc<Mmap>),
138    Shm(Arc<ConsumerGuard>),
139}
140
141impl TraceChunkRaw {
142    /// # Safety
143    ///
144    /// - The mmap must be a valid [`TraceChunkHeader`].
145    /// - The mmap must contain valid [`MemValue`]s in after the header.
146    /// - The `num_mem_reads` must be the number of [`MemValue`]s in the mmap after the header.
147    pub unsafe fn new(inner: Mmap) -> Self {
148        Self::Mmap(Arc::new(inner))
149    }
150
151    /// # Safety
152    ///
153    /// See TraceChunkRaw::new for similar safety requirements
154    pub unsafe fn from_shm(inner: ConsumerGuard) -> Self {
155        Self::Shm(Arc::new(inner))
156    }
157
158    fn as_ref(&self) -> &[u8] {
159        match self {
160            Self::Mmap(mmap) => mmap.as_ref(),
161            Self::Shm(shm) => shm.deref(),
162        }
163    }
164
165    fn as_ptr(&self) -> *const u8 {
166        match self {
167            Self::Mmap(mmap) => mmap.as_ptr(),
168            Self::Shm(shm) => shm.deref().as_ptr(),
169        }
170    }
171
172    fn len(&self) -> usize {
173        match self {
174            Self::Mmap(mmap) => mmap.len(),
175            Self::Shm(shm) => shm.deref().len(),
176        }
177    }
178
179    /// Fetching global_clk when trace ends.
180    /// For now, only native executor requires this data. So we implemented
181    /// it as a method of TraceChunkRaw, not as a trait method of MinimalTrace.
182    /// Adding it to MinimalTrace would complicate SplicingVM, while SplicingVM
183    /// does not really need global_clk now.
184    pub fn global_clk_end(&self) -> u64 {
185        let offset = std::mem::offset_of!(TraceChunkHeader, global_clk_end);
186
187        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
188    }
189}
190
191impl MinimalTrace for TraceChunkRaw {
192    fn start_registers(&self) -> [u64; 32] {
193        let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
194
195        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const [u64; 32]) }
196    }
197
198    fn pc_start(&self) -> u64 {
199        let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
200
201        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
202    }
203
204    fn clk_start(&self) -> u64 {
205        let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
206
207        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
208    }
209
210    fn clk_end(&self) -> u64 {
211        let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
212
213        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
214    }
215
216    fn num_mem_reads(&self) -> u64 {
217        let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
218
219        unsafe { std::ptr::read_unaligned(self.as_ptr().add(offset) as *const u64) }
220    }
221
222    fn mem_reads(&self) -> MemReads<'_> {
223        let header_end = std::mem::size_of::<TraceChunkHeader>();
224        let len = self.num_mem_reads() as usize;
225
226        debug_assert!(self.len() - header_end >= len);
227
228        // SAFETY:
229        // - The memory is valid assuming num_mem_reads is correct.
230        // - The memory is technically always valid for reads since all bitpatterns are valid for
231        //   `MemValue`.
232        unsafe { MemReads::new(self.as_ptr().add(header_end) as *const MemValue, len) }
233    }
234}
235
236pub struct MemReads<'a> {
237    inner: *const MemValue,
238    end: *const MemValue,
239    /// Capture the lifetime of the buffer for saftey reasons.
240    _phantom: PhantomData<&'a ()>,
241}
242
243impl<'a> MemReads<'a> {
244    /// # Safety
245    ///
246    /// - The underlying memory is valid and contains valid `MemValue`s.
247    /// - The length is the number of `MemValue`s in the underlying memory.
248    pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
249        debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
250
251        Self { inner, end: inner.add(len), _phantom: PhantomData }
252    }
253
254    /// Advance the pointer by `n` elements.
255    ///
256    /// # Panics
257    ///
258    /// Panics if `n` is greater than the purported length of the underlying buffer.
259    pub fn advance(&mut self, n: usize) {
260        unsafe {
261            let advanced = self.inner.add(n);
262
263            if advanced > self.end {
264                panic!("Cannot advance by more than the length of the slice");
265            }
266
267            self.inner = advanced;
268        }
269    }
270
271    /// Get the raw pointer to the head of the slice.
272    pub fn head_raw(&self) -> *const MemValue {
273        self.inner
274    }
275
276    /// The remaining length of the slice from our current position.
277    #[must_use]
278    pub fn len(&self) -> usize {
279        unsafe { self.end.offset_from_unsigned(self.inner) }
280    }
281
282    /// Check if the iterator is empty.
283    #[must_use]
284    pub fn is_empty(&self) -> bool {
285        self.inner == self.end
286    }
287}
288
289impl<'a> Iterator for MemReads<'a> {
290    type Item = MemValue;
291
292    fn next(&mut self) -> Option<Self::Item> {
293        if self.inner == self.end {
294            None
295        } else {
296            let value = unsafe { std::ptr::read(self.inner) };
297            self.inner = unsafe { self.inner.add(1) };
298
299            Some(value)
300        }
301    }
302}
303
304/// A trace chunk is all the data needed to continue the execution of a program at
305/// pc_start/clk_start.
306///
307/// We transmute this type directly from bytes, and the buffer should be of [TraceChunkRaw] form,
308/// plus, a slice of the memory reads.
309///
310/// When we read this type from the buffer, we will copy the registers, the pc/clk start and end,
311/// and take a pointer to the memory reads, by reading the num_mem_vals field.
312///
313/// The fields should be placed in the buffer according to the layout of [TraceChunkRaw].
314#[repr(C)]
315#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
316pub struct TraceChunk {
317    pub start_registers: [u64; 32],
318    pub pc_start: u64,
319    pub clk_start: u64,
320    pub clk_end: u64,
321    #[serde(serialize_with = "ser::serialize_mem_reads")]
322    #[serde(deserialize_with = "ser::deserialize_mem_reads")]
323    pub mem_reads: Arc<[MemValue]>,
324}
325
326impl From<TraceChunkRaw> for TraceChunk {
327    fn from(raw: TraceChunkRaw) -> Self {
328        TraceChunk::copy_from_bytes(raw.as_ref())
329    }
330}
331
332impl TraceChunk {
333    /// Copy the bytes into a [TraceChunk]. We dont just back it with the original bytes,
334    /// since this type is likely to be sent off to worker for proving.
335    ///
336    /// # Note:
337    /// This method will panic if the buffer is not large enough,
338    /// or the number of reads causes an overflow.
339    pub fn copy_from_bytes(src: &[u8]) -> Self {
340        const HDR: usize = size_of::<TraceChunkHeader>();
341
342        /* ---------- 1. header must fit ---------- */
343        if src.len() < HDR {
344            panic!("TraceChunk header too small");
345        }
346
347        /* ---------- 2. copy-out the header ---------- */
348        // SAFETY:
349        // we just checked that `src` contains at least `HDR` bytes,
350        // and `read_unaligne
351        //
352        // Note: All bit patterns are valid for `TraceChunkRaw`.
353        let raw: TraceChunkHeader =
354            unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
355
356        /* ---------- 3. tail must fit ---------- */
357        let n_words = raw.num_mem_reads as usize;
358        let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
359        let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
360        if src.len() < total {
361            panic!("TraceChunk tail too small");
362        }
363
364        /* ---------- 4. extract tail ---------- */
365        let tail = &src[HDR..total]; // only after the length check
366
367        let mem_reads = Arc::new_uninit_slice(n_words);
368
369        // SAFETY:
370        // - The tail contains valid u64s, so doing a bitwise copy preserves the validity and
371        //   endianness.
372        // - tail is likely unaligned, so casting to a u8 pointer gives the alignmnt guarantee the
373        //   compiler needs to do a copy.
374        // - `mem_reads` was just allocated to have enough space.
375        // - u8 has minimum alignment, so casting the pointer allocated by the vec is valid.
376        // - The cast from const -> mut is valid since there are no other references to the memory.
377        //
378        // This trick is mostly taken from [`std::ptr::read_unaligned`]
379        // see: <https://doc.rust-lang.org/src/core/ptr/mod.rs.html#1811>.
380        unsafe {
381            std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
382        };
383
384        Self {
385            start_registers: raw.start_registers,
386            pc_start: raw.pc_start,
387            clk_start: raw.clk_start,
388            clk_end: raw.clk_end,
389            // SAFETY: We know the memory is initialized, so we can assume it.
390            mem_reads: unsafe { mem_reads.assume_init() },
391        }
392    }
393}
394
395/// A trait that represents a minimal trace.
396///
397/// A minimal trace is the minimum required information to rexecute from
398/// `pc_start` and `clk_start` -> `clk_end`.
399///
400/// It effectively acts as an oracle for the results of memory read operations.
401pub trait MinimalTrace: Clone + Send + Sync + 'static {
402    fn start_registers(&self) -> [u64; 32];
403
404    fn pc_start(&self) -> u64;
405
406    fn clk_start(&self) -> u64;
407
408    fn clk_end(&self) -> u64;
409
410    fn num_mem_reads(&self) -> u64;
411
412    fn mem_reads(&self) -> MemReads<'_>;
413}
414
415impl MinimalTrace for TraceChunk {
416    fn start_registers(&self) -> [u64; 32] {
417        self.start_registers
418    }
419
420    fn pc_start(&self) -> u64 {
421        self.pc_start
422    }
423
424    fn clk_start(&self) -> u64 {
425        self.clk_start
426    }
427
428    fn clk_end(&self) -> u64 {
429        self.clk_end
430    }
431
432    fn num_mem_reads(&self) -> u64 {
433        self.mem_reads.len() as u64
434    }
435
436    fn mem_reads(&self) -> MemReads<'_> {
437        // SAFETY:
438        // - The memory is technically always valid for reads since all bitpatterns are valid for
439        //   `MemValue`.
440        // - the length comes directly from the Vec, which we know to be valid.
441        unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
442    }
443}
444
445mod ser {
446    use super::*;
447    use serde::{Deserializer, Serializer};
448
449    pub fn serialize_mem_reads<S: Serializer>(
450        mem_reads: &Arc<[MemValue]>,
451        serializer: S,
452    ) -> Result<S::Ok, S::Error> {
453        let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
454
455        Vec::serialize(&as_vec, serializer)
456    }
457
458    pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
459        deserializer: D,
460    ) -> Result<Arc<[MemValue]>, D::Error> {
461        let as_vec = Vec::deserialize(deserializer)?;
462
463        Ok(as_vec.into())
464    }
465
466    #[test]
467    #[cfg(test)]
468    fn test_mem_reads() {
469        let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
470        let trace = TraceChunk {
471            start_registers: [5; 32],
472            pc_start: 6,
473            clk_start: 7,
474            clk_end: 8,
475            mem_reads,
476        };
477
478        let serialized = bincode::serialize(&trace).unwrap();
479        let deserialized = bincode::deserialize(&serialized).unwrap();
480
481        assert_eq!(trace, deserialized);
482    }
483}
484
485#[cfg(test)]
486mod tests {
487    use super::*;
488
489    // TraceChunkHeader must be aligned to 16 bytes
490    #[test]
491    fn test_trace_chunk_header_alignment() {
492        assert_eq!(std::mem::size_of::<TraceChunkHeader>() % 16, 0);
493    }
494}