use super::{
super::ZERO, Felt, FieldElement, Memory, TraceFragment, ADDR_COL_IDX, CLK_COL_IDX, CTX_COL_IDX,
D0_COL_IDX, D1_COL_IDX, D_INV_COL_IDX, EMPTY_WORD, ONE, V_COL_RANGE,
};
use crate::ContextId;
use alloc::vec::Vec;
use miden_air::trace::chiplets::memory::{
Selectors, MEMORY_COPY_READ, MEMORY_INIT_READ, MEMORY_WRITE, TRACE_WIDTH as MEMORY_TRACE_WIDTH,
};
use vm_core::Word;
#[test]
fn mem_init() {
let mem = Memory::default();
assert_eq!(0, mem.size());
assert_eq!(0, mem.trace_len());
}
#[test]
fn mem_read() {
let mut mem = Memory::default();
let addr0 = 0;
let value = mem.read(ContextId::root(), addr0, 1);
assert_eq!(EMPTY_WORD, value);
assert_eq!(1, mem.size());
assert_eq!(1, mem.trace_len());
let addr3 = 3;
let value = mem.read(ContextId::root(), addr3, 2);
assert_eq!(EMPTY_WORD, value);
assert_eq!(2, mem.size());
assert_eq!(2, mem.trace_len());
let value = mem.read(ContextId::root(), addr0, 3);
assert_eq!(EMPTY_WORD, value);
assert_eq!(2, mem.size());
assert_eq!(3, mem.trace_len());
let addr2 = 2;
let value = mem.read(ContextId::root(), addr2, 4);
assert_eq!(EMPTY_WORD, value);
assert_eq!(3, mem.size());
assert_eq!(4, mem.trace_len());
let trace = build_trace(mem, 4);
let mut prev_row = [ZERO; MEMORY_TRACE_WIDTH];
let memory_access = MemoryAccess::new(ContextId::root(), addr0, 1, EMPTY_WORD);
prev_row = verify_memory_access(&trace, 0, MEMORY_INIT_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr0, 3, EMPTY_WORD);
prev_row = verify_memory_access(&trace, 1, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 4, EMPTY_WORD);
prev_row = verify_memory_access(&trace, 2, MEMORY_INIT_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr3, 2, EMPTY_WORD);
verify_memory_access(&trace, 3, MEMORY_INIT_READ, &memory_access, prev_row);
}
#[test]
fn mem_write() {
let mut mem = Memory::default();
let addr0 = 0;
let value1 = [ONE, ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr0, 1, value1);
assert_eq!(value1, mem.get_value(ContextId::root(), addr0).unwrap());
assert_eq!(1, mem.size());
assert_eq!(1, mem.trace_len());
let addr2 = 2;
let value5 = [Felt::new(5), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr2, 2, value5);
assert_eq!(value5, mem.get_value(ContextId::root(), addr2).unwrap());
assert_eq!(2, mem.size());
assert_eq!(2, mem.trace_len());
let addr1 = 1;
let value7 = [Felt::new(7), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr1, 3, value7);
assert_eq!(value7, mem.get_value(ContextId::root(), addr1).unwrap());
assert_eq!(3, mem.size());
assert_eq!(3, mem.trace_len());
let value9 = [Felt::new(9), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr0, 4, value9);
assert_eq!(value7, mem.get_value(ContextId::root(), addr1).unwrap());
assert_eq!(3, mem.size());
assert_eq!(4, mem.trace_len());
let trace = build_trace(mem, 4);
let mut prev_row = [ZERO; MEMORY_TRACE_WIDTH];
let memory_access = MemoryAccess::new(ContextId::root(), addr0, 1, value1);
prev_row = verify_memory_access(&trace, 0, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr0, 4, value9);
prev_row = verify_memory_access(&trace, 1, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr1, 3, value7);
prev_row = verify_memory_access(&trace, 2, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 2, value5);
verify_memory_access(&trace, 3, MEMORY_WRITE, &memory_access, prev_row);
}
#[test]
fn mem_write_read() {
let mut mem = Memory::default();
let addr5 = 5;
let value1 = [ONE, ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr5, 1, value1);
let addr2 = 2;
let value4 = [Felt::new(4), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr2, 2, value4);
mem.read(ContextId::root(), addr5, 3);
let value2 = [Felt::new(2), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr5, 4, value2);
mem.read(ContextId::root(), addr2, 5);
let value7 = [Felt::new(7), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), addr2, 6, value7);
mem.read(ContextId::root(), addr5, 7);
mem.read(ContextId::root(), addr2, 8);
mem.read(ContextId::root(), addr5, 9);
let trace = build_trace(mem, 9);
let mut prev_row = [ZERO; MEMORY_TRACE_WIDTH];
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 2, value4);
prev_row = verify_memory_access(&trace, 0, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 5, value4);
prev_row = verify_memory_access(&trace, 1, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 6, value7);
prev_row = verify_memory_access(&trace, 2, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr2, 8, value7);
prev_row = verify_memory_access(&trace, 3, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr5, 1, value1);
prev_row = verify_memory_access(&trace, 4, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr5, 3, value1);
prev_row = verify_memory_access(&trace, 5, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr5, 4, value2);
prev_row = verify_memory_access(&trace, 6, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr5, 7, value2);
prev_row = verify_memory_access(&trace, 7, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), addr5, 9, value2);
verify_memory_access(&trace, 8, MEMORY_COPY_READ, &memory_access, prev_row);
}
#[test]
fn mem_multi_context() {
let mut mem = Memory::default();
let value1 = [ONE, ZERO, ZERO, ZERO];
mem.write(ContextId::root(), 0, 1, value1);
assert_eq!(value1, mem.get_value(ContextId::root(), 0).unwrap());
assert_eq!(1, mem.size());
assert_eq!(1, mem.trace_len());
let value2 = [ZERO, ONE, ZERO, ZERO];
mem.write(3.into(), 1, 4, value2);
assert_eq!(value2, mem.get_value(3.into(), 1).unwrap());
assert_eq!(2, mem.size());
assert_eq!(2, mem.trace_len());
let value = mem.read(3.into(), 1, 6);
assert_eq!(value2, value);
assert_eq!(2, mem.size());
assert_eq!(3, mem.trace_len());
let value3 = [ZERO, ZERO, ONE, ZERO];
mem.write(3.into(), 0, 7, value3);
assert_eq!(value3, mem.get_value(3.into(), 0).unwrap());
assert_eq!(3, mem.size());
assert_eq!(4, mem.trace_len());
let value = mem.read(ContextId::root(), 0, 9);
assert_eq!(value1, value);
assert_eq!(3, mem.size());
assert_eq!(5, mem.trace_len());
let trace = build_trace(mem, 5);
let mut prev_row = [ZERO; MEMORY_TRACE_WIDTH];
let memory_access = MemoryAccess::new(ContextId::root(), 0, 1, value1);
prev_row = verify_memory_access(&trace, 0, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(ContextId::root(), 0, 9, value1);
prev_row = verify_memory_access(&trace, 1, MEMORY_COPY_READ, &memory_access, prev_row);
let memory_access = MemoryAccess::new(3.into(), 0, 7, value3);
prev_row = verify_memory_access(&trace, 2, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(3.into(), 1, 4, value2);
prev_row = verify_memory_access(&trace, 3, MEMORY_WRITE, &memory_access, prev_row);
let memory_access = MemoryAccess::new(3.into(), 1, 6, value2);
verify_memory_access(&trace, 4, MEMORY_COPY_READ, &memory_access, prev_row);
}
#[test]
fn mem_get_state_at() {
let mut mem = Memory::default();
let value1 = [ONE, ZERO, ZERO, ZERO];
mem.write(ContextId::root(), 5, 1, value1);
let value4 = [Felt::new(4), ZERO, ZERO, ZERO];
mem.write(ContextId::root(), 2, 2, value4);
let value7 = [Felt::new(7), ZERO, ZERO, ZERO];
mem.write(3.into(), 3, 4, value7);
assert_eq!(mem.get_state_at(ContextId::root(), 2), vec![(5, value1)]);
assert_eq!(mem.get_state_at(3.into(), 2), vec![]);
assert_eq!(mem.get_state_at(ContextId::root(), 3), vec![(2, value4), (5, value1)]);
assert_eq!(mem.get_state_at(3.into(), 3), vec![]);
assert_eq!(mem.get_state_at(ContextId::root(), 4), vec![(2, value4), (5, value1)]);
assert_eq!(mem.get_state_at(3.into(), 4), vec![]);
assert_eq!(mem.get_state_at(ContextId::root(), 5), vec![(2, value4), (5, value1)]);
assert_eq!(mem.get_state_at(3.into(), 5), vec![(3, value7)]);
}
pub struct MemoryAccess {
ctx: ContextId,
addr: Felt,
clk: Felt,
word: [Felt; 4],
}
impl MemoryAccess {
pub fn new(ctx: ContextId, addr: u32, clk: u32, word: Word) -> Self {
Self {
ctx,
addr: Felt::from(addr),
clk: Felt::from(clk),
word,
}
}
}
fn build_trace(mem: Memory, num_rows: usize) -> Vec<Vec<Felt>> {
let mut trace = (0..MEMORY_TRACE_WIDTH).map(|_| vec![ZERO; num_rows]).collect::<Vec<_>>();
let mut fragment = TraceFragment::trace_to_fragment(&mut trace);
mem.fill_trace(&mut fragment);
trace
}
fn read_trace_row(trace: &[Vec<Felt>], step: usize) -> [Felt; MEMORY_TRACE_WIDTH] {
let mut row = [ZERO; MEMORY_TRACE_WIDTH];
for (value, column) in row.iter_mut().zip(trace) {
*value = column[step];
}
row
}
fn build_trace_row(
memory_access: &MemoryAccess,
op_selectors: Selectors,
prev_row: [Felt; MEMORY_TRACE_WIDTH],
) -> [Felt; MEMORY_TRACE_WIDTH] {
let MemoryAccess {
ctx,
addr,
clk,
word: new_val,
} = *memory_access;
let mut row = [ZERO; MEMORY_TRACE_WIDTH];
row[0] = op_selectors[0];
row[1] = op_selectors[1];
row[CTX_COL_IDX] = ctx.into();
row[ADDR_COL_IDX] = addr;
row[CLK_COL_IDX] = clk;
row[V_COL_RANGE.start] = new_val[0];
row[V_COL_RANGE.start + 1] = new_val[1];
row[V_COL_RANGE.start + 2] = new_val[2];
row[V_COL_RANGE.start + 3] = new_val[3];
if prev_row != [ZERO; MEMORY_TRACE_WIDTH] {
let delta = if row[CTX_COL_IDX] != prev_row[CTX_COL_IDX] {
row[CTX_COL_IDX] - prev_row[CTX_COL_IDX]
} else if row[ADDR_COL_IDX] != prev_row[ADDR_COL_IDX] {
row[ADDR_COL_IDX] - prev_row[ADDR_COL_IDX]
} else {
row[CLK_COL_IDX] - prev_row[CLK_COL_IDX] - ONE
};
let (hi, lo) = super::split_element_u32_into_u16(delta);
row[D0_COL_IDX] = lo;
row[D1_COL_IDX] = hi;
row[D_INV_COL_IDX] = delta.inv();
}
row
}
fn verify_memory_access(
trace: &[Vec<Felt>],
row: u32,
op_selectors: Selectors,
memory_access: &MemoryAccess,
prev_row: [Felt; MEMORY_TRACE_WIDTH],
) -> [Felt; MEMORY_TRACE_WIDTH] {
let expected_row = build_trace_row(memory_access, op_selectors, prev_row);
assert_eq!(expected_row, read_trace_row(trace, row as usize));
expected_row
}