sp1_jit/lib.rs
1#![cfg_attr(not(sp1_native_executor_available), allow(unused))]
2
3#[cfg(not(target_endian = "little"))]
4compile_error!("This crate is only supported on little endian targets.");
5
6pub mod backends;
7pub mod context;
8pub mod debug;
9pub mod instructions;
10mod macros;
11pub mod memory;
12pub mod risc;
13pub mod shm;
14
15use dynasmrt::ExecutableBuffer;
16use hashbrown::HashMap;
17use std::{
18 collections::VecDeque,
19 io,
20 ops::{Deref, DerefMut},
21 os::fd::AsRawFd,
22 ptr::NonNull,
23 sync::{mpsc, Arc},
24};
25
26pub use backends::*;
27pub use context::*;
28pub use instructions::*;
29pub use risc::*;
30
31/// A function that accepts the memory pointer.
32pub type ExternFn = extern "C" fn(*mut JitContext);
33
34pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
35
36/// A debugging utility to inspect registers
37pub type DebugFn = extern "C" fn(u64);
38
39/// A transpiler for risc32 instructions.
40///
41/// This trait is implemented for each target architecture supported by the JIT transpiler.
42///
43/// The transpiler is responsible for translating the risc32 instructions into the target
44/// architecture's instruction set.
45///
46/// This transpiler should generate an entrypoint of the form: [`fn(*mut JitContext)`]
47///
48/// For each instruction, you will typically want to call [`SP1RiscvTranspiler::start_instr`]
49/// before transpiling the instruction. This maps a "riscv instruction index" to some physical
50/// native address, as there are multiple native instructions per riscv instruction.
51///
52/// You will also likely want to call [`SP1RiscvTranspiler::bump_clk`] to increment the clock
53/// counter, and [`SP1RiscvTranspiler::set_pc`] to set the PC.
54///
55/// # Note
56/// Some instructions will directly modify the PC, such as [`SP1RiscvTranspiler::jal`] and
57/// [`SP1RiscvTranspiler::jalr`], and all the branch instructions, for these instructions, you would
58/// not want to call [`SP1RiscvTranspiler::set_pc`] as it will be called for you.
59///
60///
61/// ```rust,no_run,ignore
62/// pub fn add_program() {
63/// let mut transpiler = SP1RiscvTranspiler::new(program_size, memory_size, trace_buf_size, 100, 100).unwrap();
64///
65/// // Transpile the first instruction.
66/// transpiler.start_instr();
67/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
68/// transpiler.end_instr();
69///
70/// // Transpile the second instruction.
71/// transpiler.start_instr();
72///
73/// transpiler.add(RiscOperand::Reg(RiscRegister::A), RiscOperand::Reg(RiscRegister::B), RiscRegister::C);
74/// transpiler.end_instr();
75///
76/// let mut func = transpiler.finalize();
77///
78/// // Call the function.
79/// let traces = func.call();
80///
81/// // do stuff with the traces.
82/// }
83/// ```
84pub trait RiscvTranspiler:
85 TraceCollector
86 + ComputeInstructions
87 + ControlFlowInstructions
88 + MemoryInstructions
89 + SystemInstructions
90 + Sized
91{
92 /// Create a new transpiler.
93 ///
94 /// The program is used for the jump-table and is not a hard limit on the size of the program.
95 /// The memory size is the exact amount that will be allocated for the program.
96 fn new(
97 program_size: usize,
98 memory_size: usize,
99 max_trace_size: u64,
100 pc_start: u64,
101 pc_base: u64,
102 clk_bump: u64,
103 ) -> Result<Self, std::io::Error>;
104
105 /// Register a rust function of the form [`EcallHandler`] that will be used as the ECALL.
106 fn register_ecall_handler(&mut self, handler: EcallHandler);
107
108 /// Populates a jump table entry for the current instruction being transpiled.
109 ///
110 /// Effectively should create a mapping from RISCV PC -> absolute address of the instruction.
111 ///
112 /// This method should be called for "each pc" in the program.
113 fn start_instr(&mut self);
114
115 /// This method should be called for "each pc" in the program.
116 /// Handle logics when finishing execution of an instruction such as bumping clk and jump to
117 /// branch destination.
118 fn end_instr(&mut self);
119
120 /// Inspcet a [RiscRegister] using a function pointer.
121 ///
122 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
123 fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
124
125 /// Print an immediate value.
126 ///
127 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
128 fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
129
130 /// Call an [ExternFn] from the outputted assembly.
131 ///
132 /// Implementors should ensure that [`RiscvTranspiler::start_instr`] is called before this.
133 fn call_extern_fn(&mut self, handler: ExternFn);
134
135 /// Returns the function pointer to the generated code.
136 ///
137 /// This function is expected to be of the form: `fn(*mut JitContext)`.
138 fn finalize<M: JitMemory>(self) -> io::Result<JitFunction<M>>;
139}
140
141/// A trait the collects traces, in the form [TraceChunk].
142///
143/// This type is expected to follow the conventions as described in the [TraceChunk] documentation.
144pub trait TraceCollector {
145 /// Write the current state of the registers into the trace buf.
146 ///
147 /// For SP1 this is only called once in the beginning of a "chunk".
148 fn trace_registers(&mut self);
149
150 /// Write the value located at rs1 + imm into the trace buf.
151 fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
152
153 /// Write the start pc of the trace chunk.
154 fn trace_pc_start(&mut self);
155
156 /// Write the start clk of the trace chunk.
157 fn trace_clk_start(&mut self);
158
159 /// Write the end clk of the trace chunk.
160 fn trace_clk_end(&mut self);
161}
162
163pub trait Debuggable {
164 fn print_ctx(&mut self);
165}
166
167/// A trait representing JIT memory.
168pub trait JitMemory: Sized + Deref<Target = [u8]> + DerefMut + AsRawFd {
169 fn new(memory_size: usize) -> Self;
170}
171
172/// A JIT memory that is also resetable
173pub trait JitResetableMemory: JitMemory {
174 fn reset(&mut self);
175}
176
177impl<T: RiscvTranspiler> Debuggable for T {
178 // Useful only for debugging.
179 fn print_ctx(&mut self) {
180 extern "C" fn print_ctx(ctx: *mut JitContext) {
181 let ctx = unsafe { &mut *ctx };
182 eprintln!("pc: {:x}", ctx.pc);
183 eprintln!("clk: {}", ctx.clk);
184 eprintln!("{:?}", *ctx.registers());
185 }
186
187 self.call_extern_fn(print_ctx);
188 }
189}
190
191#[cfg(not(sp1_native_executor_available))]
192/// Stub implementation for non-linux targets to compile.
193pub struct JitFunction<M> {
194 _marker: std::marker::PhantomData<M>,
195}
196
197/// A type representing a JIT compiled function.
198///
199/// The underlying function should be of the form [`fn(*mut JitContext)`].
200#[cfg(sp1_native_executor_available)]
201pub struct JitFunction<M> {
202 jump_table: Vec<*const u8>,
203 code: ExecutableBuffer,
204
205 /// The initial memory image.
206 initial_memory_image: Arc<HashMap<u64, u64>>,
207 pc_start: u64,
208 input_buffer: VecDeque<Vec<u8>>,
209
210 /// A stream of public values from the program (global to entire program).
211 pub public_values_stream: Vec<u8>,
212
213 /// Memory structure,
214 pub memory: M,
215
216 /// During execution, the hints are read by the program, and we store them here.
217 /// This is effectively a mapping from start address to the value of the hint.
218 pub hints: Vec<(u64, Vec<u8>)>,
219
220 pub pc: u64,
221 pub registers: [u64; 32],
222 pub clk: u64,
223 pub global_clk: u64,
224 pub exit_code: u32,
225
226 /// The public value digest words emitted by `COMMIT` syscalls.
227 pub public_value_digest: [u32; context::PUBLIC_VALUE_DIGEST_WORDS],
228
229 pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
230}
231
232unsafe impl<M: Send> Send for JitFunction<M> {}
233
234#[cfg(sp1_native_executor_available)]
235impl<M: JitMemory> JitFunction<M> {
236 pub(crate) fn new(
237 code: ExecutableBuffer,
238 jump_table: Vec<usize>,
239 memory_size: usize,
240 pc_start: u64,
241 ) -> std::io::Result<Self> {
242 // Adjust the jump table to be absolute addresses.
243 let buf_ptr = code.as_ptr();
244 let jump_table =
245 jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
246
247 let memory = M::new(memory_size);
248
249 Ok(Self {
250 jump_table,
251 code,
252 memory,
253 pc: pc_start,
254 clk: 1,
255 global_clk: 0,
256 registers: [0; 32],
257 initial_memory_image: Arc::new(HashMap::new()),
258 pc_start,
259 input_buffer: VecDeque::new(),
260 hints: Vec::new(),
261 public_values_stream: Vec::new(),
262 debug_sender: None,
263 exit_code: 0,
264 public_value_digest: [0; context::PUBLIC_VALUE_DIGEST_WORDS],
265 })
266 }
267
268 /// Write the initial memory image to the JIT memory.
269 ///
270 /// # Panics
271 ///
272 /// Panics if the PC is not the starting PC.
273 pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
274 assert!(
275 self.pc == self.pc_start,
276 "The initial memory should only be supplied before using the JIT function."
277 );
278
279 self.initial_memory_image = memory;
280 self.insert_memory_image();
281 }
282
283 /// Push an input to the input buffer.
284 ///
285 /// # Panics
286 ///
287 /// Panics if the PC is not the starting PC.
288 pub fn push_input(&mut self, input: Vec<u8>) {
289 assert!(
290 self.pc == self.pc_start,
291 "The input buffer should only be supplied before using the JIT function."
292 );
293
294 self.input_buffer.push_back(input);
295
296 self.hints.reserve(1);
297 }
298
299 /// Set the entire input buffer.
300 ///
301 /// # Panics
302 ///
303 /// Panics if the PC is not the starting PC.
304 pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
305 assert!(
306 self.pc == self.pc_start,
307 "The input buffer should only be supplied before using the JIT function."
308 );
309
310 // Reserve the space for the hints.
311 self.hints.reserve(input.len());
312 self.input_buffer = input;
313 }
314
315 /// Call the function, returning the trace buffer, starting at the starting PC of the program.
316 ///
317 /// If the PC is 0, then the program has completed and we return None.
318 ///
319 /// # SAFETY
320 /// Relies on the builder to emit valid assembly
321 /// and that the pointer is valid for the duration of the function call.
322 pub unsafe fn call(&mut self, trace_buf_ptr: *mut u8) {
323 if self.pc == 1 {
324 return;
325 }
326
327 let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
328
329 // Ensure the memory pointer is aligned to the alignment of the u64.
330 let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
331 let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
332 let tracing = !trace_buf_ptr.is_null();
333
334 // SAFETY:
335 // - The jump table is valid for the duration of the function call, its owned by self.
336 // - The memory is valid for the duration of the function call, its owned by self.
337 // - The trace buf is valid for the duration of the function call, we just allocated it
338 // - The input buffer is valid for the duration of the function call, its owned by self.
339 let mut ctx = JitContext {
340 jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
341 memory: NonNull::new_unchecked(mem_ptr),
342 trace_buf: trace_buf_ptr,
343 input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
344 hints: NonNull::new_unchecked(&mut self.hints),
345 maybe_unconstrained: None,
346 public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
347 memory_fd: self.memory.as_raw_fd(),
348 registers: self.registers,
349 pc: self.pc,
350 clk: self.clk,
351 global_clk: self.global_clk,
352 is_unconstrained: 0,
353 tracing,
354 debug_sender: self.debug_sender.clone(),
355 exit_code: self.exit_code,
356 public_value_digest: self.public_value_digest,
357 };
358
359 tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
360 as_fn(&mut ctx);
361 });
362
363 // Update the values we want to preserve.
364 self.pc = ctx.pc;
365 self.registers = ctx.registers;
366 self.clk = ctx.clk;
367 self.global_clk = ctx.global_clk;
368 self.exit_code = ctx.exit_code;
369 self.public_value_digest = ctx.public_value_digest;
370 }
371
372 fn insert_memory_image(&mut self) {
373 for (addr, val) in self.initial_memory_image.iter() {
374 // Technically, this crate is probably only used on little endian targets, but just to
375 // sure.
376 let bytes = val.to_le_bytes();
377
378 #[cfg(debug_assertions)]
379 if addr % 8 > 0 {
380 panic!("Address {addr} is not aligned to 8");
381 }
382
383 let actual_addr = 2 * addr + 8;
384 unsafe {
385 std::ptr::copy_nonoverlapping(
386 bytes.as_ptr(),
387 self.memory.as_mut_ptr().add(actual_addr as usize),
388 bytes.len(),
389 )
390 };
391 }
392 }
393}
394
395#[cfg(sp1_native_executor_available)]
396impl<M: JitResetableMemory> JitFunction<M> {
397 /// Reset the JIT function to the initial state.
398 ///
399 /// This will clear the registers, the program counter, the clock, and the memory, restoring the
400 /// initial memory image.
401 pub fn reset(&mut self) {
402 self.pc = self.pc_start;
403 self.registers = [0; 32];
404 self.clk = 1;
405 self.global_clk = 0;
406 self.input_buffer = VecDeque::new();
407 self.hints = Vec::new();
408 self.public_values_stream = Vec::new();
409 self.public_value_digest = [0; context::PUBLIC_VALUE_DIGEST_WORDS];
410 self.memory.reset();
411
412 self.insert_memory_image();
413 }
414}