use alloc::vec::Vec;
use miden_air::trace::{
CHIPLETS_RANGE, CHIPLETS_WIDTH,
chiplets::{
KERNEL_ROM_TRACE_WIDTH, NUM_BITWISE_SELECTORS, NUM_KERNEL_ROM_SELECTORS,
NUM_MEMORY_SELECTORS,
bitwise::{self, BITWISE_XOR, OP_CYCLE_LEN},
hasher::{CONTROLLER_ROWS_PER_PERMUTATION, HASH_CYCLE_LEN, LINEAR_HASH, S_PERM_COL_IDX},
memory,
},
};
use miden_core::{
Felt, ONE, Word, ZERO,
mast::{BasicBlockNodeBuilder, CallNodeBuilder, MastForest, MastForestContributor},
program::{Program, StackInputs},
};
use crate::{
AdviceInputs, DefaultHost, ExecutionOptions, FastProcessor, Kernel, operation::Operation,
};
type ChipletsTrace = [Vec<Felt>; CHIPLETS_WIDTH];
fn hasher_trace_len(controller_rows: usize, unique_perms: usize) -> usize {
let controller_padded = controller_rows.next_multiple_of(HASH_CYCLE_LEN);
let perm_segment = unique_perms * HASH_CYCLE_LEN;
controller_padded + perm_segment
}
#[test]
fn hasher_chiplet_trace() {
let stack = [2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0];
let operations = vec![Operation::HPerm];
let (chiplets_trace, _trace_len) = build_trace(&stack, operations, Kernel::default());
let controller_rows = 2 * CONTROLLER_ROWS_PER_PERMUTATION; let unique_perms = 2;
let hasher_len = hasher_trace_len(controller_rows, unique_perms);
assert_eq!(hasher_len, 48);
validate_hasher_trace(&chiplets_trace, hasher_len, controller_rows, unique_perms);
}
#[test]
fn bitwise_chiplet_trace() {
let stack = [4, 8];
let operations = vec![Operation::U32xor];
let (chiplets_trace, _trace_len) = build_trace(&stack, operations, Kernel::default());
let controller_rows = CONTROLLER_ROWS_PER_PERMUTATION; let unique_perms = 1;
let hasher_len = hasher_trace_len(controller_rows, unique_perms);
assert_eq!(hasher_len, 32);
let bitwise_start = hasher_len;
let bitwise_end = bitwise_start + OP_CYCLE_LEN;
validate_bitwise_trace(&chiplets_trace, bitwise_start, bitwise_end);
}
#[test]
fn memory_chiplet_trace() {
let addr = Felt::from_u32(4);
let stack = [1, 2, 3, 4];
let operations = vec![Operation::Push(addr), Operation::MStoreW];
let (chiplets_trace, _trace_len) = build_trace(&stack, operations, Kernel::default());
let controller_rows = CONTROLLER_ROWS_PER_PERMUTATION;
let unique_perms = 1;
let hasher_len = hasher_trace_len(controller_rows, unique_perms);
assert_eq!(hasher_len, 32);
let memory_start = hasher_len;
validate_memory_trace(&chiplets_trace, memory_start, memory_start + 1);
}
#[test]
fn stacked_chiplet_trace() {
let stack = [8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 1];
let ops = vec![Operation::U32xor, Operation::Push(ZERO), Operation::MStoreW, Operation::HPerm];
let kernel = build_kernel();
let (chiplets_trace, _trace_len) = build_trace(&stack, ops, kernel);
let controller_rows = 2 * CONTROLLER_ROWS_PER_PERMUTATION; let unique_perms = 2;
let hasher_len = hasher_trace_len(controller_rows, unique_perms);
assert_eq!(hasher_len, 48);
validate_hasher_trace(&chiplets_trace, hasher_len, controller_rows, unique_perms);
let bitwise_start = hasher_len;
let bitwise_end = bitwise_start + OP_CYCLE_LEN;
validate_bitwise_trace(&chiplets_trace, bitwise_start, bitwise_end);
let memory_start = bitwise_end;
validate_memory_trace(&chiplets_trace, memory_start, memory_start + 1);
let kernel_rom_start = memory_start + 1;
let kernel_rom_end = kernel_rom_start + 2; validate_kernel_rom_trace(&chiplets_trace, kernel_rom_start, kernel_rom_end);
let padding_start = kernel_rom_end;
let trace_rows = chiplets_trace[0].len();
validate_padding(&chiplets_trace, padding_start, trace_rows);
}
#[test]
fn regression_trace_build_does_not_panic_when_first_memory_access_clk_is_zero() {
let processor = FastProcessor::new(StackInputs::default());
let mut host = DefaultHost::default();
let program = {
let mut forest = MastForest::new();
let callee = BasicBlockNodeBuilder::new(vec![Operation::Noop], Vec::new())
.add_to_forest(&mut forest)
.unwrap();
forest.make_root(callee);
let entry = CallNodeBuilder::new(callee).add_to_forest(&mut forest).unwrap();
forest.make_root(entry);
Program::with_kernel(forest.into(), entry, Kernel::default())
};
let trace_inputs = processor.execute_trace_inputs_sync(&program, &mut host).unwrap();
let _trace = crate::trace::build_trace(trace_inputs).unwrap();
}
fn build_kernel() -> Kernel {
let proc_hash1 = Word::from([1_u32, 0, 1, 0]);
let proc_hash2 = Word::from([1_u32, 1, 1, 1]);
Kernel::new(&[proc_hash1, proc_hash2]).unwrap()
}
fn build_trace(
stack_inputs: &[u64],
operations: Vec<Operation>,
kernel: Kernel,
) -> (ChipletsTrace, usize) {
let stack_inputs: Vec<Felt> = stack_inputs.iter().map(|v| Felt::new_unchecked(*v)).collect();
let processor = FastProcessor::new_with_options(
StackInputs::new(&stack_inputs).unwrap(),
AdviceInputs::default(),
ExecutionOptions::default().with_core_trace_fragment_size(1 << 10).unwrap(),
)
.expect("processor advice inputs should fit advice map limits");
let mut host = DefaultHost::default();
let program = {
let mut mast_forest = MastForest::new();
let basic_block_id = BasicBlockNodeBuilder::new(operations, Vec::new())
.add_to_forest(&mut mast_forest)
.unwrap();
mast_forest.make_root(basic_block_id);
Program::with_kernel(mast_forest.into(), basic_block_id, kernel)
};
let trace_inputs = processor.execute_trace_inputs_sync(&program, &mut host).unwrap();
let trace = crate::trace::build_trace(trace_inputs).unwrap();
let trace_len = trace.get_trace_len();
(
trace
.get_column_range(CHIPLETS_RANGE)
.try_into()
.expect("failed to convert vector to array"),
trace_len,
)
}
fn validate_hasher_trace(
trace: &ChipletsTrace,
expected_len: usize,
controller_rows: usize,
unique_perms: usize,
) {
let s0_col = 1; let s1_col = 2; let s2_col = 3; let s_perm_col = 1 + S_PERM_COL_IDX;
let controller_padded = controller_rows.next_multiple_of(HASH_CYCLE_LEN);
let perm_segment_start = controller_padded;
let perm_segment_len = unique_perms * HASH_CYCLE_LEN;
assert_eq!(expected_len, controller_padded + perm_segment_len);
for row in 0..controller_padded {
assert_eq!(trace[0][row], ONE, "s_ctrl should be 1 for controller row {row}");
assert_eq!(trace[s_perm_col][row], ZERO, "s_perm should be 0 for controller row {row}");
}
for row in perm_segment_start..expected_len {
assert_eq!(trace[0][row], ZERO, "s_ctrl should be 0 for perm row {row}");
assert_eq!(trace[s_perm_col][row], ONE, "s_perm should be 1 for perm row {row}");
}
for row in 0..controller_rows {
let is_input_row = row % CONTROLLER_ROWS_PER_PERMUTATION == 0;
if is_input_row {
assert_eq!(
trace[s0_col][row], LINEAR_HASH[0],
"controller input row {row}: s0 should be {} (LINEAR_HASH)",
LINEAR_HASH[0]
);
} else {
assert_eq!(
trace[s0_col][row], ZERO,
"controller output row {row}: s0 should be 0 (RETURN_*)"
);
}
}
let padding_start = controller_rows;
for row in padding_start..controller_padded {
assert_eq!(trace[s0_col][row], ZERO, "padding row {row}: s0 should be 0");
assert_eq!(trace[s1_col][row], ONE, "padding row {row}: s1 should be 1");
assert_eq!(trace[s2_col][row], ZERO, "padding row {row}: s2 should be 0");
for col in 4..=CHIPLETS_WIDTH - 1 {
assert_eq!(trace[col][row], ZERO, "padding row {row}, col {col} should be zero");
}
}
for row in perm_segment_start..expected_len {
let offset_in_cycle = (row - perm_segment_start) % HASH_CYCLE_LEN;
match offset_in_cycle {
0..=3 | 12..=15 => {
assert_eq!(trace[s0_col][row], ZERO, "perm row {row}: s0 should be 0");
assert_eq!(trace[s1_col][row], ZERO, "perm row {row}: s1 should be 0");
assert_eq!(trace[s2_col][row], ZERO, "perm row {row}: s2 should be 0");
},
4..=10 => {
},
11 => {
assert_eq!(trace[s1_col][row], ZERO, "perm row {row}: s1 should be 0");
assert_eq!(trace[s2_col][row], ZERO, "perm row {row}: s2 should be 0");
},
_ => unreachable!(),
}
}
}
fn validate_bitwise_trace(trace: &ChipletsTrace, start: usize, end: usize) {
let bitwise_used_cols = NUM_BITWISE_SELECTORS + bitwise::TRACE_WIDTH;
for row in start..end {
assert_eq!(ZERO, trace[0][row], "bitwise s_ctrl at row {row}");
assert_eq!(ZERO, trace[1][row], "bitwise s1 at row {row}");
assert_eq!(BITWISE_XOR, trace[NUM_BITWISE_SELECTORS][row], "bitwise op at row {row}");
for col in bitwise_used_cols..CHIPLETS_WIDTH {
assert_eq!(
trace[col][row], ZERO,
"bitwise padding col {col} at row {row} should be zero"
);
}
}
}
fn validate_memory_trace(trace: &ChipletsTrace, start: usize, end: usize) {
let memory_used_cols = NUM_MEMORY_SELECTORS + memory::TRACE_WIDTH;
for row in start..end {
assert_eq!(ZERO, trace[0][row], "memory s_ctrl at row {row}");
assert_eq!(ONE, trace[1][row], "memory s1 at row {row}");
assert_eq!(ZERO, trace[2][row], "memory s2 at row {row}");
for col in memory_used_cols..CHIPLETS_WIDTH {
assert_eq!(
trace[col][row], ZERO,
"memory padding col {col} at row {row} should be zero"
);
}
}
}
fn validate_kernel_rom_trace(trace: &ChipletsTrace, start: usize, end: usize) {
let kernel_rom_used_cols = NUM_KERNEL_ROM_SELECTORS + KERNEL_ROM_TRACE_WIDTH;
for row in start..end {
assert_eq!(ZERO, trace[0][row], "kernel_rom s_ctrl at row {row}");
assert_eq!(ONE, trace[1][row], "kernel_rom s1 at row {row}");
assert_eq!(ONE, trace[2][row], "kernel_rom s2 at row {row}");
assert_eq!(ONE, trace[3][row], "kernel_rom s3 at row {row}");
assert_eq!(ZERO, trace[4][row], "kernel_rom s4 at row {row}");
for col in kernel_rom_used_cols..CHIPLETS_WIDTH {
assert_eq!(
trace[col][row], ZERO,
"kernel_rom padding col {col} at row {row} should be zero"
);
}
}
}
fn validate_padding(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ZERO, trace[0][row], "padding s_ctrl at row {row}");
for col in 1..5 {
assert_eq!(ONE, trace[col][row], "padding s{col} at row {row}");
}
for col in 5..CHIPLETS_WIDTH {
assert_eq!(ZERO, trace[col][row], "padding data col {col} at row {row} should be zero");
}
}
}