use alloc::vec::Vec;
use core::ops::Range;
use miden_air::trace::chiplets::hasher::{HASH_CYCLE_LEN, TRACE_WIDTH};
use miden_core::chiplets::hasher::Hasher;
use super::{Felt, HasherState, ONE, STATE_WIDTH, Selectors, TraceFragment, ZERO};
#[derive(Debug, Default)]
pub struct HasherTrace {
selectors: [Vec<Felt>; 3],
hasher_state: [Vec<Felt>; STATE_WIDTH],
node_index: Vec<Felt>,
mrupdate_id: Vec<Felt>,
is_boundary: Vec<Felt>,
direction_bit: Vec<Felt>,
s_perm: Vec<Felt>,
}
impl HasherTrace {
pub fn trace_len(&self) -> usize {
self.selectors[0].len()
}
pub fn next_row_addr(&self) -> Felt {
Felt::new_unchecked(self.trace_len() as u64 + 1)
}
pub fn append_controller_row(
&mut self,
selectors: Selectors,
state: &HasherState,
node_index: Felt,
mrupdate_id: Felt,
is_boundary: Felt,
direction_bit: Felt,
) {
for (trace_col, selector_val) in self.selectors.iter_mut().zip(selectors) {
trace_col.push(selector_val);
}
for (trace_col, &state_val) in self.hasher_state.iter_mut().zip(state) {
trace_col.push(state_val);
}
self.node_index.push(node_index);
self.mrupdate_id.push(mrupdate_id);
self.is_boundary.push(is_boundary);
self.direction_bit.push(direction_bit);
self.s_perm.push(ZERO);
}
pub fn append_permutation_cycle(&mut self, init_state: &HasherState, multiplicity: Felt) {
let mut state = *init_state;
self.append_perm_row_with_witnesses(&state, multiplicity, [ZERO; 3]);
Hasher::apply_matmul_external(&mut state);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_INITIAL[0]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
for r in 1..=3 {
self.append_perm_row_with_witnesses(&state, multiplicity, [ZERO; 3]);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_INITIAL[r]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
}
for triple in 0..7_usize {
let base = triple * 3;
let pre_state = state;
let mut witnesses = [ZERO; 3];
for (k, witness) in witnesses.iter_mut().enumerate() {
let sbox_out = (state[0] + Hasher::ARK_INT[base + k]).exp_const_u64::<7>();
*witness = sbox_out;
state[0] = sbox_out;
Hasher::matmul_internal(&mut state, Hasher::MAT_DIAG);
}
self.append_perm_row_with_witnesses(&pre_state, multiplicity, witnesses);
}
let pre_state = state;
let w0 = (state[0] + Hasher::ARK_INT[21]).exp_const_u64::<7>();
state[0] = w0;
Hasher::matmul_internal(&mut state, Hasher::MAT_DIAG);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_TERMINAL[0]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
self.append_perm_row_with_witnesses(&pre_state, multiplicity, [w0, ZERO, ZERO]);
for r in 1..=3 {
self.append_perm_row_with_witnesses(&state, multiplicity, [ZERO; 3]);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_TERMINAL[r]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
}
self.append_perm_row_with_witnesses(&state, multiplicity, [ZERO; 3]);
}
fn append_perm_row_with_witnesses(
&mut self,
state: &HasherState,
multiplicity: Felt,
witnesses: [Felt; 3],
) {
self.selectors[0].push(witnesses[0]);
self.selectors[1].push(witnesses[1]);
self.selectors[2].push(witnesses[2]);
for (trace_col, &state_val) in self.hasher_state.iter_mut().zip(state) {
trace_col.push(state_val);
}
self.node_index.push(multiplicity);
self.mrupdate_id.push(ZERO);
self.is_boundary.push(ZERO);
self.direction_bit.push(ZERO);
self.s_perm.push(ONE);
}
pub fn pad_to_cycle_boundary(&mut self, mrupdate_id: Felt) {
let padding_selectors = [ZERO, ONE, ZERO];
let remainder = self.trace_len() % HASH_CYCLE_LEN;
if remainder != 0 {
let padding_rows = HASH_CYCLE_LEN - remainder;
for _ in 0..padding_rows {
self.append_controller_row(
padding_selectors,
&[ZERO; STATE_WIDTH],
ZERO,
mrupdate_id,
ZERO,
ZERO,
);
}
}
}
pub fn input_states_in_range(&self, range: Range<usize>) -> Vec<HasherState> {
let mut states = Vec::new();
for row in range {
if self.selectors[0][row] == ONE && self.s_perm[row] == ZERO {
let mut state = [ZERO; STATE_WIDTH];
for (col, hasher) in self.hasher_state.iter().enumerate() {
state[col] = hasher[row];
}
states.push(state);
}
}
states
}
pub fn copy_trace(&mut self, state: &mut [Felt; STATE_WIDTH], range: Range<usize>) {
for selector in self.selectors.iter_mut() {
selector.extend_from_within(range.clone());
}
for hasher in self.hasher_state.iter_mut() {
hasher.extend_from_within(range.clone());
}
self.node_index.extend_from_within(range.clone());
self.mrupdate_id.extend_from_within(range.clone());
self.is_boundary.extend_from_within(range.clone());
self.direction_bit.extend_from_within(range.clone());
self.s_perm.extend_from_within(range.clone());
for (col, hasher) in self.hasher_state.iter().enumerate() {
state[col] = hasher[range.end - 1];
}
}
pub fn overwrite_mrupdate_id_in_range(&mut self, range: Range<usize>, mrupdate_id: Felt) {
for row in range {
self.mrupdate_id[row] = mrupdate_id;
}
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
debug_assert_eq!(TRACE_WIDTH, trace.width(), "inconsistent trace widths");
let mut columns = Vec::new();
self.selectors.into_iter().for_each(|c| columns.push(c));
self.hasher_state.into_iter().for_each(|c| columns.push(c));
columns.push(self.node_index);
columns.push(self.mrupdate_id);
columns.push(self.is_boundary);
columns.push(self.direction_bit);
columns.push(self.s_perm);
for (out_column, column) in trace.columns().zip(columns) {
out_column.copy_from_slice(&column);
}
}
}