use super::{Felt, HasherState, Selectors, TraceFragment, STATE_WIDTH, TRACE_WIDTH, ZERO};
use alloc::vec::Vec;
use core::ops::Range;
use miden_air::trace::chiplets::hasher::NUM_ROUNDS;
use vm_core::chiplets::hasher::apply_round;
#[derive(Default)]
pub struct HasherTrace {
selectors: [Vec<Felt>; 3],
hasher_state: [Vec<Felt>; STATE_WIDTH],
node_index: Vec<Felt>,
}
impl HasherTrace {
pub fn trace_len(&self) -> usize {
self.selectors[0].len()
}
pub fn next_row_addr(&self) -> Felt {
Felt::new(self.trace_len() as u64 + 1)
}
pub fn append_permutation_with_index(
&mut self,
state: &mut HasherState,
init_selectors: Selectors,
final_selectors: Selectors,
init_index: Felt,
rest_index: Felt,
) {
self.append_row(init_selectors, state, init_index);
let next_selectors = [ZERO, init_selectors[1], init_selectors[2]];
for i in 0..NUM_ROUNDS - 1 {
apply_round(state, i);
self.append_row(next_selectors, state, rest_index);
}
apply_round(state, NUM_ROUNDS - 1);
self.append_row(final_selectors, state, rest_index);
}
#[inline(always)]
pub fn append_permutation(
&mut self,
state: &mut HasherState,
init_selectors: Selectors,
final_selectors: Selectors,
) {
self.append_permutation_with_index(state, init_selectors, final_selectors, ZERO, ZERO);
}
fn append_row(&mut self, selectors: Selectors, state: &HasherState, index: Felt) {
for (trace_col, selector_val) in self.selectors.iter_mut().zip(selectors) {
trace_col.push(selector_val);
}
for (trace_col, &state_val) in self.hasher_state.iter_mut().zip(state) {
trace_col.push(state_val);
}
self.node_index.push(index);
}
pub fn copy_trace(&mut self, state: &mut [Felt; STATE_WIDTH], range: Range<usize>) {
for selector in self.selectors.iter_mut() {
selector.extend_from_within(range.clone());
}
for hasher in self.hasher_state.iter_mut() {
hasher.extend_from_within(range.clone());
}
self.node_index.extend_from_within(range.clone());
for (col, hasher) in self.hasher_state.iter().enumerate() {
state[col] = hasher[range.end - 1];
}
}
pub fn fill_trace(self, trace: &mut TraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
debug_assert_eq!(TRACE_WIDTH, trace.width(), "inconsistent trace widths");
let mut columns = Vec::new();
self.selectors.into_iter().for_each(|c| columns.push(c));
self.hasher_state.into_iter().for_each(|c| columns.push(c));
columns.push(self.node_index);
for (out_column, column) in trace.columns().zip(columns) {
out_column.copy_from_slice(&column);
}
}
}