use alloc::{collections::BTreeMap, vec::Vec};
use core::fmt::Debug;
use miden_air::{
RowIndex,
trace::chiplets::memory::{
CLK_COL_IDX, CTX_COL_IDX, D_INV_COL_IDX, D0_COL_IDX, D1_COL_IDX,
FLAG_SAME_CONTEXT_AND_WORD, IDX0_COL_IDX, IDX1_COL_IDX, IS_READ_COL_IDX,
IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_ELEMENT, MEMORY_ACCESS_WORD, MEMORY_READ,
MEMORY_WRITE, V_COL_RANGE, WORD_COL_IDX,
},
};
use miden_core::{WORD_SIZE, ZERO};
use super::{
EMPTY_WORD, Felt, FieldElement, ONE, RangeChecker, TraceFragment, Word,
utils::{split_element_u32_into_u16, split_u32_into_u16},
};
use crate::{MemoryAddress, errors::ErrorContext, system::ContextId};
mod errors;
pub use errors::MemoryError;
mod segment;
use segment::{MemoryOperation, MemorySegmentTrace};
#[cfg(test)]
mod tests;
const INIT_MEM_VALUE: Word = EMPTY_WORD;
#[derive(Debug, Default)]
pub struct Memory {
trace: BTreeMap<ContextId, MemorySegmentTrace>,
num_trace_rows: usize,
}
impl Memory {
pub fn trace_len(&self) -> usize {
self.num_trace_rows
}
pub fn get_value(&self, ctx: ContextId, addr: u32) -> Option<Felt> {
match self.trace.get(&ctx) {
Some(segment) => segment.get_value(addr),
None => None,
}
}
pub fn get_word(&self, ctx: ContextId, addr: u32) -> Result<Option<Word>, MemoryError> {
match self.trace.get(&ctx) {
Some(segment) => segment
.get_word(addr)
.map_err(|_| MemoryError::UnalignedWordAccessNoClk { addr, ctx }),
None => Ok(None),
}
}
pub fn get_state_at(&self, ctx: ContextId, clk: RowIndex) -> Vec<(MemoryAddress, Felt)> {
if clk == 0 {
return vec![];
}
match self.trace.get(&ctx) {
Some(segment) => segment.get_state_at(clk),
None => vec![],
}
}
pub fn read(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
err_ctx: &impl ErrorContext,
) -> Result<Felt, MemoryError> {
let addr: u32 = addr
.as_int()
.try_into()
.map_err(|_| MemoryError::address_out_of_bounds(addr.as_int(), err_ctx))?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read(ctx, addr, Felt::from(clk))
}
pub fn read_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
err_ctx: &impl ErrorContext,
) -> Result<Word, MemoryError> {
let addr: u32 = addr
.as_int()
.try_into()
.map_err(|_| MemoryError::address_out_of_bounds(addr.as_int(), err_ctx))?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::unaligned_word_access(addr, ctx, clk.into(), err_ctx));
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read_word(ctx, addr, Felt::from(clk))
}
pub fn write(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Felt,
err_ctx: &impl ErrorContext,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_int()
.try_into()
.map_err(|_| MemoryError::address_out_of_bounds(addr.as_int(), err_ctx))?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write(ctx, addr, Felt::from(clk), value)
}
pub fn write_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Word,
err_ctx: &impl ErrorContext,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_int()
.try_into()
.map_err(|_| MemoryError::address_out_of_bounds(addr.as_int(), err_ctx))?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::unaligned_word_access(addr, ctx, clk.into(), err_ctx));
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write_word(ctx, addr, Felt::from(clk), value)
}
pub fn append_range_checks(&self, memory_start_row: RowIndex, range: &mut RangeChecker) {
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (ctx, addr, clk.as_int() - 1),
None => return,
};
let mut row = memory_start_row;
for (&ctx, segment) in self.trace.iter() {
for (&addr, addr_trace) in segment.inner().iter() {
for memory_access in addr_trace {
let clk = memory_access.clk().as_int();
let delta = if prev_ctx != ctx {
(u32::from(ctx) - u32::from(prev_ctx)).into()
} else if prev_addr != addr {
u64::from(addr - prev_addr)
} else {
clk - prev_clk
};
let (delta_hi, delta_lo) = split_u32_into_u16(delta);
range.add_range_checks(row, &[delta_lo, delta_hi]);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1_u32;
}
}
}
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (Felt::from(ctx), Felt::from(addr), clk - ONE),
None => return,
};
let mut row: RowIndex = 0.into();
for (ctx, segment) in self.trace {
let ctx = Felt::from(ctx);
for (addr, addr_trace) in segment.into_inner() {
let felt_addr = Felt::from(addr);
for memory_access in addr_trace {
let clk = memory_access.clk();
let value = memory_access.word();
match memory_access.operation() {
MemoryOperation::Read => trace.set(row, IS_READ_COL_IDX, MEMORY_READ),
MemoryOperation::Write => trace.set(row, IS_READ_COL_IDX, MEMORY_WRITE),
}
let (idx1, idx0) = match memory_access.access_type() {
segment::MemoryAccessType::Element { addr_idx_in_word } => {
trace.set(row, IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_ELEMENT);
match addr_idx_in_word {
0 => (ZERO, ZERO),
1 => (ZERO, ONE),
2 => (ONE, ZERO),
3 => (ONE, ONE),
_ => panic!("invalid address index in word: {addr_idx_in_word}"),
}
},
segment::MemoryAccessType::Word => {
trace.set(row, IS_WORD_ACCESS_COL_IDX, MEMORY_ACCESS_WORD);
(ZERO, ZERO)
},
};
trace.set(row, CTX_COL_IDX, ctx);
trace.set(row, WORD_COL_IDX, felt_addr);
trace.set(row, IDX0_COL_IDX, idx0);
trace.set(row, IDX1_COL_IDX, idx1);
trace.set(row, CLK_COL_IDX, clk);
for (idx, col) in V_COL_RANGE.enumerate() {
trace.set(row, col, value[idx]);
}
let delta = if prev_ctx != ctx {
ctx - prev_ctx
} else if prev_addr != felt_addr {
felt_addr - prev_addr
} else {
clk - prev_clk
};
let (delta_hi, delta_lo) = split_element_u32_into_u16(delta);
trace.set(row, D0_COL_IDX, delta_lo);
trace.set(row, D1_COL_IDX, delta_hi);
trace.set(row, D_INV_COL_IDX, delta.inv());
if prev_ctx == ctx && prev_addr == felt_addr {
trace.set(row, FLAG_SAME_CONTEXT_AND_WORD, ONE);
} else {
trace.set(row, FLAG_SAME_CONTEXT_AND_WORD, ZERO);
};
prev_ctx = ctx;
prev_addr = felt_addr;
prev_clk = clk;
row += 1_u32;
}
}
}
}
fn get_first_row_info(&self) -> Option<(ContextId, u32, Felt)> {
let (ctx, segment) = match self.trace.iter().next() {
Some((&ctx, segment)) => (ctx, segment),
None => return None,
};
let (&addr, addr_trace) = segment.inner().iter().next().expect("empty memory segment");
Some((ctx, addr, addr_trace[0].clk()))
}
#[cfg(test)]
pub fn num_accessed_words(&self) -> usize {
self.trace.iter().fold(0, |acc, (_, s)| acc + s.num_accessed_words())
}
}