use alloc::{collections::BTreeMap, vec::Vec};
use core::fmt::Debug;
use miden_air::trace::{
RowIndex,
chiplets::memory::{
CLK_COL_IDX, CTX_COL_IDX, D_INV_COL_IDX, D0_COL_IDX, D1_COL_IDX,
FLAG_SAME_CONTEXT_AND_WORD, IDX0_COL_IDX, IDX1_COL_IDX, IS_READ_COL_IDX,
IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_ELEMENT, MEMORY_ACCESS_WORD, MEMORY_READ,
MEMORY_WRITE, V_COL_RANGE, WORD_COL_IDX,
},
};
use super::{
super::utils::{split_element_u32_into_u16, split_u32_into_u16},
RangeChecker, TraceFragment,
};
use crate::{
ContextId, EMPTY_WORD, Felt, MemoryAddress, MemoryError, ONE, WORD_SIZE, Word, ZERO,
field::Field,
};
mod segment;
use segment::{MemoryOperation, MemorySegmentTrace};
#[cfg(test)]
mod tests;
const INIT_MEM_VALUE: Word = EMPTY_WORD;
#[derive(Debug, Default)]
pub struct Memory {
trace: BTreeMap<ContextId, MemorySegmentTrace>,
num_trace_rows: usize,
}
impl Memory {
pub fn trace_len(&self) -> usize {
self.num_trace_rows
}
pub fn get_value(&self, ctx: ContextId, addr: u32) -> Option<Felt> {
match self.trace.get(&ctx) {
Some(segment) => segment.get_value(addr),
None => None,
}
}
pub fn get_word(&self, ctx: ContextId, addr: u32) -> Result<Option<Word>, MemoryError> {
match self.trace.get(&ctx) {
Some(segment) => segment
.get_word(addr)
.map_err(|_| MemoryError::UnalignedWordAccess { addr, ctx }),
None => Ok(None),
}
}
pub fn get_state_at(&self, ctx: ContextId, clk: RowIndex) -> Vec<(MemoryAddress, Felt)> {
if clk == 0 {
return vec![];
}
match self.trace.get(&ctx) {
Some(segment) => segment.get_state_at(clk),
None => vec![],
}
}
pub fn read(&mut self, ctx: ContextId, addr: Felt, clk: RowIndex) -> Result<Felt, MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read(ctx, addr, Felt::from(clk))
}
pub fn read_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
) -> Result<Word, MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::UnalignedWordAccess { addr, ctx });
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read_word(ctx, addr, Felt::from(clk))
}
pub fn write(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Felt,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write(ctx, addr, Felt::from(clk), value)
}
pub fn write_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Word,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::UnalignedWordAccess { addr, ctx });
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write_word(ctx, addr, Felt::from(clk), value)
}
pub fn append_range_checks(&self, memory_start_row: RowIndex, range: &mut RangeChecker) {
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (ctx, addr, clk.as_canonical_u64() - 1),
None => return,
};
let mut row = memory_start_row;
for (&ctx, segment) in self.trace.iter() {
for (&addr, addr_trace) in segment.inner().iter() {
for memory_access in addr_trace {
let clk = memory_access.clk().as_canonical_u64();
let delta = if prev_ctx != ctx {
(u32::from(ctx) - u32::from(prev_ctx)).into()
} else if prev_addr != addr {
u64::from(addr - prev_addr)
} else {
clk - prev_clk
};
let (delta_hi, delta_lo) = split_u32_into_u16(delta);
range.add_range_checks(row, &[delta_lo, delta_hi]);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1_u32;
}
}
}
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (Felt::from(ctx), Felt::from_u32(addr), clk - ONE),
None => return,
};
let mut row: RowIndex = 0.into();
for (ctx, segment) in self.trace {
let ctx = Felt::from(ctx);
for (addr, addr_trace) in segment.into_inner() {
let felt_addr = Felt::from_u32(addr);
for memory_access in addr_trace {
let clk = memory_access.clk();
let value = memory_access.word();
match memory_access.operation() {
MemoryOperation::Read => trace.set(row, IS_READ_COL_IDX, MEMORY_READ),
MemoryOperation::Write => trace.set(row, IS_READ_COL_IDX, MEMORY_WRITE),
}
let (idx1, idx0) = match memory_access.access_type() {
segment::MemoryAccessType::Element { addr_idx_in_word } => {
trace.set(row, IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_ELEMENT);
match addr_idx_in_word {
0 => (ZERO, ZERO),
1 => (ZERO, ONE),
2 => (ONE, ZERO),
3 => (ONE, ONE),
_ => panic!("invalid address index in word: {addr_idx_in_word}"),
}
},
segment::MemoryAccessType::Word => {
trace.set(row, IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_WORD);
(ZERO, ZERO)
},
};
trace.set(row, CTX_COL_IDX, ctx);
trace.set(row, WORD_COL_IDX, felt_addr);
trace.set(row, IDX0_COL_IDX, idx0);
trace.set(row, IDX1_COL_IDX, idx1);
trace.set(row, CLK_COL_IDX, clk);
for (idx, col) in V_COL_RANGE.enumerate() {
trace.set(row, col, value[idx]);
}
let delta = if prev_ctx != ctx {
ctx - prev_ctx
} else if prev_addr != felt_addr {
felt_addr - prev_addr
} else {
clk - prev_clk
};
let (delta_hi, delta_lo) = split_element_u32_into_u16(delta);
trace.set(row, D0_COL_IDX, delta_lo);
trace.set(row, D1_COL_IDX, delta_hi);
trace.set(row, D_INV_COL_IDX, delta.try_inverse().unwrap_or(ZERO));
if prev_ctx == ctx && prev_addr == felt_addr {
trace.set(row, FLAG_SAME_CONTEXT_AND_WORD, ONE);
} else {
trace.set(row, FLAG_SAME_CONTEXT_AND_WORD, ZERO);
};
prev_ctx = ctx;
prev_addr = felt_addr;
prev_clk = clk;
row += 1_u32;
}
}
}
}
fn get_first_row_info(&self) -> Option<(ContextId, u32, Felt)> {
let (ctx, segment) = match self.trace.iter().next() {
Some((&ctx, segment)) => (ctx, segment),
None => return None,
};
let (&addr, addr_trace) = segment.inner().iter().next().expect("empty memory segment");
Some((ctx, addr, addr_trace[0].clk()))
}
#[cfg(test)]
pub fn num_accessed_words(&self) -> usize {
self.trace.iter().fold(0, |acc, (_, s)| acc + s.num_accessed_words())
}
}