use super::{
super::utils::get_trace_len, get_num_groups_in_next_batch, Felt, Operation, Word, DIGEST_LEN,
MIN_TRACE_LEN, NUM_HASHER_COLUMNS, NUM_OP_BATCH_FLAGS, NUM_OP_BITS, NUM_OP_BITS_EXTRA_COLS,
ONE, OP_BATCH_1_GROUPS, OP_BATCH_2_GROUPS, OP_BATCH_4_GROUPS, OP_BATCH_8_GROUPS, OP_BATCH_SIZE,
ZERO,
};
use alloc::vec::Vec;
use core::ops::Range;
use vm_core::utils::new_array_vec;
#[cfg(test)]
use miden_air::trace::decoder::NUM_USER_OP_HELPERS;
pub const USER_OP_HELPERS: Range<usize> = Range {
start: 2,
end: NUM_HASHER_COLUMNS,
};
pub struct DecoderTrace {
addr_trace: Vec<Felt>,
op_bits_trace: [Vec<Felt>; NUM_OP_BITS],
hasher_trace: [Vec<Felt>; NUM_HASHER_COLUMNS],
in_span_trace: Vec<Felt>,
group_count_trace: Vec<Felt>,
op_idx_trace: Vec<Felt>,
op_batch_flag_trace: [Vec<Felt>; NUM_OP_BATCH_FLAGS],
op_bit_extra_trace: [Vec<Felt>; NUM_OP_BITS_EXTRA_COLS],
}
impl DecoderTrace {
pub fn new() -> Self {
Self {
addr_trace: Vec::with_capacity(MIN_TRACE_LEN),
op_bits_trace: new_array_vec(MIN_TRACE_LEN),
hasher_trace: new_array_vec(MIN_TRACE_LEN),
group_count_trace: Vec::with_capacity(MIN_TRACE_LEN),
in_span_trace: Vec::with_capacity(MIN_TRACE_LEN),
op_idx_trace: Vec::with_capacity(MIN_TRACE_LEN),
op_batch_flag_trace: new_array_vec(MIN_TRACE_LEN),
op_bit_extra_trace: new_array_vec(MIN_TRACE_LEN),
}
}
pub fn trace_len(&self) -> usize {
self.addr_trace.len()
}
pub fn program_hash(&self) -> [Felt; DIGEST_LEN] {
let mut result = [ZERO; DIGEST_LEN];
for (i, element) in result.iter_mut().enumerate() {
*element = self.last_hasher_value(i);
}
result
}
pub fn append_block_start(&mut self, parent_addr: Felt, op: Operation, h1: Word, h2: Word) {
self.addr_trace.push(parent_addr);
self.append_opcode(op);
self.hasher_trace[0].push(h1[0]);
self.hasher_trace[1].push(h1[1]);
self.hasher_trace[2].push(h1[2]);
self.hasher_trace[3].push(h1[3]);
self.hasher_trace[4].push(h2[0]);
self.hasher_trace[5].push(h2[1]);
self.hasher_trace[6].push(h2[2]);
self.hasher_trace[7].push(h2[3]);
self.in_span_trace.push(ZERO);
self.group_count_trace.push(ZERO);
self.op_idx_trace.push(ZERO);
self.op_batch_flag_trace[0].push(ZERO);
self.op_batch_flag_trace[1].push(ZERO);
self.op_batch_flag_trace[2].push(ZERO);
}
pub fn append_block_end(
&mut self,
block_addr: Felt,
block_hash: Word,
is_loop_body: Felt,
is_loop: Felt,
is_call: Felt,
is_syscall: Felt,
) {
debug_assert!(is_loop_body.as_int() <= 1, "invalid is_loop_body");
debug_assert!(is_loop.as_int() <= 1, "invalid is_loop");
debug_assert!(is_call.as_int() <= 1, "invalid is_call");
debug_assert!(is_syscall.as_int() <= 1, "invalid is_syscall");
self.addr_trace.push(block_addr);
self.append_opcode(Operation::End);
self.hasher_trace[0].push(block_hash[0]);
self.hasher_trace[1].push(block_hash[1]);
self.hasher_trace[2].push(block_hash[2]);
self.hasher_trace[3].push(block_hash[3]);
self.hasher_trace[4].push(is_loop_body);
self.hasher_trace[5].push(is_loop);
self.hasher_trace[6].push(is_call);
self.hasher_trace[7].push(is_syscall);
self.in_span_trace.push(ZERO);
let last_group_count = self.last_group_count();
debug_assert!(last_group_count == ZERO, "group count not zero");
self.group_count_trace.push(last_group_count);
self.op_idx_trace.push(ZERO);
self.op_batch_flag_trace[0].push(ZERO);
self.op_batch_flag_trace[1].push(ZERO);
self.op_batch_flag_trace[2].push(ZERO);
}
pub fn append_loop_repeat(&mut self, loop_addr: Felt) {
self.addr_trace.push(loop_addr);
self.append_opcode(Operation::Repeat);
let last_row = get_trace_len(&self.hasher_trace) - 1;
for column in self.hasher_trace.iter_mut() {
column.push(column[last_row]);
}
self.in_span_trace.push(ZERO);
self.group_count_trace.push(ZERO);
self.op_idx_trace.push(ZERO);
self.op_batch_flag_trace[0].push(ZERO);
self.op_batch_flag_trace[1].push(ZERO);
self.op_batch_flag_trace[2].push(ZERO);
}
pub fn append_span_start(
&mut self,
parent_addr: Felt,
first_op_batch: &[Felt; OP_BATCH_SIZE],
num_op_groups: Felt,
) {
self.addr_trace.push(parent_addr);
self.append_opcode(Operation::Span);
for (i, &op_group) in first_op_batch.iter().enumerate() {
self.hasher_trace[i].push(op_group);
}
self.in_span_trace.push(ZERO);
self.group_count_trace.push(num_op_groups);
self.op_idx_trace.push(ZERO);
let op_batch_flags = get_op_batch_flags(num_op_groups);
self.op_batch_flag_trace[0].push(op_batch_flags[0]);
self.op_batch_flag_trace[1].push(op_batch_flags[1]);
self.op_batch_flag_trace[2].push(op_batch_flags[2]);
}
pub fn append_respan(&mut self, op_batch: &[Felt; OP_BATCH_SIZE]) {
self.addr_trace.push(self.last_addr());
self.append_opcode(Operation::Respan);
for (i, &op_group) in op_batch.iter().enumerate() {
self.hasher_trace[i].push(op_group);
}
let group_count = self.last_group_count();
self.in_span_trace.push(ZERO);
self.group_count_trace.push(group_count);
self.op_idx_trace.push(ZERO);
let op_batch_flags = get_op_batch_flags(group_count);
self.op_batch_flag_trace[0].push(op_batch_flags[0]);
self.op_batch_flag_trace[1].push(op_batch_flags[1]);
self.op_batch_flag_trace[2].push(op_batch_flags[2]);
}
pub fn append_user_op(
&mut self,
op: Operation,
span_addr: Felt,
parent_addr: Felt,
num_groups_left: Felt,
group_ops_left: Felt,
op_idx: Felt,
) {
self.addr_trace.push(span_addr);
self.append_opcode(op);
self.hasher_trace[0].push(group_ops_left);
self.hasher_trace[1].push(parent_addr);
for idx in USER_OP_HELPERS {
self.hasher_trace[idx].push(ZERO);
}
self.in_span_trace.push(ONE);
self.group_count_trace.push(num_groups_left);
self.op_idx_trace.push(op_idx);
self.op_batch_flag_trace[0].push(ZERO);
self.op_batch_flag_trace[1].push(ZERO);
self.op_batch_flag_trace[2].push(ZERO);
}
pub fn append_span_end(&mut self, span_hash: Word, is_loop_body: Felt) {
debug_assert!(is_loop_body.as_int() <= 1, "invalid loop body");
self.addr_trace.push(self.last_addr());
self.append_opcode(Operation::End);
self.hasher_trace[0].push(span_hash[0]);
self.hasher_trace[1].push(span_hash[1]);
self.hasher_trace[2].push(span_hash[2]);
self.hasher_trace[3].push(span_hash[3]);
self.hasher_trace[4].push(is_loop_body);
self.hasher_trace[5].push(ZERO);
self.hasher_trace[6].push(ZERO);
self.hasher_trace[7].push(ZERO);
self.in_span_trace.push(ZERO);
let last_group_count = self.last_group_count();
debug_assert!(last_group_count == ZERO, "group count not zero");
self.group_count_trace.push(last_group_count);
self.op_idx_trace.push(ZERO);
self.op_batch_flag_trace[0].push(ZERO);
self.op_batch_flag_trace[1].push(ZERO);
self.op_batch_flag_trace[2].push(ZERO);
}
pub fn into_vec(mut self, trace_len: usize, num_rand_rows: usize) -> Vec<Vec<Felt>> {
let own_len = self.trace_len();
assert!(own_len + num_rand_rows <= trace_len, "target trace length too small");
let mut trace = Vec::new();
self.addr_trace.resize(trace_len, ZERO);
trace.push(self.addr_trace);
let halt_opcode = Operation::Halt.op_code();
for (i, mut column) in self.op_bits_trace.into_iter().enumerate() {
debug_assert_eq!(own_len, column.len());
let value = Felt::from((halt_opcode >> i) & 1);
column.resize(trace_len, value);
trace.push(column);
}
for (i, mut column) in self.hasher_trace.into_iter().enumerate() {
debug_assert_eq!(own_len, column.len());
if i < 4 {
let last_value = *column.last().expect("no last hasher trace value");
column.resize(trace_len, last_value);
} else {
column.resize(trace_len, ZERO);
}
trace.push(column);
}
debug_assert_eq!(own_len, self.in_span_trace.len());
self.in_span_trace.resize(trace_len, ZERO);
trace.push(self.in_span_trace);
debug_assert_eq!(own_len, self.group_count_trace.len());
self.group_count_trace.resize(trace_len, ZERO);
trace.push(self.group_count_trace);
debug_assert_eq!(own_len, self.op_idx_trace.len());
self.op_idx_trace.resize(trace_len, ZERO);
trace.push(self.op_idx_trace);
for mut column in self.op_batch_flag_trace.into_iter() {
debug_assert_eq!(own_len, column.len());
column.resize(trace_len, ZERO);
trace.push(column);
}
let [mut op_bit_extra_0, mut op_bit_extra_1] = self.op_bit_extra_trace;
debug_assert_eq!(own_len, op_bit_extra_0.len());
debug_assert_eq!(own_len, op_bit_extra_1.len());
op_bit_extra_0.resize(trace_len, ZERO);
trace.push(op_bit_extra_0);
debug_assert_eq!(1, (halt_opcode >> 6) & 1);
debug_assert_eq!(1, (halt_opcode >> 5) & 1);
op_bit_extra_1.resize(trace_len, ONE);
trace.push(op_bit_extra_1);
trace
}
fn last_addr(&self) -> Felt {
*self.addr_trace.last().expect("no last addr")
}
fn last_group_count(&self) -> Felt {
*self.group_count_trace.last().expect("no group count")
}
fn last_hasher_value(&self, idx: usize) -> Felt {
debug_assert!(idx < NUM_HASHER_COLUMNS, "invalid hasher register index");
*self.hasher_trace[idx].last().expect("no last hasher value")
}
fn last_helper_mut(&mut self, idx: usize) -> &mut Felt {
debug_assert!(idx < USER_OP_HELPERS.len(), "invalid helper register index");
self.hasher_trace[USER_OP_HELPERS.start + idx]
.last_mut()
.expect("no last helper value")
}
fn append_opcode(&mut self, op: Operation) {
let op_code = op.op_code();
for i in 0..NUM_OP_BITS {
let bit = Felt::from((op_code >> i) & 1);
self.op_bits_trace[i].push(bit);
}
let clk = self.op_bit_extra_trace[0].len();
let bit6 = self.op_bits_trace[NUM_OP_BITS - 1][clk];
let bit5 = self.op_bits_trace[NUM_OP_BITS - 2][clk];
let bit4 = self.op_bits_trace[NUM_OP_BITS - 3][clk];
self.op_bit_extra_trace[0].push(bit6 * (ONE - bit5) * bit4);
self.op_bit_extra_trace[1].push(bit6 * bit5);
}
pub fn set_user_op_helpers(&mut self, values: &[Felt]) {
assert!(values.len() <= USER_OP_HELPERS.len(), "too many values for helper columns");
for (idx, value) in values.iter().enumerate() {
*self.last_helper_mut(idx) = *value;
}
}
#[cfg(test)]
pub fn add_dummy_row(&mut self) {
self.addr_trace.push(ZERO);
for column in self.op_bits_trace.iter_mut() {
column.push(ZERO);
}
self.in_span_trace.push(ZERO);
for column in self.hasher_trace.iter_mut() {
column.push(ZERO);
}
self.group_count_trace.push(ZERO);
self.op_idx_trace.push(ZERO);
}
#[cfg(test)]
pub fn get_user_op_helpers(&self) -> [Felt; NUM_USER_OP_HELPERS] {
let mut result = [ZERO; NUM_USER_OP_HELPERS];
for (idx, helper) in result.iter_mut().enumerate() {
*helper = *self.hasher_trace[USER_OP_HELPERS.start + idx]
.last()
.expect("no last helper value");
}
result
}
}
fn get_op_batch_flags(num_groups_left: Felt) -> [Felt; NUM_OP_BATCH_FLAGS] {
let num_groups = get_num_groups_in_next_batch(num_groups_left);
match num_groups {
8 => OP_BATCH_8_GROUPS,
4 => OP_BATCH_4_GROUPS,
2 => OP_BATCH_2_GROUPS,
1 => OP_BATCH_1_GROUPS,
_ => panic!(
"invalid number of groups in a batch: {num_groups}, group count: {num_groups_left}"
),
}
}