use crate::{
CodeBlock, DefaultHost, ExecutionOptions, ExecutionTrace, Kernel, Operation, Process,
StackInputs,
};
use alloc::vec::Vec;
use miden_air::trace::{
chiplets::{
bitwise::{BITWISE_XOR, OP_CYCLE_LEN, TRACE_WIDTH as BITWISE_TRACE_WIDTH},
hasher::{Digest, HASH_CYCLE_LEN, LINEAR_HASH, RETURN_STATE},
kernel_rom::TRACE_WIDTH as KERNEL_ROM_TRACE_WIDTH,
memory::TRACE_WIDTH as MEMORY_TRACE_WIDTH,
NUM_BITWISE_SELECTORS, NUM_KERNEL_ROM_SELECTORS, NUM_MEMORY_SELECTORS,
},
CHIPLETS_RANGE, CHIPLETS_WIDTH,
};
use vm_core::{CodeBlockTable, Felt, ONE, ZERO};
type ChipletsTrace = [Vec<Felt>; CHIPLETS_WIDTH];
#[test]
fn hasher_chiplet_trace() {
let stack = [2, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0];
let operations = vec![Operation::HPerm];
let (chiplets_trace, trace_len) = build_trace(&stack, operations, Kernel::default());
let hasher_start = HASH_CYCLE_LEN;
let hasher_end = hasher_start + HASH_CYCLE_LEN;
validate_hasher_trace(&chiplets_trace, hasher_start, hasher_end);
validate_padding(&chiplets_trace, hasher_end, trace_len);
}
#[test]
fn bitwise_chiplet_trace() {
let stack = [4, 8];
let operations = vec![Operation::U32xor];
let (chiplets_trace, trace_len) = build_trace(&stack, operations, Kernel::default());
let bitwise_end = HASH_CYCLE_LEN + OP_CYCLE_LEN;
validate_bitwise_trace(&chiplets_trace, HASH_CYCLE_LEN, bitwise_end);
validate_padding(&chiplets_trace, bitwise_end, trace_len - 1);
}
#[test]
fn memory_chiplet_trace() {
let stack = [1, 2, 3, 4];
let operations = vec![Operation::Push(Felt::new(2)), Operation::MStoreW];
let (chiplets_trace, trace_len) = build_trace(&stack, operations, Kernel::default());
let memory_trace_len = 1;
let memory_end = HASH_CYCLE_LEN + memory_trace_len;
validate_memory_trace(&chiplets_trace, HASH_CYCLE_LEN, memory_end);
validate_padding(&chiplets_trace, memory_end, trace_len);
}
#[test]
fn stacked_chiplet_trace() {
let stack = [8, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 1];
let ops = vec![Operation::U32xor, Operation::Push(ZERO), Operation::MStoreW, Operation::HPerm];
let kernel = build_kernel();
let (chiplets_trace, trace_len) = build_trace(&stack, ops, kernel);
let memory_len = 1;
let kernel_rom_len = 2;
let hasher_start = HASH_CYCLE_LEN;
let hasher_end = hasher_start + HASH_CYCLE_LEN;
validate_hasher_trace(&chiplets_trace, hasher_start, hasher_end);
let bitwise_end = hasher_end + OP_CYCLE_LEN;
validate_bitwise_trace(&chiplets_trace, hasher_end, bitwise_end);
let memory_end = bitwise_end + memory_len;
validate_memory_trace(&chiplets_trace, bitwise_end, memory_end);
let kernel_rom_end = memory_end + kernel_rom_len;
validate_kernel_rom_trace(&chiplets_trace, memory_end, kernel_rom_end);
validate_padding(&chiplets_trace, kernel_rom_end, trace_len);
}
fn build_kernel() -> Kernel {
let proc_hash1: Digest = [ONE, ZERO, ONE, ZERO].into();
let proc_hash2: Digest = [ONE, ONE, ONE, ONE].into();
Kernel::new(&[proc_hash1, proc_hash2]).unwrap()
}
fn build_trace(
stack_inputs: &[u64],
operations: Vec<Operation>,
kernel: Kernel,
) -> (ChipletsTrace, usize) {
let stack_inputs = StackInputs::try_from_ints(stack_inputs.iter().copied()).unwrap();
let host = DefaultHost::default();
let mut process = Process::new(kernel, stack_inputs, host, ExecutionOptions::default());
let program = CodeBlock::new_span(operations);
process.execute_code_block(&program, &CodeBlockTable::default()).unwrap();
let (trace, _, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = trace.num_rows() - ExecutionTrace::NUM_RAND_ROWS;
(
trace
.get_column_range(CHIPLETS_RANGE)
.try_into()
.expect("failed to convert vector to array"),
trace_len,
)
}
fn validate_hasher_trace(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ZERO, trace[0][row]);
match row % HASH_CYCLE_LEN {
0 => {
assert_eq!(LINEAR_HASH, [trace[1][row], trace[2][row], trace[3][row]]);
}
7 => {
assert_eq!(RETURN_STATE, [trace[1][row], trace[2][row], trace[3][row]]);
}
_ => {
assert_eq!(
[ZERO, LINEAR_HASH[1], LINEAR_HASH[2]],
[trace[1][row], trace[2][row], trace[3][row]]
);
}
}
}
}
fn validate_bitwise_trace(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ONE, trace[0][row]);
assert_eq!(ZERO, trace[1][row]);
assert_eq!(BITWISE_XOR, trace[2][row]);
for column in trace.iter().skip(BITWISE_TRACE_WIDTH + NUM_BITWISE_SELECTORS) {
assert_eq!(ZERO, column[row]);
}
}
}
fn validate_memory_trace(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ONE, trace[0][row]);
assert_eq!(ONE, trace[1][row]);
assert_eq!(ZERO, trace[2][row]);
for column in trace.iter().skip(MEMORY_TRACE_WIDTH + NUM_MEMORY_SELECTORS) {
assert_eq!(ZERO, column[row]);
}
}
}
fn validate_kernel_rom_trace(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ONE, trace[0][row]);
assert_eq!(ONE, trace[1][row]);
assert_eq!(ONE, trace[2][row]);
assert_eq!(ZERO, trace[3][row]);
assert_eq!(ZERO, trace[4][row]);
for column in trace.iter().skip(KERNEL_ROM_TRACE_WIDTH + NUM_KERNEL_ROM_SELECTORS) {
assert_eq!(ZERO, column[row]);
}
}
}
fn validate_padding(trace: &ChipletsTrace, start: usize, end: usize) {
for row in start..end {
assert_eq!(ONE, trace[0][row]);
assert_eq!(ONE, trace[1][row]);
assert_eq!(ONE, trace[2][row]);
assert_eq!(ONE, trace[3][row]);
trace.iter().skip(4).for_each(|column| {
assert_eq!(ZERO, column[row]);
});
}
}