use alloc::vec::Vec;
use core::{borrow::BorrowMut, ops::Range};
use miden_air::{
ControllerCols, PermutationCols,
trace::chiplets::hasher::{HASH_CYCLE_LEN, TRACE_WIDTH},
};
use miden_core::chiplets::hasher::Hasher;
use super::{ChipletTraceFragment, Felt, HasherState, ONE, STATE_WIDTH, Selectors, ZERO};
const S_PERM_OFFSET: usize = TRACE_WIDTH - 1;
#[derive(Debug, Clone)]
enum HasherOp {
Controller {
selectors: Selectors,
state: HasherState,
node_index: Felt,
mrupdate_id: Felt,
is_boundary: Felt,
direction_bit: Felt,
},
Permutation {
init_state: HasherState,
multiplicity: Felt,
},
Padding { count: usize, mrupdate_id: Felt },
}
impl HasherOp {
fn row_count(&self) -> usize {
match self {
Self::Controller { .. } => 1,
Self::Permutation { .. } => HASH_CYCLE_LEN,
Self::Padding { count, .. } => *count,
}
}
}
#[derive(Debug, Default)]
pub struct HasherTrace {
ops: Vec<HasherOp>,
row_count: usize,
}
impl HasherTrace {
pub fn trace_len(&self) -> usize {
self.row_count
}
pub fn next_row_addr(&self) -> Felt {
Felt::new_unchecked(self.row_count as u64 + 1)
}
pub fn next_op_index(&self) -> usize {
self.ops.len()
}
pub fn append_controller_row(
&mut self,
selectors: Selectors,
state: &HasherState,
node_index: Felt,
mrupdate_id: Felt,
is_boundary: Felt,
direction_bit: Felt,
) {
self.ops.push(HasherOp::Controller {
selectors,
state: *state,
node_index,
mrupdate_id,
is_boundary,
direction_bit,
});
self.row_count += 1;
}
pub fn append_permutation_cycle(&mut self, init_state: &HasherState, multiplicity: Felt) {
self.ops.push(HasherOp::Permutation { init_state: *init_state, multiplicity });
self.row_count += HASH_CYCLE_LEN;
}
pub fn pad_to_cycle_boundary(&mut self, mrupdate_id: Felt) {
let remainder = self.row_count % HASH_CYCLE_LEN;
if remainder != 0 {
let count = HASH_CYCLE_LEN - remainder;
self.ops.push(HasherOp::Padding { count, mrupdate_id });
self.row_count += count;
}
}
pub fn replay_ops_range(
&mut self,
range: Range<usize>,
new_mrupdate_id: Felt,
) -> (HasherState, Vec<HasherState>) {
let copied: Vec<HasherOp> = self.ops[range].to_vec();
let mut last_state = [ZERO; STATE_WIDTH];
let mut input_states = Vec::new();
for mut op in copied {
match &mut op {
HasherOp::Controller { mrupdate_id, selectors, state, .. } => {
*mrupdate_id = new_mrupdate_id;
if selectors[0] == ONE {
input_states.push(*state);
}
last_state = *state;
},
HasherOp::Padding { mrupdate_id, .. } => {
*mrupdate_id = new_mrupdate_id;
},
HasherOp::Permutation { .. } => {},
}
self.row_count += op.row_count();
self.ops.push(op);
}
(last_state, input_states)
}
pub fn fill_trace(self, trace: &mut ChipletTraceFragment) {
debug_assert_eq!(self.trace_len(), trace.len(), "inconsistent trace lengths");
debug_assert_eq!(TRACE_WIDTH, trace.width(), "inconsistent trace widths");
let mut chunk = [ZERO; TRACE_WIDTH * HASH_CYCLE_LEN];
let mut row_idx = 0usize;
for op in &self.ops {
let n = op.row_count();
debug_assert!(n <= HASH_CYCLE_LEN);
let is_ctrl = matches!(op, HasherOp::Controller { .. } | HasherOp::Padding { .. });
let (chunk_rows, _) = chunk.as_mut_slice().as_chunks_mut::<TRACE_WIDTH>();
match op {
HasherOp::Controller {
selectors,
state,
node_index,
mrupdate_id,
is_boundary,
direction_bit,
} => {
write_controller_row(
&mut chunk_rows[0],
*selectors,
state,
*node_index,
*mrupdate_id,
*is_boundary,
*direction_bit,
);
},
HasherOp::Permutation { init_state, multiplicity } => {
write_permutation_cycle(
&mut chunk_rows[..HASH_CYCLE_LEN],
init_state,
*multiplicity,
);
},
HasherOp::Padding { count, mrupdate_id } => {
let padding_selectors = [ZERO, ONE, ZERO];
for row in &mut chunk_rows[..*count] {
write_controller_row(
row,
padding_selectors,
&[ZERO; STATE_WIDTH],
ZERO,
*mrupdate_id,
ZERO,
ZERO,
);
}
},
}
trace.copy_rows_into(row_idx, &chunk[..n * TRACE_WIDTH]);
if is_ctrl {
for i in 0..n {
trace.set_s_01(row_idx + i);
}
}
row_idx += n;
}
debug_assert_eq!(row_idx, self.row_count);
}
}
fn write_controller_row(
row: &mut [Felt; TRACE_WIDTH],
selectors: Selectors,
state: &HasherState,
node_index: Felt,
mrupdate_id: Felt,
is_boundary: Felt,
direction_bit: Felt,
) {
let (overlay, tail) = row.split_at_mut(S_PERM_OFFSET);
let cols: &mut ControllerCols<Felt> = overlay.borrow_mut();
cols.s0 = selectors[0];
cols.s1 = selectors[1];
cols.s2 = selectors[2];
cols.state = *state;
cols.node_index = node_index;
cols.mrupdate_id = mrupdate_id;
cols.is_boundary = is_boundary;
cols.direction_bit = direction_bit;
tail[0] = ZERO;
}
fn write_permutation_cycle(
rows: &mut [[Felt; TRACE_WIDTH]],
init_state: &HasherState,
multiplicity: Felt,
) {
debug_assert_eq!(rows.len(), HASH_CYCLE_LEN);
let mut state = *init_state;
write_perm_row(&mut rows[0], &state, multiplicity, [ZERO; 3]);
Hasher::apply_matmul_external(&mut state);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_INITIAL[0]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
for (r, row) in rows.iter_mut().enumerate().take(3 + 1).skip(1) {
write_perm_row(row, &state, multiplicity, [ZERO; 3]);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_INITIAL[r]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
}
for triple in 0..7_usize {
let base = triple * 3;
let pre_state = state;
let mut witnesses = [ZERO; 3];
for (k, witness) in witnesses.iter_mut().enumerate() {
let sbox_out = (state[0] + Hasher::ARK_INT[base + k]).exp_const_u64::<7>();
*witness = sbox_out;
state[0] = sbox_out;
Hasher::matmul_internal(&mut state, Hasher::MAT_DIAG);
}
write_perm_row(&mut rows[4 + triple], &pre_state, multiplicity, witnesses);
}
let pre_state = state;
let w0 = (state[0] + Hasher::ARK_INT[21]).exp_const_u64::<7>();
state[0] = w0;
Hasher::matmul_internal(&mut state, Hasher::MAT_DIAG);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_TERMINAL[0]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
write_perm_row(&mut rows[11], &pre_state, multiplicity, [w0, ZERO, ZERO]);
for r in 1..=3 {
write_perm_row(&mut rows[11 + r], &state, multiplicity, [ZERO; 3]);
Hasher::add_rc(&mut state, &Hasher::ARK_EXT_TERMINAL[r]);
Hasher::apply_sbox(&mut state);
Hasher::apply_matmul_external(&mut state);
}
write_perm_row(&mut rows[15], &state, multiplicity, [ZERO; 3]);
}
fn write_perm_row(
row: &mut [Felt; TRACE_WIDTH],
state: &HasherState,
multiplicity: Felt,
witnesses: [Felt; 3],
) {
let (overlay, tail) = row.split_at_mut(S_PERM_OFFSET);
let cols: &mut PermutationCols<Felt> = overlay.borrow_mut();
cols.witnesses = witnesses;
cols.state = *state;
cols.multiplicity = multiplicity;
cols.set_unused_padding(ZERO);
tail[0] = ONE;
}