use alloc::{collections::BTreeMap, vec::Vec};
use core::{borrow::BorrowMut, fmt::Debug};
use miden_air::{
MemoryCols,
trace::{
RowIndex,
chiplets::memory::{
MEMORY_ACCESS_ELEMENT, MEMORY_ACCESS_WORD, MEMORY_READ, MEMORY_WRITE,
TRACE_WIDTH as MEMORY_TRACE_WIDTH,
},
},
};
use super::{
super::utils::{split_element_u32_into_u16, split_u32_into_u16},
ChipletTraceFragment, RangeChecker,
};
use crate::{
ContextId, EMPTY_WORD, Felt, MemoryAddress, MemoryError, ONE, WORD_SIZE, Word, ZERO,
field::batch_inversion_allow_zeros,
};
mod segment;
use segment::{MemoryOperation, MemorySegmentTrace};
#[cfg(test)]
mod tests;
const INIT_MEM_VALUE: Word = EMPTY_WORD;
#[derive(Debug, Default)]
pub struct Memory {
trace: BTreeMap<ContextId, MemorySegmentTrace>,
num_trace_rows: usize,
}
impl Memory {
pub fn trace_len(&self) -> usize {
self.num_trace_rows
}
pub fn get_value(&self, ctx: ContextId, addr: u32) -> Option<Felt> {
match self.trace.get(&ctx) {
Some(segment) => segment.get_value(addr),
None => None,
}
}
pub fn get_word(&self, ctx: ContextId, addr: u32) -> Result<Option<Word>, MemoryError> {
match self.trace.get(&ctx) {
Some(segment) => segment
.get_word(addr)
.map_err(|_| MemoryError::UnalignedWordAccess { addr, ctx }),
None => Ok(None),
}
}
pub fn get_state_at(&self, ctx: ContextId, clk: RowIndex) -> Vec<(MemoryAddress, Felt)> {
if clk == 0 {
return vec![];
}
match self.trace.get(&ctx) {
Some(segment) => segment.get_state_at(clk),
None => vec![],
}
}
pub fn read(&mut self, ctx: ContextId, addr: Felt, clk: RowIndex) -> Result<Felt, MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read(ctx, addr, Felt::from(clk))
}
pub fn read_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
) -> Result<Word, MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::UnalignedWordAccess { addr, ctx });
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().read_word(ctx, addr, Felt::from(clk))
}
pub fn write(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Felt,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write(ctx, addr, Felt::from(clk), value)
}
pub fn write_word(
&mut self,
ctx: ContextId,
addr: Felt,
clk: RowIndex,
value: Word,
) -> Result<(), MemoryError> {
let addr: u32 = addr
.as_canonical_u64()
.try_into()
.map_err(|_| MemoryError::AddressOutOfBounds { addr: addr.as_canonical_u64() })?;
if !addr.is_multiple_of(WORD_SIZE as u32) {
return Err(MemoryError::UnalignedWordAccess { addr, ctx });
}
self.num_trace_rows += 1;
self.trace.entry(ctx).or_default().write_word(ctx, addr, Felt::from(clk), value)
}
pub fn append_range_checks(&self, memory_start_row: RowIndex, range: &mut RangeChecker) {
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (ctx, addr, clk.as_canonical_u64().wrapping_sub(1)),
None => return,
};
let mut row = memory_start_row;
for (&ctx, segment) in self.trace.iter() {
for (&addr, addr_trace) in segment.inner().iter() {
for memory_access in addr_trace {
let clk = memory_access.clk().as_canonical_u64();
let delta = if prev_ctx != ctx {
(u32::from(ctx) - u32::from(prev_ctx)).into()
} else if prev_addr != addr {
u64::from(addr - prev_addr)
} else {
clk.wrapping_sub(prev_clk)
};
let (delta_hi, delta_lo) = split_u32_into_u16(delta);
range.add_range_checks(&[delta_lo, delta_hi]);
let word_index = addr / WORD_SIZE as u32;
let w0 = (word_index & 0xffff) as u16;
let w1 = (word_index >> 16) as u16;
range.add_value(w0);
range.add_value(w1);
range.add_value(w1 << 2);
prev_ctx = ctx;
prev_addr = addr;
prev_clk = clk;
row += 1_u32;
}
}
}
}
pub fn fill_trace(self, trace: &mut ChipletTraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
let (mut prev_ctx, mut prev_addr, mut prev_clk) = match self.get_first_row_info() {
Some((ctx, addr, clk)) => (Felt::from(ctx), Felt::from_u32(addr), clk - ONE),
None => return,
};
let num_rows = self.trace_len();
let mut buffer = vec![ZERO; num_rows * MEMORY_TRACE_WIDTH];
let (out_rows, _) = buffer.as_chunks_mut::<MEMORY_TRACE_WIDTH>();
let mut deltas: Vec<Felt> = Vec::with_capacity(num_rows);
let mut row: usize = 0;
for (ctx, segment) in self.trace {
let ctx = Felt::from(ctx);
for (addr, addr_trace) in segment.into_inner() {
let felt_addr = Felt::from_u32(addr);
for memory_access in addr_trace {
let clk = memory_access.clk();
let value = memory_access.word();
let (mem_slice, aux_slice) = out_rows[row].split_at_mut(MEMORY_TRACE_WIDTH - 2);
let cols: &mut MemoryCols<Felt> = mem_slice.borrow_mut();
cols.is_read = match memory_access.operation() {
MemoryOperation::Read => MEMORY_READ,
MemoryOperation::Write => MEMORY_WRITE,
};
let (idx1, idx0) = match memory_access.access_type() {
segment::MemoryAccessType::Element { addr_idx_in_word } => {
cols.is_word = MEMORY_ACCESS_ELEMENT;
match addr_idx_in_word {
0 => (ZERO, ZERO),
1 => (ZERO, ONE),
2 => (ONE, ZERO),
3 => (ONE, ONE),
_ => panic!("invalid address index in word: {addr_idx_in_word}"),
}
},
segment::MemoryAccessType::Word => {
cols.is_word = MEMORY_ACCESS_WORD;
(ZERO, ZERO)
},
};
cols.ctx = ctx;
cols.word_addr = felt_addr;
cols.idx0 = idx0;
cols.idx1 = idx1;
cols.clk = clk;
cols.values = *value;
let delta = if prev_ctx != ctx {
ctx - prev_ctx
} else if prev_addr != felt_addr {
felt_addr - prev_addr
} else {
clk - prev_clk
};
let (delta_hi, delta_lo) = split_element_u32_into_u16(delta);
cols.d0 = delta_lo;
cols.d1 = delta_hi;
deltas.push(delta);
cols.is_same_ctx_and_addr = if prev_ctx == ctx && prev_addr == felt_addr {
ONE
} else {
ZERO
};
let word_index = addr / WORD_SIZE as u32;
aux_slice[0] = Felt::from_u16((word_index & 0xffff) as u16);
aux_slice[1] = Felt::from_u16((word_index >> 16) as u16);
prev_ctx = ctx;
prev_addr = felt_addr;
prev_clk = clk;
row += 1;
}
}
}
batch_inversion_allow_zeros(&mut deltas);
for (r, &inv) in deltas.iter().enumerate() {
let cols: &mut MemoryCols<Felt> = out_rows[r][..MEMORY_TRACE_WIDTH - 2].borrow_mut();
cols.d_inv = inv;
}
trace.copy_rows_from(&buffer);
}
fn get_first_row_info(&self) -> Option<(ContextId, u32, Felt)> {
let (ctx, segment) = match self.trace.iter().next() {
Some((&ctx, segment)) => (ctx, segment),
None => return None,
};
let (&addr, addr_trace) = segment.inner().iter().next().expect("empty memory segment");
Some((ctx, addr, addr_trace[0].clk()))
}
#[cfg(test)]
pub fn num_accessed_words(&self) -> usize {
self.trace.iter().fold(0, |acc, (_, s)| acc + s.num_accessed_words())
}
}