use std::{marker::PhantomData, sync::Arc};
use memmap2::Mmap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u8)]
pub enum RiscRegister {
X0 = 0,
X1 = 1,
X2 = 2,
X3 = 3,
X4 = 4,
X5 = 5,
X6 = 6,
X7 = 7,
X8 = 8,
X9 = 9,
X10 = 10,
X11 = 11,
X12 = 12,
X13 = 13,
X14 = 14,
X15 = 15,
X16 = 16,
X17 = 17,
X18 = 18,
X19 = 19,
X20 = 20,
X21 = 21,
X22 = 22,
X23 = 23,
X24 = 24,
X25 = 25,
X26 = 26,
X27 = 27,
X28 = 28,
X29 = 29,
X30 = 30,
X31 = 31,
}
impl RiscRegister {
pub fn all_registers() -> &'static [RiscRegister] {
&[
RiscRegister::X0,
RiscRegister::X1,
RiscRegister::X2,
RiscRegister::X3,
RiscRegister::X4,
RiscRegister::X5,
RiscRegister::X6,
RiscRegister::X7,
RiscRegister::X8,
RiscRegister::X9,
RiscRegister::X10,
RiscRegister::X11,
RiscRegister::X12,
RiscRegister::X13,
RiscRegister::X14,
RiscRegister::X15,
RiscRegister::X16,
RiscRegister::X17,
RiscRegister::X18,
RiscRegister::X19,
RiscRegister::X20,
RiscRegister::X21,
RiscRegister::X22,
RiscRegister::X23,
RiscRegister::X24,
RiscRegister::X25,
RiscRegister::X26,
RiscRegister::X27,
RiscRegister::X28,
RiscRegister::X29,
RiscRegister::X30,
RiscRegister::X31,
]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RiscOperand {
Register(RiscRegister),
Immediate(i32),
}
impl From<RiscRegister> for RiscOperand {
fn from(reg: RiscRegister) -> Self {
RiscOperand::Register(reg)
}
}
impl From<u32> for RiscOperand {
fn from(imm: u32) -> Self {
RiscOperand::Immediate(imm as i32)
}
}
impl From<i32> for RiscOperand {
fn from(imm: i32) -> Self {
RiscOperand::Immediate(imm)
}
}
impl From<u64> for RiscOperand {
fn from(imm: u64) -> Self {
RiscOperand::Immediate(imm as i32)
}
}
#[repr(C)]
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct MemValue {
pub clk: u64,
pub value: u64,
}
#[repr(C)]
pub struct TraceChunkHeader {
pub start_registers: [u64; 32],
pub pc_start: u64,
pub clk_start: u64,
pub clk_end: u64,
pub num_mem_reads: u64,
}
#[repr(C)]
#[derive(Clone)]
pub struct TraceChunkRaw(Arc<Mmap>);
impl TraceChunkRaw {
pub unsafe fn new(inner: Mmap) -> Self {
Self(Arc::new(inner))
}
}
impl MinimalTrace for TraceChunkRaw {
fn start_registers(&self) -> [u64; 32] {
let offset = std::mem::offset_of!(TraceChunkHeader, start_registers);
unsafe { std::ptr::read_unaligned(self.0.as_ptr().add(offset) as *const [u64; 32]) }
}
fn pc_start(&self) -> u64 {
let offset = std::mem::offset_of!(TraceChunkHeader, pc_start);
unsafe { std::ptr::read_unaligned(self.0.as_ptr().add(offset) as *const u64) }
}
fn clk_start(&self) -> u64 {
let offset = std::mem::offset_of!(TraceChunkHeader, clk_start);
unsafe { std::ptr::read_unaligned(self.0.as_ptr().add(offset) as *const u64) }
}
fn clk_end(&self) -> u64 {
let offset = std::mem::offset_of!(TraceChunkHeader, clk_end);
unsafe { std::ptr::read_unaligned(self.0.as_ptr().add(offset) as *const u64) }
}
fn num_mem_reads(&self) -> u64 {
let offset = std::mem::offset_of!(TraceChunkHeader, num_mem_reads);
unsafe { std::ptr::read_unaligned(self.0.as_ptr().add(offset) as *const u64) }
}
fn mem_reads(&self) -> MemReads<'_> {
let header_end = std::mem::size_of::<TraceChunkHeader>();
let len = self.num_mem_reads() as usize;
debug_assert!(self.0.len() - header_end >= len);
unsafe { MemReads::new(self.0.as_ptr().add(header_end) as *const MemValue, len) }
}
}
pub struct MemReads<'a> {
inner: *const MemValue,
end: *const MemValue,
_phantom: PhantomData<&'a ()>,
}
impl<'a> MemReads<'a> {
pub(crate) unsafe fn new(inner: *const MemValue, len: usize) -> Self {
debug_assert!(inner.is_aligned(), "MemReads ptr is not aligned");
Self { inner, end: inner.add(len), _phantom: PhantomData }
}
pub fn advance(&mut self, n: usize) {
unsafe {
let advanced = self.inner.add(n);
if advanced > self.end {
panic!("Cannot advance by more than the length of the slice");
}
self.inner = advanced;
}
}
pub fn head_raw(&self) -> *const MemValue {
self.inner
}
#[must_use]
pub fn len(&self) -> usize {
unsafe { self.end.offset_from_unsigned(self.inner) }
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner == self.end
}
}
impl<'a> Iterator for MemReads<'a> {
type Item = MemValue;
fn next(&mut self) -> Option<Self::Item> {
if self.inner == self.end {
None
} else {
let value = unsafe { std::ptr::read(self.inner) };
self.inner = unsafe { self.inner.add(1) };
Some(value)
}
}
}
#[repr(C)]
#[derive(Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TraceChunk {
pub start_registers: [u64; 32],
pub pc_start: u64,
pub clk_start: u64,
pub clk_end: u64,
#[serde(serialize_with = "ser::serialize_mem_reads")]
#[serde(deserialize_with = "ser::deserialize_mem_reads")]
pub mem_reads: Arc<[MemValue]>,
}
impl From<TraceChunkRaw> for TraceChunk {
fn from(raw: TraceChunkRaw) -> Self {
TraceChunk::copy_from_bytes(raw.0.as_ref())
}
}
impl TraceChunk {
pub fn copy_from_bytes(src: &[u8]) -> Self {
const HDR: usize = size_of::<TraceChunkHeader>();
if src.len() < HDR {
panic!("TraceChunk header too small");
}
let raw: TraceChunkHeader =
unsafe { core::ptr::read_unaligned(src.as_ptr() as *const TraceChunkHeader) };
let n_words = raw.num_mem_reads as usize;
let n_bytes = n_words.checked_mul(size_of::<MemValue>()).expect("Num mem reads too large");
let total = HDR.checked_add(n_bytes).expect("Num mem reads too large");
if src.len() < total {
panic!("TraceChunk tail too small");
}
let tail = &src[HDR..total];
let mem_reads = Arc::new_uninit_slice(n_words);
unsafe {
std::ptr::copy_nonoverlapping(tail.as_ptr(), mem_reads.as_ptr() as *mut u8, n_bytes)
};
Self {
start_registers: raw.start_registers,
pc_start: raw.pc_start,
clk_start: raw.clk_start,
clk_end: raw.clk_end,
mem_reads: unsafe { mem_reads.assume_init() },
}
}
}
pub trait MinimalTrace: Clone + Send + Sync + 'static {
fn start_registers(&self) -> [u64; 32];
fn pc_start(&self) -> u64;
fn clk_start(&self) -> u64;
fn clk_end(&self) -> u64;
fn num_mem_reads(&self) -> u64;
fn mem_reads(&self) -> MemReads<'_>;
}
impl MinimalTrace for TraceChunk {
fn start_registers(&self) -> [u64; 32] {
self.start_registers
}
fn pc_start(&self) -> u64 {
self.pc_start
}
fn clk_start(&self) -> u64 {
self.clk_start
}
fn clk_end(&self) -> u64 {
self.clk_end
}
fn num_mem_reads(&self) -> u64 {
self.mem_reads.len() as u64
}
fn mem_reads(&self) -> MemReads<'_> {
unsafe { MemReads::new(self.mem_reads.as_ptr(), self.mem_reads.len()) }
}
}
mod ser {
use super::*;
use serde::{Deserializer, Serializer};
pub fn serialize_mem_reads<S: Serializer>(
mem_reads: &Arc<[MemValue]>,
serializer: S,
) -> Result<S::Ok, S::Error> {
let as_vec: Vec<MemValue> = Vec::from(&mem_reads[..]);
Vec::serialize(&as_vec, serializer)
}
pub fn deserialize_mem_reads<'a, D: Deserializer<'a>>(
deserializer: D,
) -> Result<Arc<[MemValue]>, D::Error> {
let as_vec = Vec::deserialize(deserializer)?;
Ok(as_vec.into())
}
#[test]
#[cfg(test)]
fn test_mem_reads() {
let mem_reads = Arc::new([MemValue { clk: 0, value: 0 }, MemValue { clk: 1, value: 1 }]);
let trace = TraceChunk {
start_registers: [5; 32],
pc_start: 6,
clk_start: 7,
clk_end: 8,
mem_reads,
};
let serialized = bincode::serialize(&trace).unwrap();
let deserialized = bincode::deserialize(&serialized).unwrap();
assert_eq!(trace, deserialized);
}
}