sp1-jit 6.0.2

JIT compilation for SP1 trace generation
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
#![cfg_attr(not(target_os = "linux"), allow(unused))]

#[cfg(not(target_endian = "little"))]
compile_error!("This crate is only supported on little endian targets.");

pub mod backends;
pub mod context;
pub mod debug;
pub mod instructions;
mod macros;
pub mod risc;

use dynasmrt::ExecutableBuffer;
use hashbrown::HashMap;
use memmap2::{MmapMut, MmapOptions};
use std::{
    collections::VecDeque,
    io,
    os::fd::AsRawFd,
    ptr::NonNull,
    sync::{mpsc, Arc},
};

pub use backends::*;
pub use context::*;
pub use instructions::*;
pub use risc::*;

/// A function that accepts the memory pointer.
pub type ExternFn = extern "C" fn(*mut JitContext);

pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;

/// A debugging utility to inspect registers
pub type DebugFn = extern "C" fn(u64);

/// A transpiler for risc32 instructions.
///
/// This trait is implemented for each target architecture supported by the JIT transpiler.
///
/// The transpiler is responsible for translating the risc32 instructions into the target
/// architecture's instruction set.
///
/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
///
/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
/// native address, as there are multiple native instructions per riscv instruction.
///
/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
///
/// # Note
/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
///
///
/// ```rust,no_run,ignore
/// pub fn add_program() {
///     let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
///      
///     // Transpile the first instruction.
///     transpiler.start_instr();
///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
///     transpiler.end_instr();
///     
///     // Transpile the second instruction.
///     transpiler.start_instr();
///
///     transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
///     transpiler.end_instr();
///     
///     let mut func = transpiler.finalize();
///
///     // Call the function.
///     let traces = func.call();
///
///     // do stuff with the traces.
/// }
/// ```
pub trait RiscvTranspiler:
    TraceCollector
    + ComputeInstructions
    + ControlFlowInstructions
    + MemoryInstructions
    + SystemInstructions
    + Sized
{
    /// Create a new transpiler.
    ///
    /// The program is used for the jump-table and is not a hard limit on the size of the program.
    /// The memory size is the exact amount that will be allocated for the program.
    fn new(
        program_size: usize,
        memory_size: usize,
        max_trace_size: u64,
        pc_start: u64,
        pc_base: u64,
        clk_bump: u64,
    ) -> Result<Self, std::io::Error>;

    /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
    fn register_ecall_handler(&mut self, handler: EcallHandler);

    /// Populates a jump table entry for the current instruction being transpiled.
    ///
    /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
    ///
    /// This method should be called for "each pc" in the program.
    fn start_instr(&mut self);

    /// This method should be called for "each pc" in the program.
    /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
    /// branch destination.
    fn end_instr(&mut self);

    /// Inspcet a [RiscRegister] using a function pointer.
    ///
    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
    fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);

    /// Print an immediate value.
    ///
    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
    fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);

    /// Call an [ExternFn] from the outputted assembly.
    ///
    /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
    fn call_extern_fn(&mut self, handler: ExternFn);

    /// Returns the function pointer to the generated code.
    ///
    /// This function is expected to be of the form: `fn(*mut JitContext)`.
    fn finalize(self) -> io::Result<JitFunction>;
}

/// A trait the collects traces, in the form [TraceChunk].
///
/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
pub trait TraceCollector {
    /// Write the current state of the registers into the trace buf.
    ///
    /// For SP1 this is only called once in the beginning of a "chunk".
    fn trace_registers(&mut self);

    /// Write the value located at rs1 + imm into the trace buf.
    fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);

    /// Write the start pc of the trace chunk.
    fn trace_pc_start(&mut self);

    /// Write the start clk of the trace chunk.
    fn trace_clk_start(&mut self);

    /// Write the end clk of the trace chunk.
    fn trace_clk_end(&mut self);
}

pub trait Debuggable {
    fn print_ctx(&mut self);
}

impl<T: RiscvTranspiler> Debuggable for T {
    // Useful only for debugging.
    fn print_ctx(&mut self) {
        extern "C" fn print_ctx(ctx: *mut JitContext) {
            let ctx = unsafe { &mut *ctx };
            eprintln!("pc: {:x}", ctx.pc);
            eprintln!("clk: {}", ctx.clk);
            eprintln!("{:?}", *ctx.registers());
        }

        self.call_extern_fn(print_ctx);
    }
}

