use super::{
utils::{split_element_u32_into_u16, split_u32_into_u16},
Felt, FieldElement, RangeChecker, TraceFragment, Word, EMPTY_WORD, ONE,
};
use crate::system::ContextId;
use alloc::{collections::BTreeMap, vec::Vec};
use miden_air::trace::chiplets::memory::{
ADDR_COL_IDX, CLK_COL_IDX, CTX_COL_IDX, D0_COL_IDX, D1_COL_IDX, D_INV_COL_IDX, V_COL_RANGE,
};
mod segment;
use segment::MemorySegmentTrace;
#[cfg(test)]
mod tests;
const INIT_MEM_VALUE: Word = EMPTY_WORD;
#[derive(Default)]
pub struct Memory {
trace: BTreeMap<ContextId, MemorySegmentTrace>,
num_trace_rows: usize,
}
impl Memory {
pub fn trace_len(&self) -> usize {
self.num_trace_rows
}
pub fn get_value(&self, ctx: ContextId, addr: u32) -> Option<Word> {
match self.trace.get(&ctx) {
Some(segment) => segment.get_value(addr),
None => None,
}
}
pub fn get_old_value(&self, ctx: ContextId, addr: u32) -> Word {
self.get_value(ctx, addr).unwrap_or(INIT_MEM_VALUE)
}
pub fn get_state_at(&self, ctx: ContextId, clk: u32) -> Vec<(u64, Word)> {
if clk == 0 {
return vec![];
}
match self.trace.get(&ctx) {
Some(segment) => segment.get_state_at(clk),
None => vec![],
}
}
pub fn read(&mut self, ctx: ContextId, addr: u32, clk: u32) -> Word {
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read(addr, Felt::from(clk))
}
pub fn write(&mut self, ctx: ContextId, addr: u32, clk: u32, value: Word) {
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write(addr, Felt::from(clk), value);
}
pub fn append_range_checks(&self, memory_start_row: usize, range: &mut RangeChecker) {
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (ctx, addr, clk.as_int() - 1),
None => return,
};
let mut row = memory_start_row as u32;
for (&ctx, segment) in self.trace.iter() {
for (&addr, addr_trace) in segment.inner().iter() {
for memory_access in addr_trace {
let clk = memory_access.clk().as_int();
let delta = if prev_ctx != ctx {
(u32::from(ctx) - u32::from(prev_ctx)).into()
} else if prev_addr != addr {
(addr - prev_addr) as u64
} else {
clk - prev_clk - 1
};
let (delta_hi, delta_lo) = split_u32_into_u16(delta);
range.add_range_checks(row, &[delta_lo, delta_hi]);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1;
}
}
}
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (Felt::from(ctx), Felt::from(addr), clk - ONE),
None => return,
};
let mut row = 0;
for (ctx, segment) in self.trace {
let ctx = Felt::from(ctx);
for (addr, addr_trace) in segment.into_inner() {
let felt_addr = Felt::from(addr);
for memory_access in addr_trace {
let clk = memory_access.clk();
let value = memory_access.value();
let selectors = memory_access.op_selectors();
trace.set(row, 0, selectors[0]);
trace.set(row, 1, selectors[1]);
trace.set(row, CTX_COL_IDX, ctx);
trace.set(row, ADDR_COL_IDX, felt_addr);
trace.set(row, CLK_COL_IDX, clk);
for (idx, col) in V_COL_RANGE.enumerate() {
trace.set(row, col, value[idx]);
}
let delta = if prev_ctx != ctx {
ctx - prev_ctx
} else if prev_addr != felt_addr {
felt_addr - prev_addr
} else {
clk - prev_clk - ONE
};
let (delta_hi, delta_lo) = split_element_u32_into_u16(delta);
trace.set(row, D0_COL_IDX, delta_lo);
trace.set(row, D1_COL_IDX, delta_hi);
trace.set(row, D_INV_COL_IDX, delta.inv());
prev_ctx = ctx;
prev_addr = felt_addr;
prev_clk = clk;
row += 1;
}
}
}
}
fn get_first_row_info(&self) -> Option<(ContextId, u32, Felt)> {
let (ctx, segment) = match self.trace.iter().next() {
Some((&ctx, segment)) => (ctx, segment),
None => return None,
};
let (&addr, addr_trace) = segment.inner().iter().next().expect("empty memory segment");
Some((ctx, addr, addr_trace[0].clk()))
}
#[cfg(test)]
pub fn size(&self) -> usize {
self.trace.iter().fold(0, |acc, (_, s)| acc + s.size())
}
}