use super::{
BTreeMap, ChipletsBus, Felt, FieldElement, StarkField, TraceFragment, Vec, Word, ONE, ZERO,
};
use crate::{
range::RangeChecker,
trace::LookupTableRow,
utils::{split_element_u32_into_u16, split_u32_into_u16},
Matrix,
};
use vm_core::chiplets::memory::{
ADDR_COL_IDX, CLK_COL_IDX, CTX_COL_IDX, D0_COL_IDX, D1_COL_IDX, D_INV_COL_IDX, V_COL_RANGE,
};
mod segment;
use segment::MemorySegmentTrace;
#[cfg(test)]
mod tests;
const INIT_MEM_VALUE: Word = [ZERO; 4];
#[derive(Default)]
pub struct Memory {
trace: BTreeMap<u32, MemorySegmentTrace>,
num_trace_rows: usize,
}
impl Memory {
pub fn trace_len(&self) -> usize {
self.num_trace_rows
}
pub fn get_value(&self, ctx: u32, addr: u64) -> Option<Word> {
match self.trace.get(&ctx) {
Some(segment) => segment.get_value(addr),
None => None,
}
}
pub fn get_old_value(&self, ctx: u32, addr: u64) -> Word {
self.get_value(ctx, addr).unwrap_or(INIT_MEM_VALUE)
}
pub fn get_state_at(&self, ctx: u32, clk: u32) -> Vec<(u64, Word)> {
if clk == 0 {
return vec![];
}
match self.trace.get(&ctx) {
Some(segment) => segment.get_state_at(clk),
None => vec![],
}
}
pub fn read(&mut self, ctx: u32, addr: Felt, clk: u32) -> Word {
self.num_trace_rows += 1;
self.trace
.entry(ctx)
.or_insert_with(MemorySegmentTrace::default)
.read(addr, Felt::from(clk))
}
pub fn write(&mut self, ctx: u32, addr: Felt, clk: u32, value: Word) {
self.num_trace_rows += 1;
self.trace
.entry(ctx)
.or_insert_with(MemorySegmentTrace::default)
.write(addr, Felt::from(clk), value);
}
pub fn append_range_checks(&self, memory_start_row: usize, range: &mut RangeChecker) {
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (ctx, addr, clk.as_int() - 1),
None => return,
};
let mut row = memory_start_row as u32;
for (&ctx, segment) in self.trace.iter() {
for (&addr, addr_trace) in segment.inner().iter() {
for memory_access in addr_trace {
let clk = memory_access.clk().as_int();
let delta = if prev_ctx != ctx {
(ctx - prev_ctx) as u64
} else if prev_addr != addr {
addr - prev_addr
} else {
clk - prev_clk - 1
};
let (delta_hi, delta_lo) = split_u32_into_u16(delta);
range.add_mem_checks(row, &[delta_lo, delta_hi]);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1;
}
}
}
}
pub fn fill_trace(
self,
trace: &mut TraceFragment,
chiplets_bus: &mut ChipletsBus,
memory_start_row: usize,
) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (Felt::from(ctx), Felt::from(addr), clk - ONE),
None => return,
};
let mut row = 0;
for (ctx, segment) in self.trace {
let ctx = Felt::from(ctx);
for (addr, addr_trace) in segment.into_inner() {
let addr = Felt::new(addr);
for memory_access in addr_trace {
let clk = memory_access.clk();
let value = memory_access.value();
let selectors = memory_access.op_selectors();
trace.set(row, 0, selectors[0]);
trace.set(row, 1, selectors[1]);
trace.set(row, CTX_COL_IDX, ctx);
trace.set(row, ADDR_COL_IDX, addr);
trace.set(row, CLK_COL_IDX, clk);
for (idx, col) in V_COL_RANGE.enumerate() {
trace.set(row, col, value[idx]);
}
let delta = if prev_ctx != ctx {
ctx - prev_ctx
} else if prev_addr != addr {
addr - prev_addr
} else {
clk - prev_clk - ONE
};
let (delta_hi, delta_lo) = split_element_u32_into_u16(delta);
trace.set(row, D0_COL_IDX, delta_lo);
trace.set(row, D1_COL_IDX, delta_hi);
trace.set(row, D_INV_COL_IDX, delta.inv());
let memory_lookup =
MemoryLookup::new(memory_access.op_label(), ctx, addr, clk, value);
chiplets_bus
.provide_memory_operation(memory_lookup, (memory_start_row + row) as u32);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1;
}
}
}
}
fn get_first_row_info(&self) -> Option<(u32, u64, Felt)> {
let (ctx, segment) = match self.trace.iter().next() {
Some((&ctx, segment)) => (ctx, segment),
None => return None,
};
let (&addr, addr_trace) = segment.inner().iter().next().expect("empty memory segment");
Some((ctx, addr, addr_trace[0].clk()))
}
#[cfg(test)]
pub fn size(&self) -> usize {
self.trace.iter().fold(0, |acc, (_, s)| acc + s.size())
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct MemoryLookup {
label: u8,
ctx: Felt,
addr: Felt,
clk: Felt,
word: Word,
}
impl MemoryLookup {
pub fn new(label: u8, ctx: Felt, addr: Felt, clk: Felt, word: Word) -> Self {
Self {
label,
ctx,
addr,
clk,
word,
}
}
pub fn from_ints(label: u8, ctx: u32, addr: Felt, clk: u32, word: Word) -> Self {
Self {
label,
ctx: Felt::from(ctx),
addr,
clk: Felt::from(clk),
word,
}
}
}
impl LookupTableRow for MemoryLookup {
fn to_value<E: FieldElement<BaseField = Felt>>(
&self,
_main_trace: &Matrix<Felt>,
alphas: &[E],
) -> E {
let word_value = self
.word
.iter()
.enumerate()
.fold(E::ZERO, |acc, (j, element)| {
acc + alphas[j + 5].mul_base(*element)
});
alphas[0]
+ alphas[1].mul_base(Felt::from(self.label))
+ alphas[2].mul_base(self.ctx)
+ alphas[3].mul_base(self.addr)
+ alphas[4].mul_base(self.clk)
+ word_value
}
}