#[cfg(not(target_os = "linux"))]
/// Stub implementation for non-linux targets to compile.
pub struct JitFunction {}

/// A type representing a JIT compiled function.
///
/// The underlying function should be of the form [`fn(*mut JitContext)`].
#[cfg(target_os = "linux")]
pub struct JitFunction {
    jump_table: Vec<*const u8>,
    trace_buf_size: usize,
    code: ExecutableBuffer,

    /// The initial memory image.
    initial_memory_image: Arc<HashMap<u64, u64>>,
    pc_start: u64,
    input_buffer: VecDeque<Vec<u8>>,

    /// A stream of public values from the program (global to entire program).
    pub public_values_stream: Vec<u8>,

    /// Keep around the memfd, and pass it to the JIT context,
    /// we can use this to create the COW memory at runtime.
    mem_fd: memfd::Memfd,

    /// During execution, the hints are read by the program, and we store them here.
    /// This is effectively a mapping from start address to the value of the hint.
    pub hints: Vec<(u64, Vec<u8>)>,

    /// The JIT function may stop "in the middle" of an program,
    /// we want to be able to resume it, so this is the information needed to do so.
    pub memory: MmapMut,
    pub pc: u64,
    pub registers: [u64; 32],
    pub clk: u64,
    pub global_clk: u64,
    pub exit_code: u32,

    pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
}

unsafe impl Send for JitFunction {}

#[cfg(target_os = "linux")]
impl JitFunction {
    pub(crate) fn new(
        code: ExecutableBuffer,
        jump_table: Vec<usize>,
        memory_size: usize,
        trace_buf_size: usize,
        pc_start: u64,
    ) -> std::io::Result<Self> {
        // Adjust the jump table to be absolute addresses.
        let buf_ptr = code.as_ptr();
        let jump_table =
            jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();

        let fd = memfd::MemfdOptions::default()
            .create(uuid::Uuid::new_v4().to_string())
            .expect("Failed to create jit memory");

        fd.as_file().set_len((memory_size + std::mem::align_of::<u64>()) as u64)?;

        Ok(Self {
            jump_table,
            code,
            memory: unsafe { MmapOptions::new().no_reserve_swap().map_mut(fd.as_file())? },
            mem_fd: fd,
            trace_buf_size,
            pc: pc_start,
            clk: 1,
            global_clk: 0,
            registers: [0; 32],
            initial_memory_image: Arc::new(HashMap::new()),
            pc_start,
            input_buffer: VecDeque::new(),
            hints: Vec::new(),
            public_values_stream: Vec::new(),
            debug_sender: None,
            exit_code: 0,
        })
    }

    /// Write the initial memory image to the JIT memory.
    ///
    /// # Panics
    ///
    /// Panics if the PC is not the starting PC.
    pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
        assert!(
            self.pc == self.pc_start,
            "The initial memory should only be supplied before using the JIT function."
        );

