#![cfg_attr(not(target_os = "linux"), allow(unused))]
#[cfg(not(target_endian = "little"))]
compile_error!("This crate is only supported on little endian targets.");
pub mod backends;
pub mod context;
pub mod debug;
pub mod instructions;
mod macros;
pub mod risc;
use dynasmrt::ExecutableBuffer;
use hashbrown::HashMap;
use memmap2::{MmapMut, MmapOptions};
use std::{
collections::VecDeque,
io,
os::fd::AsRawFd,
ptr::NonNull,
sync::{mpsc, Arc},
};
pub use backends::*;
pub use context::*;
pub use instructions::*;
pub use risc::*;
pub type ExternFn = extern "C" fn(*mut JitContext);
pub type EcallHandler = extern "C" fn(*mut JitContext) -> u64;
pub type DebugFn = extern "C" fn(u64);
pub trait RiscvTranspiler:
TraceCollector
+ ComputeInstructions
+ ControlFlowInstructions
+ MemoryInstructions
+ SystemInstructions
+ Sized
{
fn new(
program_size: usize,
memory_size: usize,
max_trace_size: u64,
pc_start: u64,
pc_base: u64,
clk_bump: u64,
) -> Result<Self, std::io::Error>;
fn register_ecall_handler(&mut self, handler: EcallHandler);
fn start_instr(&mut self);
fn end_instr(&mut self);
fn inspect_register(&mut self, reg: RiscRegister, handler: DebugFn);
fn inspect_immediate(&mut self, imm: u64, handler: DebugFn);
fn call_extern_fn(&mut self, handler: ExternFn);
fn finalize(self) -> io::Result<JitFunction>;
}
pub trait TraceCollector {
fn trace_registers(&mut self);
fn trace_mem_value(&mut self, rs1: RiscRegister, imm: u64);
fn trace_pc_start(&mut self);
fn trace_clk_start(&mut self);
fn trace_clk_end(&mut self);
}
pub trait Debuggable {
fn print_ctx(&mut self);
}
impl<T: RiscvTranspiler> Debuggable for T {
fn print_ctx(&mut self) {
extern "C" fn print_ctx(ctx: *mut JitContext) {
let ctx = unsafe { &mut *ctx };
eprintln!("pc: {:x}", ctx.pc);
eprintln!("clk: {}", ctx.clk);
eprintln!("{:?}", *ctx.registers());
}
self.call_extern_fn(print_ctx);
}
}
#[cfg(not(target_os = "linux"))]
pub struct JitFunction {}
#[cfg(target_os = "linux")]
pub struct JitFunction {
jump_table: Vec<*const u8>,
trace_buf_size: usize,
code: ExecutableBuffer,
initial_memory_image: Arc<HashMap<u64, u64>>,
pc_start: u64,
input_buffer: VecDeque<Vec<u8>>,
pub public_values_stream: Vec<u8>,
mem_fd: memfd::Memfd,
pub hints: Vec<(u64, Vec<u8>)>,
pub memory: MmapMut,
pub pc: u64,
pub registers: [u64; 32],
pub clk: u64,
pub global_clk: u64,
pub exit_code: u32,
pub debug_sender: Option<mpsc::SyncSender<Option<debug::State>>>,
}
unsafe impl Send for JitFunction {}
#[cfg(target_os = "linux")]
impl JitFunction {
pub(crate) fn new(
code: ExecutableBuffer,
jump_table: Vec<usize>,
memory_size: usize,
trace_buf_size: usize,
pc_start: u64,
) -> std::io::Result<Self> {
let buf_ptr = code.as_ptr();
let jump_table =
jump_table.into_iter().map(|offset| unsafe { buf_ptr.add(offset) }).collect();
let fd = memfd::MemfdOptions::default()
.create(uuid::Uuid::new_v4().to_string())
.expect("Failed to create jit memory");
fd.as_file().set_len((memory_size + std::mem::align_of::<u64>()) as u64)?;
Ok(Self {
jump_table,
code,
memory: unsafe { MmapOptions::new().no_reserve_swap().map_mut(fd.as_file())? },
mem_fd: fd,
trace_buf_size,
pc: pc_start,
clk: 1,
global_clk: 0,
registers: [0; 32],
initial_memory_image: Arc::new(HashMap::new()),
pc_start,
input_buffer: VecDeque::new(),
hints: Vec::new(),
public_values_stream: Vec::new(),
debug_sender: None,
exit_code: 0,
})
}
pub fn with_initial_memory_image(&mut self, memory: Arc<HashMap<u64, u64>>) {
assert!(
self.pc == self.pc_start,
"The initial memory should only be supplied before using the JIT function."
);
self.initial_memory_image = memory;
self.insert_memory_image();
}
pub fn push_input(&mut self, input: Vec<u8>) {
assert!(
self.pc == self.pc_start,
"The input buffer should only be supplied before using the JIT function."
);
self.input_buffer.push_back(input);
self.hints.reserve(1);
}
pub fn set_input_buffer(&mut self, input: VecDeque<Vec<u8>>) {
assert!(
self.pc == self.pc_start,
"The input buffer should only be supplied before using the JIT function."
);
self.hints.reserve(input.len());
self.input_buffer = input;
}
pub unsafe fn call(&mut self) -> Option<TraceChunkRaw> {
if self.pc == 1 {
return None;
}
let as_fn = std::mem::transmute::<*const u8, fn(*mut JitContext)>(self.code.as_ptr());
let mut trace_buf =
MmapMut::map_anon(self.trace_buf_size + std::mem::align_of::<MemValue>())
.expect("Failed to create trace buf mmap");
let trace_buf_offset = trace_buf.as_ptr().align_offset(std::mem::align_of::<MemValue>());
let trace_buf_ptr = trace_buf.as_mut_ptr().add(trace_buf_offset);
let align_offset = self.memory.as_ptr().align_offset(std::mem::align_of::<u64>());
let mem_ptr = self.memory.as_mut_ptr().add(align_offset);
let tracing = self.trace_buf_size > 0;
let mut ctx = JitContext {
jump_table: NonNull::new_unchecked(self.jump_table.as_mut_ptr()),
memory: NonNull::new_unchecked(mem_ptr),
trace_buf: NonNull::new_unchecked(trace_buf_ptr),
input_buffer: NonNull::new_unchecked(&mut self.input_buffer),
hints: NonNull::new_unchecked(&mut self.hints),
maybe_unconstrained: None,
public_values_stream: NonNull::new_unchecked(&mut self.public_values_stream),
memory_fd: self.mem_fd.as_raw_fd(),
registers: self.registers,
pc: self.pc,
clk: self.clk,
global_clk: self.global_clk,
is_unconstrained: 0,
tracing,
debug_sender: self.debug_sender.clone(),
exit_code: self.exit_code,
};
tracing::debug_span!("JIT function", pc = ctx.pc, clk = ctx.clk).in_scope(|| {
as_fn(&mut ctx);
});
self.pc = ctx.pc;
self.registers = ctx.registers;
self.clk = ctx.clk;
self.global_clk = ctx.global_clk;
self.exit_code = ctx.exit_code;
tracing.then_some(TraceChunkRaw::new(
trace_buf.make_read_only().expect("Failed to make trace buf read only"),
))
}
pub fn reset(&mut self) {
self.pc = self.pc_start;
self.registers = [0; 32];
self.clk = 1;
self.global_clk = 0;
self.input_buffer = VecDeque::new();
self.hints = Vec::new();
self.public_values_stream = Vec::new();
let memory_size = self.memory.len();
self.mem_fd = memfd::MemfdOptions::default()
.create(uuid::Uuid::new_v4().to_string())
.expect("Failed to create jit memory");
self.mem_fd
.as_file()
.set_len(memory_size as u64)
.expect("Failed to set memfd size for backing memory.");
self.memory = unsafe {
MmapOptions::new()
.no_reserve_swap()
.map_mut(self.mem_fd.as_file())
.expect("Failed to map memory")
};
self.insert_memory_image();
}
fn insert_memory_image(&mut self) {
for (addr, val) in self.initial_memory_image.iter() {
let bytes = val.to_le_bytes();
#[cfg(debug_assertions)]
if addr % 8 > 0 {
panic!("Address {addr} is not aligned to 8");
}
let actual_addr = 2 * addr + 8;
unsafe {
std::ptr::copy_nonoverlapping(
bytes.as_ptr(),
self.memory.as_mut_ptr().add(actual_addr as usize),
bytes.len(),
)
};
}
}
}
pub struct MemoryView<'a> {
pub memory: &'a MmapMut,
}
impl<'a> MemoryView<'a> {
pub const fn new(memory: &'a MmapMut) -> Self {
Self { memory }
}
pub fn get(&self, addr: u64) -> MemValue {
assert!(addr.is_multiple_of(8), "Address {addr} is not aligned to 8");
let word_address = addr / 8;
let entry_ptr = self.memory.as_ptr() as *mut MemValue;
unsafe { std::ptr::read(entry_ptr.add(word_address as usize)) }
}
}