use super::{
Felt, FieldElement, HasherState, LookupTableRow, OpBatch, StarkField, TraceFragment, Vec, Word,
ZERO,
};
use vm_core::{
chiplets::hasher::{
absorb_into_state, get_digest, init_state, init_state_from_words, Digest, Selectors,
HASH_CYCLE_LEN, LINEAR_HASH, LINEAR_HASH_LABEL, MP_VERIFY, MP_VERIFY_LABEL, MR_UPDATE_NEW,
MR_UPDATE_NEW_LABEL, MR_UPDATE_OLD, MR_UPDATE_OLD_LABEL, RETURN_HASH, RETURN_HASH_LABEL,
RETURN_STATE, RETURN_STATE_LABEL, STATE_WIDTH, TRACE_WIDTH,
},
utils::collections::BTreeMap,
};
mod lookups;
pub use lookups::HasherLookup;
use lookups::HasherLookupContext;
mod trace;
use trace::HasherTrace;
mod aux_trace;
pub use aux_trace::{AuxTraceBuilder, SiblingTableRow, SiblingTableUpdate};
#[cfg(test)]
mod tests;
#[derive(Default)]
pub struct Hasher {
trace: HasherTrace,
aux_trace: AuxTraceBuilder,
memoized_trace_map: BTreeMap<[u8; 32], (usize, usize)>,
}
impl Hasher {
pub(super) fn trace_len(&self) -> usize {
self.trace.trace_len()
}
#[inline(always)]
fn get_lookup(&self, label: u8, index: Felt, context: HasherLookupContext) -> HasherLookup {
let addr = match context {
HasherLookupContext::Start => self.trace.next_row_addr().as_int() as u32,
_ => self.trace_len() as u32,
};
HasherLookup::new(label, addr, index, context)
}
pub(super) fn permute(
&mut self,
mut state: HasherState,
lookups: &mut Vec<HasherLookup>,
) -> (Felt, HasherState) {
let addr = self.trace.next_row_addr();
let lookup = self.get_lookup(LINEAR_HASH_LABEL, ZERO, HasherLookupContext::Start);
lookups.push(lookup);
self.trace
.append_permutation(&mut state, LINEAR_HASH, RETURN_STATE);
let lookup = self.get_lookup(RETURN_STATE_LABEL, ZERO, HasherLookupContext::Return);
lookups.push(lookup);
(addr, state)
}
pub(super) fn hash_control_block(
&mut self,
h1: Word,
h2: Word,
expected_hash: Digest,
lookups: &mut Vec<HasherLookup>,
) -> (Felt, Word) {
let addr = self.trace.next_row_addr();
let mut state = init_state_from_words(&h1, &h2);
let lookup = self.get_lookup(LINEAR_HASH_LABEL, ZERO, HasherLookupContext::Start);
lookups.push(lookup);
if let Some((start_row, end_row)) = self.get_memoized_trace(expected_hash) {
self.trace.copy_trace(&mut state, *start_row..*end_row);
} else {
self.trace
.append_permutation(&mut state, LINEAR_HASH, RETURN_HASH);
self.insert_to_memoized_trace_map(addr, expected_hash);
};
let lookup = self.get_lookup(RETURN_HASH_LABEL, ZERO, HasherLookupContext::Return);
lookups.push(lookup);
let result = get_digest(&state);
(addr, result)
}
pub(super) fn hash_span_block(
&mut self,
op_batches: &[OpBatch],
num_op_groups: usize,
expected_hash: Digest,
lookups: &mut Vec<HasherLookup>,
) -> (Felt, Word) {
const START: Selectors = LINEAR_HASH;
const START_LABEL: u8 = LINEAR_HASH_LABEL;
const RETURN: Selectors = RETURN_HASH;
const RETURN_LABEL: u8 = RETURN_HASH_LABEL;
const ABSORB: Selectors = LINEAR_HASH;
const ABSORB_LABEL: u8 = LINEAR_HASH_LABEL;
const CONTINUE: Selectors = [ZERO, LINEAR_HASH[1], LINEAR_HASH[2]];
let addr = self.trace.next_row_addr();
let mut state = init_state(op_batches[0].groups(), num_op_groups);
let lookup = self.get_lookup(START_LABEL, ZERO, HasherLookupContext::Start);
lookups.push(lookup);
let (start_row, end_row, is_memoized) =
if let Some((start_row, end_row)) = self.get_memoized_trace(expected_hash) {
(*start_row, *end_row, true)
} else {
(0, 0, false)
};
let num_batches = op_batches.len();
if !is_memoized {
if num_batches == 1 {
self.trace.append_permutation(&mut state, START, RETURN);
} else {
self.trace.append_permutation(&mut state, START, ABSORB);
for batch in op_batches.iter().take(num_batches - 1).skip(1) {
absorb_into_state(&mut state, batch.groups());
let lookup = self.get_lookup(ABSORB_LABEL, ZERO, HasherLookupContext::Absorb);
lookups.push(lookup);
self.trace.append_permutation(&mut state, CONTINUE, ABSORB);
}
absorb_into_state(&mut state, op_batches[num_batches - 1].groups());
let lookup = self.get_lookup(ABSORB_LABEL, ZERO, HasherLookupContext::Absorb);
lookups.push(lookup);
self.trace.append_permutation(&mut state, CONTINUE, RETURN);
}
self.insert_to_memoized_trace_map(addr, expected_hash);
} else if num_batches == 1 {
self.trace.copy_trace(&mut state, start_row..end_row);
} else {
for i in 1..num_batches {
let lookup_addr = self.trace_len() + i * HASH_CYCLE_LEN;
let lookup = HasherLookup::new(
ABSORB_LABEL,
lookup_addr as u32,
ZERO,
HasherLookupContext::Absorb,
);
lookups.push(lookup);
}
self.trace.copy_trace(&mut state, start_row..end_row);
}
let lookup = self.get_lookup(RETURN_LABEL, ZERO, HasherLookupContext::Return);
lookups.push(lookup);
let result = get_digest(&state);
(addr, result)
}
pub(super) fn build_merkle_root(
&mut self,
value: Word,
path: &[Word],
index: Felt,
lookups: &mut Vec<HasherLookup>,
) -> (Felt, Word) {
let addr = self.trace.next_row_addr();
let root = self.verify_merkle_path(
value,
path,
index.as_int(),
MerklePathContext::MpVerify,
lookups,
);
(addr, root)
}
pub(super) fn update_merkle_root(
&mut self,
old_value: Word,
new_value: Word,
path: &[Word],
index: Felt,
lookups: &mut Vec<HasherLookup>,
) -> (Felt, Word, Word) {
let addr = self.trace.next_row_addr();
let index = index.as_int();
let old_root = self.verify_merkle_path(
old_value,
path,
index,
MerklePathContext::MrUpdateOld,
lookups,
);
let new_root = self.verify_merkle_path(
new_value,
path,
index,
MerklePathContext::MrUpdateNew,
lookups,
);
(addr, old_root, new_root)
}
pub(super) fn fill_trace(self, trace: &mut TraceFragment) -> AuxTraceBuilder {
self.trace.fill_trace(trace);
self.aux_trace
}
fn verify_merkle_path(
&mut self,
value: Word,
path: &[Word],
mut index: u64,
context: MerklePathContext,
lookups: &mut Vec<HasherLookup>,
) -> Word {
assert!(!path.is_empty(), "path is empty");
assert!(index >> path.len() == 0, "invalid index for the path");
let mut root = value;
let mut depth = path.len() - 1;
let main_selectors = context.main_selectors();
let part_selectors = context.part_selectors();
if path.len() == 1 {
self.update_sibling_hints(context, index, path[0], depth);
self.verify_mp_leg(
root,
path[0],
&mut index,
main_selectors,
RETURN_HASH,
lookups,
)
} else {
let sibling = path[0];
self.update_sibling_hints(context, index, sibling, depth);
root = self.verify_mp_leg(
root,
sibling,
&mut index,
main_selectors,
main_selectors,
lookups,
);
depth -= 1;
for &sibling in &path[1..path.len() - 1] {
self.update_sibling_hints(context, index, sibling, depth);
root = self.verify_mp_leg(
root,
sibling,
&mut index,
part_selectors,
main_selectors,
lookups,
);
depth -= 1;
}
let sibling = path[path.len() - 1];
self.update_sibling_hints(context, index, sibling, depth);
self.verify_mp_leg(
root,
sibling,
&mut index,
part_selectors,
RETURN_HASH,
lookups,
)
}
}
fn verify_mp_leg(
&mut self,
root: Word,
sibling: Word,
index: &mut u64,
init_selectors: Selectors,
final_selectors: Selectors,
lookups: &mut Vec<HasherLookup>,
) -> Word {
let index_bit = *index & 1;
let mut state = build_merge_state(&root, &sibling, index_bit);
let context = HasherLookupContext::Start;
if let Some(label) = get_selector_context_label(init_selectors, context) {
let lookup = self.get_lookup(label, Felt::new(*index), context);
lookups.push(lookup);
}
let (init_index, rest_index) = if init_selectors[0] == ZERO {
(Felt::new(*index >> 1), Felt::new(*index >> 1))
} else {
(Felt::new(*index), Felt::new(*index >> 1))
};
self.trace.append_permutation_with_index(
&mut state,
init_selectors,
final_selectors,
init_index,
rest_index,
);
*index >>= 1;
let context = HasherLookupContext::Return;
if let Some(label) = get_selector_context_label(final_selectors, context) {
let lookup = self.get_lookup(label, Felt::new(*index), context);
lookups.push(lookup);
}
get_digest(&state)
}
fn update_sibling_hints(
&mut self,
context: MerklePathContext,
index: u64,
sibling: Word,
depth: usize,
) {
let step = self.trace.trace_len() as u32;
match context {
MerklePathContext::MrUpdateOld => {
self.aux_trace
.sibling_added(step, Felt::new(index), sibling);
}
MerklePathContext::MrUpdateNew => {
self.aux_trace.sibling_removed(step, depth);
}
_ => (),
}
}
fn get_memoized_trace(&self, hash: Digest) -> Option<&(usize, usize)> {
let key: [u8; 32] = hash.into();
self.memoized_trace_map.get(&key)
}
fn insert_to_memoized_trace_map(&mut self, addr: Felt, hash: Digest) {
let key: [u8; 32] = hash.into();
let start_row = addr.as_int() as usize - 1;
let end_row = self.trace.next_row_addr().as_int() as usize - 1;
self.memoized_trace_map.insert(key, (start_row, end_row));
}
}
#[derive(Debug, Clone, Copy)]
enum MerklePathContext {
MpVerify,
MrUpdateOld,
MrUpdateNew,
}
impl MerklePathContext {
pub fn main_selectors(&self) -> Selectors {
match self {
Self::MpVerify => MP_VERIFY,
Self::MrUpdateOld => MR_UPDATE_OLD,
Self::MrUpdateNew => MR_UPDATE_NEW,
}
}
pub fn part_selectors(&self) -> Selectors {
let selectors = self.main_selectors();
[ZERO, selectors[1], selectors[2]]
}
}
#[inline(always)]
fn build_merge_state(a: &Word, b: &Word, index_bit: u64) -> HasherState {
match index_bit {
0 => init_state_from_words(a, b),
1 => init_state_from_words(b, a),
_ => panic!("index bit is not a binary value"),
}
}
pub fn get_selector_context_label(
selectors: Selectors,
context: HasherLookupContext,
) -> Option<u8> {
match context {
HasherLookupContext::Start => {
if selectors == LINEAR_HASH {
Some(LINEAR_HASH_LABEL)
} else if selectors == MP_VERIFY {
Some(MP_VERIFY_LABEL)
} else if selectors == MR_UPDATE_OLD {
Some(MR_UPDATE_OLD_LABEL)
} else if selectors == MR_UPDATE_NEW {
Some(MR_UPDATE_NEW_LABEL)
} else {
None
}
}
HasherLookupContext::Return => {
if selectors == RETURN_HASH {
Some(RETURN_HASH_LABEL)
} else if selectors == RETURN_STATE {
Some(RETURN_STATE_LABEL)
} else {
None
}
}
_ => {
if selectors == LINEAR_HASH {
Some(LINEAR_HASH_LABEL)
} else {
None
}
}
}
}