use alloc::collections::BTreeMap;
use miden_air::trace::chiplets::hasher::{
DIGEST_RANGE, HASH_CYCLE_LEN, LINEAR_HASH, MP_VERIFY, MR_UPDATE_NEW, MR_UPDATE_OLD, RATE_LEN,
RETURN_HASH, RETURN_STATE, STATE_WIDTH, Selectors,
};
use miden_core::chiplets::hasher::apply_permutation;
use super::{
Felt, HasherState, MerklePath, MerkleRootUpdate, ONE, OpBatch, TraceFragment, Word as Digest,
ZERO,
};
mod trace;
use trace::HasherTrace;
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests;
type DigestKey = [u64; 4];
type StateKey = [u64; STATE_WIDTH];
fn digest_to_key(digest: Digest) -> DigestKey {
let elems = digest.as_elements();
core::array::from_fn(|i| elems[i].as_canonical_u64())
}
fn state_to_key(state: &HasherState) -> StateKey {
core::array::from_fn(|i| state[i].as_canonical_u64())
}
fn key_to_state(key: &StateKey) -> HasherState {
core::array::from_fn(|i| Felt::new_unchecked(key[i]))
}
#[derive(Debug, Default)]
pub struct Hasher {
trace: HasherTrace,
memoized_trace_map: BTreeMap<DigestKey, (usize, usize)>,
perm_request_map: BTreeMap<StateKey, u64>,
mrupdate_id: Felt,
finalized: bool,
}
impl Hasher {
pub(super) fn trace_len(&self) -> usize {
if self.finalized {
self.trace.trace_len()
} else {
self.estimate_trace_len()
}
}
pub(super) fn region_lengths(&self) -> (usize, usize) {
debug_assert!(!self.finalized, "region_lengths must be called before finalization");
let controller_len = self.trace.trace_len().next_multiple_of(HASH_CYCLE_LEN);
let perm_len = self.perm_request_map.len() * HASH_CYCLE_LEN;
(controller_len, perm_len)
}
fn estimate_trace_len(&self) -> usize {
let (controller_len, perm_len) = self.region_lengths();
controller_len + perm_len
}
pub fn permute(&mut self, state: HasherState) -> (Felt, HasherState) {
let addr = self.trace.next_row_addr();
let permuted = self.append_controller_permutation(
LINEAR_HASH,
RETURN_STATE,
state,
ZERO, ZERO, ONE, ONE, ZERO, ZERO, );
(addr, permuted)
}
pub fn hash_control_block(
&mut self,
h1: Digest,
h2: Digest,
domain: Felt,
expected_hash: Digest,
) -> (Felt, Digest) {
if let Some(memoized) = self.replay_memoized_trace(expected_hash) {
return memoized;
}
let addr = self.trace.next_row_addr();
let init_state = init_state_from_words_with_domain(&h1, &h2, domain);
let permuted = self.append_controller_permutation(
LINEAR_HASH,
RETURN_HASH,
init_state,
ZERO,
ZERO, ONE,
ONE, ZERO,
ZERO, );
self.insert_to_memoized_trace_map(addr, expected_hash);
let result = get_digest(&permuted);
(addr, result)
}
pub fn hash_basic_block(
&mut self,
op_batches: &[OpBatch],
expected_hash: Digest,
) -> (Felt, Digest) {
if let Some(memoized) = self.replay_memoized_trace(expected_hash) {
return memoized;
}
let addr = self.trace.next_row_addr();
let init_state = init_state(op_batches[0].groups(), ZERO);
let num_batches = op_batches.len();
if num_batches == 1 {
let permuted = self.append_controller_permutation(
LINEAR_HASH,
RETURN_HASH,
init_state,
ZERO,
ZERO,
ONE,
ONE,
ZERO,
ZERO,
);
self.insert_to_memoized_trace_map(addr, expected_hash);
let result = get_digest(&permuted);
return (addr, result);
}
let mut state = self.append_controller_permutation(
LINEAR_HASH,
RETURN_STATE,
init_state,
ZERO,
ZERO,
ONE,
ZERO,
ZERO,
ZERO,
);
for batch in op_batches.iter().take(num_batches - 1).skip(1) {
absorb_into_state(&mut state, batch.groups());
state = self.append_controller_permutation(
LINEAR_HASH,
RETURN_STATE,
state,
ZERO,
ZERO,
ZERO,
ZERO,
ZERO,
ZERO,
);
}
absorb_into_state(&mut state, op_batches[num_batches - 1].groups());
let permuted = self.append_controller_permutation(
LINEAR_HASH,
RETURN_HASH,
state,
ZERO,
ZERO,
ZERO,
ONE,
ZERO,
ZERO,
);
self.insert_to_memoized_trace_map(addr, expected_hash);
let result = get_digest(&permuted);
(addr, result)
}
pub fn build_merkle_root(
&mut self,
value: Digest,
path: &MerklePath,
index: Felt,
) -> (Felt, Digest) {
let addr = self.trace.next_row_addr();
let root = self.verify_merkle_path(
value,
path,
index.as_canonical_u64(),
MerklePathContext::MpVerify,
);
(addr, root)
}
pub fn update_merkle_root(
&mut self,
old_value: Digest,
new_value: Digest,
path: &MerklePath,
index: Felt,
) -> MerkleRootUpdate {
self.mrupdate_id += ONE;
let address = self.trace.next_row_addr();
let index = index.as_canonical_u64();
let old_root =
self.verify_merkle_path(old_value, path, index, MerklePathContext::MrUpdateOld);
let new_root =
self.verify_merkle_path(new_value, path, index, MerklePathContext::MrUpdateNew);
MerkleRootUpdate { address, old_root, new_root }
}
pub(super) fn fill_trace(mut self, trace: &mut TraceFragment) {
if !self.finalized {
let estimated_len = self.estimate_trace_len();
self.finalize_trace();
debug_assert_eq!(
estimated_len,
self.trace.trace_len(),
"hasher trace length estimate ({}) diverged from actual ({})",
estimated_len,
self.trace.trace_len(),
);
}
self.trace.fill_trace(trace);
}
fn finalize_trace(&mut self) {
if self.finalized {
return;
}
self.trace.pad_to_cycle_boundary(self.mrupdate_id);
for (key, multiplicity) in core::mem::take(&mut self.perm_request_map) {
let state = key_to_state(&key);
self.trace.append_permutation_cycle(&state, Felt::new_unchecked(multiplicity));
}
self.finalized = true;
}
fn append_controller_permutation(
&mut self,
init_selectors: Selectors,
final_selectors: Selectors,
state: HasherState,
input_node_index: Felt,
output_node_index: Felt,
is_boundary_input: Felt,
is_boundary_output: Felt,
input_direction_bit: Felt,
output_direction_bit: Felt,
) -> HasherState {
self.trace.append_controller_row(
init_selectors,
&state,
input_node_index,
self.mrupdate_id,
is_boundary_input,
input_direction_bit,
);
let mut permuted = state;
apply_permutation(&mut permuted);
self.trace.append_controller_row(
final_selectors,
&permuted,
output_node_index,
self.mrupdate_id,
is_boundary_output,
output_direction_bit,
);
self.record_perm_request(&state);
permuted
}
fn verify_merkle_path(
&mut self,
value: Digest,
path: &MerklePath,
mut index: u64,
context: MerklePathContext,
) -> Digest {
assert!(!path.is_empty(), "path is empty");
assert!(
index.checked_shr(path.len() as u32).unwrap_or(0) == 0,
"invalid index for the path"
);
let main_selectors = context.main_selectors();
let depth = path.len();
let mut root = value;
for (i, &sibling) in path.iter().enumerate() {
let is_first = i == 0;
let is_last = i == depth - 1;
let is_boundary_input = if is_first { ONE } else { ZERO };
let is_boundary_output = if is_last { ONE } else { ZERO };
let b_i = index & 1;
let state = build_merge_state(&root, &sibling, b_i);
let input_node_idx = Felt::new_unchecked(index);
let output_node_idx = Felt::new_unchecked(index >> 1);
let b_next = if is_last { 0 } else { (index >> 1) & 1 };
let final_selectors = if is_last { RETURN_HASH } else { RETURN_STATE };
let permuted = self.append_controller_permutation(
main_selectors,
final_selectors,
state,
input_node_idx,
output_node_idx,
is_boundary_input,
is_boundary_output,
Felt::new_unchecked(b_i), Felt::new_unchecked(b_next), );
root = get_digest(&permuted);
index >>= 1;
}
root
}
fn record_perm_request(&mut self, state: &HasherState) {
let key = state_to_key(state);
*self.perm_request_map.entry(key).or_insert(0) += 1;
}
fn replay_memoized_trace(&mut self, expected_hash: Digest) -> Option<(Felt, Digest)> {
let (start_row, end_row) = match self.get_memoized_trace(expected_hash) {
Some(&(s, e)) => (s, e),
None => return None,
};
let addr = self.trace.next_row_addr();
let mut state = [ZERO; STATE_WIDTH];
let append_start = self.trace.trace_len();
self.trace.copy_trace(&mut state, start_row..end_row);
let append_end = self.trace.trace_len();
self.trace
.overwrite_mrupdate_id_in_range(append_start..append_end, self.mrupdate_id);
let input_states = self.trace.input_states_in_range(append_start..append_end);
for input_state in input_states {
self.record_perm_request(&input_state);
}
let result = get_digest(&state);
Some((addr, result))
}
fn get_memoized_trace(&self, hash: Digest) -> Option<&(usize, usize)> {
self.memoized_trace_map.get(&digest_to_key(hash))
}
fn insert_to_memoized_trace_map(&mut self, addr: Felt, hash: Digest) {
let start_row = addr.as_canonical_u64() as usize - 1;
let end_row = self.trace.next_row_addr().as_canonical_u64() as usize - 1;
self.memoized_trace_map.insert(digest_to_key(hash), (start_row, end_row));
}
}
#[derive(Debug, Clone, Copy)]
enum MerklePathContext {
MpVerify,
MrUpdateOld,
MrUpdateNew,
}
impl MerklePathContext {
pub fn main_selectors(self) -> Selectors {
match self {
Self::MpVerify => MP_VERIFY,
Self::MrUpdateOld => MR_UPDATE_OLD,
Self::MrUpdateNew => MR_UPDATE_NEW,
}
}
}
#[inline(always)]
fn build_merge_state(a: &Digest, b: &Digest, index_bit: u64) -> HasherState {
match index_bit {
0 => init_state_from_words(a, b),
1 => init_state_from_words(b, a),
_ => panic!("index bit is not a binary value"),
}
}
#[inline(always)]
pub fn init_state(init_values: &[Felt; RATE_LEN], padding_flag: Felt) -> [Felt; STATE_WIDTH] {
debug_assert!(
padding_flag == ZERO || padding_flag == ONE,
"first capacity element must be 0 or 1"
);
let mut state = [ZERO; STATE_WIDTH];
state[..RATE_LEN].copy_from_slice(init_values);
state[RATE_LEN] = padding_flag;
state
}
#[inline(always)]
pub fn init_state_from_words(w1: &Digest, w2: &Digest) -> [Felt; STATE_WIDTH] {
init_state_from_words_with_domain(w1, w2, ZERO)
}
#[inline(always)]
pub fn init_state_from_words_with_domain(
w1: &Digest,
w2: &Digest,
domain: Felt,
) -> [Felt; STATE_WIDTH] {
[w1[0], w1[1], w1[2], w1[3], w2[0], w2[1], w2[2], w2[3], ZERO, domain, ZERO, ZERO]
}
#[inline(always)]
pub fn absorb_into_state(state: &mut [Felt; STATE_WIDTH], values: &[Felt; RATE_LEN]) {
state[..RATE_LEN].copy_from_slice(values);
}
pub fn get_digest(state: &[Felt; STATE_WIDTH]) -> Digest {
state[DIGEST_RANGE].try_into().expect("failed to get digest from hasher state")
}