use memmap2::{Mmap, MmapMut};
use sp1_jit::{MemValue, TraceChunkHeader};
pub(super) struct TraceChunkBuffer {
inner: MmapMut,
}
impl TraceChunkBuffer {
pub fn new(size: usize) -> Self {
assert!(
size >= std::mem::size_of::<TraceChunkHeader>(),
"Trace chunk buffer size must be at least the size of the header"
);
Self { inner: MmapMut::map_anon(size).expect("Failed to create trace buf mmap") }
}
pub unsafe fn write_start_registers(&self, start_registers: &[u64; 32]) {
unsafe {
std::ptr::copy_nonoverlapping(
start_registers.as_ptr().cast::<u8>(),
self.as_mut_ptr().add(std::mem::offset_of!(TraceChunkHeader, start_registers)),
std::mem::size_of::<[u64; 32]>(),
);
}
}
pub unsafe fn write_pc_start(&self, pc_start: u64) {
unsafe {
std::ptr::write_unaligned(
self.as_mut_ptr()
.add(std::mem::offset_of!(TraceChunkHeader, pc_start))
.cast::<u64>(),
pc_start,
);
}
}
pub unsafe fn write_clk_start(&self, clk_start: u64) {
unsafe {
std::ptr::write_unaligned(
self.as_mut_ptr()
.add(std::mem::offset_of!(TraceChunkHeader, clk_start))
.cast::<u64>(),
clk_start,
);
}
}
pub unsafe fn write_clk_end(&self, clk_end: u64) {
unsafe {
std::ptr::write_unaligned(
self.as_mut_ptr()
.add(std::mem::offset_of!(TraceChunkHeader, clk_end))
.cast::<u64>(),
clk_end,
);
}
}
pub unsafe fn write_global_clk_end(&self, global_clk_end: u64) {
unsafe {
std::ptr::write_unaligned(
self.as_mut_ptr()
.add(std::mem::offset_of!(TraceChunkHeader, global_clk_end))
.cast::<u64>(),
global_clk_end,
);
}
}
pub unsafe fn extend(&self, values: &[MemValue]) {
let num_mem_reads = std::ptr::read_unaligned(
self.as_mut_ptr().add(std::mem::offset_of!(TraceChunkHeader, num_mem_reads))
as *const u64,
);
let new_num_mem_reads =
num_mem_reads.checked_add(values.len() as u64).expect("Num mem reads too large");
assert!(
new_num_mem_reads * std::mem::size_of::<MemValue>() as u64
<= self.inner.len() as u64 - std::mem::size_of::<TraceChunkHeader>() as u64,
"Num mem reads ({new_num_mem_reads}) would exceed buffer capacity of {} entries",
(self.inner.len() - std::mem::size_of::<TraceChunkHeader>())
/ std::mem::size_of::<MemValue>()
);
std::ptr::write_unaligned(
self.as_mut_ptr()
.add(std::mem::offset_of!(TraceChunkHeader, num_mem_reads))
.cast::<u64>(),
new_num_mem_reads,
);
std::ptr::copy_nonoverlapping(
values.as_ptr().cast::<u8>(),
self.as_mut_ptr()
.add(std::mem::size_of::<TraceChunkHeader>())
.add(num_mem_reads as usize * std::mem::size_of::<MemValue>()),
std::mem::size_of_val(values),
);
}
pub fn num_mem_reads(&self) -> u64 {
unsafe {
std::ptr::read_unaligned(
self.as_mut_ptr().add(std::mem::offset_of!(TraceChunkHeader, num_mem_reads))
as *const u64,
)
}
}
fn as_mut_ptr(&self) -> *mut u8 {
self.inner.as_ptr().cast_mut()
}
}
impl From<TraceChunkBuffer> for Mmap {
fn from(buffer: TraceChunkBuffer) -> Self {
buffer.inner.make_read_only().expect("Failed to make trace buf read only")
}
}