        self.initial_memory_image = memory;
        self.insert_memory_image();
    }

    /// Push an input to the input buffer.
    ///
    /// # Panics
    ///
    /// Panics if the PC is not the starting PC.
    pub fn push_input(&mut self, input: Vec<u8>) {
        assert!(
            self.pc == self.pc_start,
            "The input buffer should only be supplied before using the JIT function."
        );

        self.input_buffer.push_back(input);

        self.hints.reserve(1);
    }

    /// Set the entire input buffer.
    ///
    /// # Panics
    ///
    /// Panics if the PC is not the starting PC.
    pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
        assert!(
            self.pc == self.pc_start,
            "The input buffer should only be supplied before using the JIT function."
        );

        // Reserve the space for the hints.
        self.hints.reserve(input.len());
        self.input_buffer = input;
    }

    /// Call the function, returning the trace buffer, starting at the starting PC of the program.
    ///
    /// If the PC is 0, then the program has completed and we return None.
    ///
    /// # SAFETY
    /// Relies on the builder to emit valid assembly
    /// and that the pointer is valid for the duration of the function call.
    pub unsafe fn call(&mut self) -> Option<TraceChunkRaw> {
        if self.pc == 1 {
            return None;
        }

        let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());

        // Ensure the pointer is aligned to the alignment of the MemValue.
        let mut trace_buf =
            MmapMut::map_anon(self.trace_buf_size + std::mem::align_of::<MemValue>())
                .expect("Failed to create trace buf mmap");
        let trace_buf_offset = trace_buf.as_ptr().align_offset(std::mem::align_of::<MemValue>());
        let trace_buf_ptr = trace_buf.as_mut_ptr().add(trace_buf_offset);

        // Ensure the memory pointer is aligned to the alignment of the u64.
        let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
        let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
        let tracing = self.trace_buf_size > 0;

        // SAFETY:
        // - The jump table is valid for the duration of the function call, its owned by self.
        // - The memory is valid for the duration of the function call, its owned by self.
        // - The trace buf is valid for the duration of the function call, we just allocated it
        // - The input buffer is valid for the duration of the function call, its owned by self.
        let mut ctx = JitContext {
            jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
            memory: NonNull::new_unchecked(mem_ptr),
            trace_buf: NonNull::new_unchecked(trace_buf_ptr),
            input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
            hints: NonNull::new_unchecked(&mut self.hints),
            maybe_unconstrained: None,
            public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
            memory_fd: self.mem_fd.as_raw_fd(),
            registers: self.registers,
            pc: self.pc,
            clk: self.clk,
            global_clk: self.global_clk,
            is_unconstrained: 0,
            tracing,
            debug_sender: self.debug_sender.clone(),
            exit_code: self.exit_code,
        };

        tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
            as_fn(&mut ctx);
        });

        // Update the values we want to preserve.
        self.pc = ctx.pc;
        self.registers = ctx.registers;
        self.clk = ctx.clk;
        self.global_clk = ctx.global_clk;
        self.exit_code = ctx.exit_code;

        tracing.then_some(TraceChunkRaw::new(
            trace_buf.make_read_only().expect("Failed to make trace buf read only"),
        ))
    }

    /// Reset the JIT function to the initial state.
    ///
    /// This will clear the registers, the program counter, the clock, and the memory, restoring the
    /// initial memory image.
    pub fn reset(&mut self) {
        self.pc = self.pc_start;
        self.registers = [0; 32];
        self.clk = 1;
        self.global_clk = 0;
        self.input_buffer = VecDeque::new();
        self.hints = Vec::new();
        self.public_values_stream = Vec::new();

        // Store the original size of the memory.
        let memory_size = self.memory.len();

        // Create a new memfd for the backing memory.
        self.mem_fd = memfd::MemfdOptions::default()
            .create(uuid::Uuid::new_v4().to_string())
            .expect("Failed to create jit memory");

        self.mem_fd
            .as_file()
            .set_len(memory_size as u64)
            .expect("Failed to set memfd size for backing memory.");

        self.memory = unsafe {
            MmapOptions::new()
                .no_reserve_swap()
                .map_mut(self.mem_fd.as_file())
                .expect("Failed to map memory")
        };

        self.insert_memory_image();
    }

    fn insert_memory_image(&mut self) {
        for (addr, val) in self.initial_memory_image.iter() {
            // Technically, this crate is probably only used on little endian targets, but just to
            // sure.
            let bytes = val.to_le_bytes();

            #[cfg(debug_assertions)]
            if addr % 8 > 0 {
                panic!("Address {addr} is not aligned to 8");
            }

            let actual_addr = 2 * addr + 8;
            unsafe {
                std::ptr::copy_nonoverlapping(
                    bytes.as_ptr(),
                    self.memory.as_mut_ptr().add(actual_addr as usize),
                    bytes.len(),
                )
            };
        }
    }
}

pub struct MemoryView<'a> {
    pub memory: &'a MmapMut,
}

impl<'a> MemoryView<'a> {
    pub const fn new(memory: &'a MmapMut) -> Self {
        Self { memory }
    }

    /// Read a word from the memory at the address.
    ///
    /// # Panics
    ///
    /// Panics if the address is not aligned to 8 bytes.
    pub fn get(&self, addr: u64) -> MemValue {
        assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");

        let word_address = addr / 8;
        let entry_ptr = self.memory.as_ptr() as *mut MemValue;

        unsafe { std::ptr::read(entry_ptr.add(word_address as usize)) }
    }
}