use super::{
Felt, HasherState, MerklePath, MerkleRootUpdate, OpBatch, TraceFragment, Word, ONE, ZERO,
};
use alloc::collections::BTreeMap;
use miden_air::trace::chiplets::hasher::{
Digest, Selectors, DIGEST_LEN, DIGEST_RANGE, LINEAR_HASH, MP_VERIFY, MR_UPDATE_NEW,
MR_UPDATE_OLD, RATE_LEN, RETURN_HASH, RETURN_STATE, STATE_WIDTH, TRACE_WIDTH,
};
mod trace;
use trace::HasherTrace;
#[cfg(test)]
mod tests;
#[derive(Default)]
pub struct Hasher {
trace: HasherTrace,
memoized_trace_map: BTreeMap<[u8; 32], (usize, usize)>,
}
impl Hasher {
pub(super) fn trace_len(&self) -> usize {
self.trace.trace_len()
}
pub(super) fn permute(&mut self, mut state: HasherState) -> (Felt, HasherState) {
let addr = self.trace.next_row_addr();
self.trace.append_permutation(&mut state, LINEAR_HASH, RETURN_STATE);
(addr, state)
}
pub(super) fn hash_control_block(
&mut self,
h1: Word,
h2: Word,
domain: Felt,
expected_hash: Digest,
) -> (Felt, Word) {
let addr = self.trace.next_row_addr();
let mut state = init_state_from_words_with_domain(&h1, &h2, domain);
if let Some((start_row, end_row)) = self.get_memoized_trace(expected_hash) {
self.trace.copy_trace(&mut state, *start_row..*end_row);
} else {
self.trace.append_permutation(&mut state, LINEAR_HASH, RETURN_HASH);
self.insert_to_memoized_trace_map(addr, expected_hash);
};
let result = get_digest(&state);
(addr, result)
}
pub(super) fn hash_span_block(
&mut self,
op_batches: &[OpBatch],
expected_hash: Digest,
) -> (Felt, Word) {
const START: Selectors = LINEAR_HASH;
const RETURN: Selectors = RETURN_HASH;
const ABSORB: Selectors = LINEAR_HASH;
const CONTINUE: Selectors = [ZERO, LINEAR_HASH[1], LINEAR_HASH[2]];
let addr = self.trace.next_row_addr();
let mut state = init_state(op_batches[0].groups(), ZERO);
let (start_row, end_row, is_memoized) =
if let Some((start_row, end_row)) = self.get_memoized_trace(expected_hash) {
(*start_row, *end_row, true)
} else {
(0, 0, false)
};
let num_batches = op_batches.len();
if !is_memoized {
if num_batches == 1 {
self.trace.append_permutation(&mut state, START, RETURN);
} else {
self.trace.append_permutation(&mut state, START, ABSORB);
for batch in op_batches.iter().take(num_batches - 1).skip(1) {
absorb_into_state(&mut state, batch.groups());
self.trace.append_permutation(&mut state, CONTINUE, ABSORB);
}
absorb_into_state(&mut state, op_batches[num_batches - 1].groups());
self.trace.append_permutation(&mut state, CONTINUE, RETURN);
}
self.insert_to_memoized_trace_map(addr, expected_hash);
} else {
self.trace.copy_trace(&mut state, start_row..end_row);
}
let result = get_digest(&state);
(addr, result)
}
pub(super) fn build_merkle_root(
&mut self,
value: Word,
path: &MerklePath,
index: Felt,
) -> (Felt, Word) {
let addr = self.trace.next_row_addr();
let root =
self.verify_merkle_path(value, path, index.as_int(), MerklePathContext::MpVerify);
(addr, root)
}
pub(super) fn update_merkle_root(
&mut self,
old_value: Word,
new_value: Word,
path: &MerklePath,
index: Felt,
) -> MerkleRootUpdate {
let address = self.trace.next_row_addr();
let index = index.as_int();
let old_root =
self.verify_merkle_path(old_value, path, index, MerklePathContext::MrUpdateOld);
let new_root =
self.verify_merkle_path(new_value, path, index, MerklePathContext::MrUpdateNew);
MerkleRootUpdate {
address,
old_root,
new_root,
}
}
pub(super) fn fill_trace(self, trace: &mut TraceFragment) {
self.trace.fill_trace(trace)
}
fn verify_merkle_path(
&mut self,
value: Word,
path: &MerklePath,
mut index: u64,
context: MerklePathContext,
) -> Word {
assert!(!path.is_empty(), "path is empty");
assert!(
index.checked_shr(path.len() as u32).unwrap_or(0) == 0,
"invalid index for the path"
);
let mut root = value;
let main_selectors = context.main_selectors();
let part_selectors = context.part_selectors();
if path.len() == 1 {
self.verify_mp_leg(root, path[0], &mut index, main_selectors, RETURN_HASH)
} else {
let sibling = path[0];
root = self.verify_mp_leg(root, sibling, &mut index, main_selectors, main_selectors);
for &sibling in &path[1..path.len() - 1] {
root =
self.verify_mp_leg(root, sibling, &mut index, part_selectors, main_selectors);
}
let sibling = path[path.len() - 1];
self.verify_mp_leg(root, sibling, &mut index, part_selectors, RETURN_HASH)
}
}
fn verify_mp_leg(
&mut self,
root: Word,
sibling: Digest,
index: &mut u64,
init_selectors: Selectors,
final_selectors: Selectors,
) -> Word {
let index_bit = *index & 1;
let mut state = build_merge_state(&root, &sibling, index_bit);
let (init_index, rest_index) = if init_selectors[0] == ZERO {
(Felt::new(*index >> 1), Felt::new(*index >> 1))
} else {
(Felt::new(*index), Felt::new(*index >> 1))
};
self.trace.append_permutation_with_index(
&mut state,
init_selectors,
final_selectors,
init_index,
rest_index,
);
*index >>= 1;
get_digest(&state)
}
fn get_memoized_trace(&self, hash: Digest) -> Option<&(usize, usize)> {
let key: [u8; 32] = hash.into();
self.memoized_trace_map.get(&key)
}
fn insert_to_memoized_trace_map(&mut self, addr: Felt, hash: Digest) {
let key: [u8; 32] = hash.into();
let start_row = addr.as_int() as usize - 1;
let end_row = self.trace.next_row_addr().as_int() as usize - 1;
self.memoized_trace_map.insert(key, (start_row, end_row));
}
}
#[derive(Debug, Clone, Copy)]
enum MerklePathContext {
MpVerify,
MrUpdateOld,
MrUpdateNew,
}
impl MerklePathContext {
pub fn main_selectors(&self) -> Selectors {
match self {
Self::MpVerify => MP_VERIFY,
Self::MrUpdateOld => MR_UPDATE_OLD,
Self::MrUpdateNew => MR_UPDATE_NEW,
}
}
pub fn part_selectors(&self) -> Selectors {
let selectors = self.main_selectors();
[ZERO, selectors[1], selectors[2]]
}
}
#[inline(always)]
fn build_merge_state(a: &Word, b: &Word, index_bit: u64) -> HasherState {
match index_bit {
0 => init_state_from_words(a, b),
1 => init_state_from_words(b, a),
_ => panic!("index bit is not a binary value"),
}
}
#[inline(always)]
pub fn init_state(init_values: &[Felt; RATE_LEN], padding_flag: Felt) -> [Felt; STATE_WIDTH] {
debug_assert!(
padding_flag == ZERO || padding_flag == ONE,
"first capacity element must be 0 or 1"
);
[
padding_flag,
ZERO,
ZERO,
ZERO,
init_values[0],
init_values[1],
init_values[2],
init_values[3],
init_values[4],
init_values[5],
init_values[6],
init_values[7],
]
}
#[inline(always)]
pub fn init_state_from_words(w1: &Word, w2: &Word) -> [Felt; STATE_WIDTH] {
init_state_from_words_with_domain(w1, w2, ZERO)
}
#[inline(always)]
pub fn init_state_from_words_with_domain(
w1: &Word,
w2: &Word,
domain: Felt,
) -> [Felt; STATE_WIDTH] {
[ZERO, domain, ZERO, ZERO, w1[0], w1[1], w1[2], w1[3], w2[0], w2[1], w2[2], w2[3]]
}
#[inline(always)]
pub fn absorb_into_state(state: &mut [Felt; STATE_WIDTH], values: &[Felt; RATE_LEN]) {
state[4] = values[0];
state[5] = values[1];
state[6] = values[2];
state[7] = values[3];
state[8] = values[4];
state[9] = values[5];
state[10] = values[6];
state[11] = values[7];
}
pub fn get_digest(state: &[Felt; STATE_WIDTH]) -> [Felt; DIGEST_LEN] {
state[DIGEST_RANGE].try_into().expect("failed to get digest from hasher state")
}