use alloc::vec::Vec;
use miden_air::trace::chiplets::hasher::{
DIGEST_LEN, HASH_CYCLE_LEN, NUM_ROUNDS, NUM_SELECTORS, STATE_COL_RANGE,
};
use miden_core::{
ONE, ZERO,
chiplets::hasher,
crypto::merkle::{MerkleTree, NodeIndex},
mast::{
BasicBlockNodeBuilder, DecoratorId, JoinNodeBuilder, LoopNodeBuilder, MastForest,
MastForestContributor, MastNode, MastNodeExt, SplitNodeBuilder,
},
operations::Operation,
};
use miden_utils_testing::rand::rand_array;
use super::{
Digest, Felt, Hasher, HasherState, LINEAR_HASH, MP_VERIFY, MR_UPDATE_NEW, MR_UPDATE_OLD,
MerklePath, RETURN_HASH, RETURN_STATE, Selectors, TRACE_WIDTH, TraceFragment,
init_state_from_words,
};
#[test]
fn hasher_permute() {
let mut hasher = Hasher::default();
let init_state: HasherState = rand_array();
let (addr, final_state) = hasher.permute(init_state);
assert_eq!(ONE, addr);
let expected_state = apply_permutation(init_state);
assert_eq!(expected_state, final_state);
let trace = build_trace(hasher, HASH_CYCLE_LEN);
check_selector_trace(&trace, 0, LINEAR_HASH, RETURN_STATE);
check_hasher_state_trace(&trace, 0, init_state);
assert_eq!(trace.last().unwrap(), &[ZERO; HASH_CYCLE_LEN]);
let mut hasher = Hasher::default();
let init_state1: HasherState = rand_array();
let (addr1, final_state1) = hasher.permute(init_state1);
let init_state2: HasherState = rand_array();
let (addr2, final_state2) = hasher.permute(init_state2);
assert_eq!(ONE, addr1);
assert_eq!(Felt::new(HASH_CYCLE_LEN as u64 + 1), addr2);
let expected_state1 = apply_permutation(init_state1);
assert_eq!(expected_state1, final_state1);
let expected_state2 = apply_permutation(init_state2);
assert_eq!(expected_state2, final_state2);
let trace = build_trace(hasher, 2 * HASH_CYCLE_LEN);
check_selector_trace(&trace, 0, LINEAR_HASH, RETURN_STATE);
check_selector_trace(&trace, HASH_CYCLE_LEN, LINEAR_HASH, RETURN_STATE);
check_hasher_state_trace(&trace, 0, init_state1);
check_hasher_state_trace(&trace, HASH_CYCLE_LEN, init_state2);
assert_eq!(trace.last().unwrap(), &[ZERO; 2 * HASH_CYCLE_LEN]);
}
#[test]
fn hasher_build_merkle_root() {
let leaves = init_leaves(&[1, 2]);
let tree = MerkleTree::new(&leaves).unwrap();
let mut hasher = Hasher::default();
let path0 = tree.get_path(NodeIndex::new(1, 0).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[0], &path0, ZERO);
let path1 = tree.get_path(NodeIndex::new(1, 1).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[1], &path1, ONE);
let trace = build_trace(hasher, 2 * HASH_CYCLE_LEN);
check_selector_trace(&trace, 0, MP_VERIFY, RETURN_HASH);
check_selector_trace(&trace, HASH_CYCLE_LEN, MP_VERIFY, RETURN_HASH);
check_hasher_state_trace(&trace, 0, init_state_from_words(&leaves[0], &path0[0]));
check_hasher_state_trace(&trace, 0, init_state_from_words(&path1[0], &leaves[1]));
let node_idx_column = trace.last().unwrap();
assert_eq!(node_idx_column.len(), 2 * HASH_CYCLE_LEN);
assert!(node_idx_column[..HASH_CYCLE_LEN].iter().all(|&v| v == ZERO));
assert_eq!(node_idx_column[HASH_CYCLE_LEN], ONE);
assert!(node_idx_column[HASH_CYCLE_LEN + 1..].iter().all(|&v| v == ZERO));
let leaves = init_leaves(&[1, 2, 3, 4, 5, 6, 7, 8]);
let tree = MerkleTree::new(&leaves).unwrap();
let mut hasher = Hasher::default();
let path = tree.get_path(NodeIndex::new(3, 5).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[5], &path, Felt::new(5));
let trace = build_trace(hasher, path.len() * HASH_CYCLE_LEN);
check_merkle_path(&trace, 0, leaves[5], &path, 5, MP_VERIFY);
let mut hasher = Hasher::default();
let path0 = tree.get_path(NodeIndex::new(3, 0).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[0], &path0, ZERO);
let path3 = tree.get_path(NodeIndex::new(3, 3).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[3], &path3, Felt::new(3));
let path7 = tree.get_path(NodeIndex::new(3, 7).unwrap()).unwrap();
let _ = hasher.build_merkle_root(leaves[7], &path7, Felt::new(7));
let _ = hasher.build_merkle_root(leaves[3], &path3, Felt::new(3));
let trace = build_trace(hasher, 4 * path0.len() * HASH_CYCLE_LEN);
check_merkle_path(&trace, 0, leaves[0], &path0, 0, MP_VERIFY);
check_merkle_path(&trace, path0.len() * HASH_CYCLE_LEN, leaves[3], &path3, 3, MP_VERIFY);
check_merkle_path(&trace, 2 * path0.len() * HASH_CYCLE_LEN, leaves[7], &path7, 7, MP_VERIFY);
check_merkle_path(&trace, 3 * path0.len() * HASH_CYCLE_LEN, leaves[3], &path3, 3, MP_VERIFY);
}
#[test]
fn hasher_update_merkle_root() {
let leaves = init_leaves(&[1, 2]);
let mut tree = MerkleTree::new(&leaves).unwrap();
let mut hasher = Hasher::default();
let path0 = tree.get_path(NodeIndex::new(1, 0).unwrap()).unwrap();
let new_leaf0 = init_leaf(3);
hasher.update_merkle_root(leaves[0], new_leaf0, &path0, ZERO);
tree.update_leaf(0, new_leaf0).unwrap();
let path1 = tree.get_path(NodeIndex::new(1, 1).unwrap()).unwrap();
let new_leaf1 = init_leaf(4);
hasher.update_merkle_root(leaves[1], new_leaf1, &path1, ONE);
tree.update_leaf(1, new_leaf1).unwrap();
let trace = build_trace(hasher, 4 * HASH_CYCLE_LEN);
check_selector_trace(&trace, 0, MR_UPDATE_OLD, RETURN_HASH);
check_selector_trace(&trace, HASH_CYCLE_LEN, MR_UPDATE_NEW, RETURN_HASH);
check_selector_trace(&trace, 2 * HASH_CYCLE_LEN, MR_UPDATE_OLD, RETURN_HASH);
check_selector_trace(&trace, 3 * HASH_CYCLE_LEN, MR_UPDATE_NEW, RETURN_HASH);
check_hasher_state_trace(&trace, 0, init_state_from_words(&leaves[0], &path0[0]));
check_hasher_state_trace(&trace, HASH_CYCLE_LEN, init_state_from_words(&new_leaf0, &path0[0]));
check_hasher_state_trace(
&trace,
2 * HASH_CYCLE_LEN,
init_state_from_words(&path1[0], &leaves[1]),
);
check_hasher_state_trace(
&trace,
3 * HASH_CYCLE_LEN,
init_state_from_words(&path1[0], &new_leaf1),
);
let node_idx_column = trace.last().unwrap();
assert_eq!(node_idx_column.len(), 4 * HASH_CYCLE_LEN);
assert!(node_idx_column[..2 * HASH_CYCLE_LEN].iter().all(|&v| v == ZERO));
assert_eq!(node_idx_column[2 * HASH_CYCLE_LEN], ONE);
assert!(
node_idx_column[2 * HASH_CYCLE_LEN + 1..3 * HASH_CYCLE_LEN]
.iter()
.all(|&v| v == ZERO)
);
assert_eq!(node_idx_column[3 * HASH_CYCLE_LEN], ONE);
assert!(node_idx_column[3 * HASH_CYCLE_LEN + 1..].iter().all(|&v| v == ZERO));
let leaves = init_leaves(&[1, 2, 3, 4, 5, 6, 7, 8]);
let mut tree = MerkleTree::new(&leaves).unwrap();
let mut hasher = Hasher::default();
let path3 = tree.get_path(NodeIndex::new(3, 3).unwrap()).unwrap();
let new_leaf3 = init_leaf(23);
hasher.update_merkle_root(leaves[3], new_leaf3, &path3, Felt::new(3));
tree.update_leaf(3, new_leaf3).unwrap();
let path6 = tree.get_path(NodeIndex::new(3, 6).unwrap()).unwrap();
let new_leaf6 = init_leaf(25);
hasher.update_merkle_root(leaves[6], new_leaf6, &path6, Felt::new(6));
tree.update_leaf(6, new_leaf6).unwrap();
let path3_2 = tree.get_path(NodeIndex::new(3, 3).unwrap()).unwrap();
let new_leaf3_2 = init_leaf(27);
hasher.update_merkle_root(new_leaf3, new_leaf3_2, &path3_2, Felt::new(3));
tree.update_leaf(3, new_leaf3_2).unwrap();
assert_ne!(path3, path3_2);
let leg_rows = path3.len() * HASH_CYCLE_LEN;
let trace = build_trace(hasher, 6 * leg_rows);
check_merkle_path(&trace, 0, leaves[3], &path3, 3, MR_UPDATE_OLD);
check_merkle_path(&trace, leg_rows, new_leaf3, &path3, 3, MR_UPDATE_NEW);
check_merkle_path(&trace, 2 * leg_rows, leaves[6], &path6, 6, MR_UPDATE_OLD);
check_merkle_path(&trace, 3 * leg_rows, new_leaf6, &path6, 6, MR_UPDATE_NEW);
check_merkle_path(&trace, 4 * leg_rows, new_leaf3, &path3_2, 3, MR_UPDATE_OLD);
check_merkle_path(&trace, 5 * leg_rows, new_leaf3_2, &path3_2, 3, MR_UPDATE_NEW);
}
#[test]
fn hash_memoization_control_blocks() {
let mut mast_forest = MastForest::new();
let t_branch_id = BasicBlockNodeBuilder::new(vec![Operation::Push(ZERO)], Vec::new())
.add_to_forest(&mut mast_forest)
.unwrap();
let t_branch = mast_forest[t_branch_id].clone();
let f_branch_id = BasicBlockNodeBuilder::new(vec![Operation::Push(ONE)], Vec::new())
.add_to_forest(&mut mast_forest)
.unwrap();
let f_branch = mast_forest[f_branch_id].clone();
let split1_id = SplitNodeBuilder::new([t_branch_id, f_branch_id])
.add_to_forest(&mut mast_forest)
.unwrap();
let split1 = mast_forest[split1_id].clone();
let split2_id = SplitNodeBuilder::new([t_branch_id, f_branch_id])
.add_to_forest(&mut mast_forest)
.unwrap();
let split2 = mast_forest[split2_id].clone();
let _join_node_id = JoinNodeBuilder::new([split1_id, split2_id])
.add_to_forest(&mut mast_forest)
.unwrap();
let join_node = mast_forest[_join_node_id].clone();
let mut hasher = Hasher::default();
let h1: [Felt; DIGEST_LEN] = split1
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let h2: [Felt; DIGEST_LEN] = split2
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let expected_hash = join_node.digest();
let (_, final_state) =
hasher.hash_control_block(h1.into(), h2.into(), join_node.domain(), expected_hash);
assert_eq!(final_state, expected_hash);
let h1: [Felt; DIGEST_LEN] = t_branch
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let h2: [Felt; DIGEST_LEN] = f_branch
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let expected_hash = split1.digest();
let (addr, final_state) =
hasher.hash_control_block(h1.into(), h2.into(), split1.domain(), expected_hash);
let first_block_final_state = final_state;
assert_eq!(final_state, expected_hash);
let start_row = addr.as_canonical_u64() as usize - 1;
let end_row = hasher.trace_len() - 1;
let h1: [Felt; DIGEST_LEN] = t_branch
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let h2: [Felt; DIGEST_LEN] = f_branch
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let expected_hash = split2.digest();
let (addr, final_state) =
hasher.hash_control_block(h1.into(), h2.into(), split2.domain(), expected_hash);
assert_eq!(final_state, expected_hash);
assert_eq!(first_block_final_state, final_state);
let copied_start_row = addr.as_canonical_u64() as usize - 1;
let copied_end_row = hasher.trace_len() - 1;
let trace = build_trace(hasher, copied_end_row + 1);
let hash_cycle_len: u64 = HASH_CYCLE_LEN.try_into().expect("Could not convert usize to u64");
assert_eq!(Felt::new(hash_cycle_len * 2 + 1), addr);
assert_eq!(trace.last().unwrap(), &[ZERO; HASH_CYCLE_LEN * 3]);
check_memoized_trace(&trace, start_row, end_row, copied_start_row, copied_end_row);
}
#[test]
fn hash_memoization_basic_blocks() {
hash_memoization_basic_blocks_check(
vec![Operation::Push(Felt::new(10)), Operation::Drop],
Vec::new(),
);
let ops = vec![
Operation::Push(ONE),
Operation::Push(Felt::new(2)),
Operation::Push(Felt::new(3)),
Operation::Push(Felt::new(4)),
Operation::Push(Felt::new(5)),
Operation::Push(Felt::new(6)),
Operation::Push(Felt::new(7)),
Operation::Push(Felt::new(8)),
Operation::Push(Felt::new(9)),
Operation::Push(Felt::new(10)),
Operation::Push(Felt::new(11)),
Operation::Push(Felt::new(12)),
Operation::Push(Felt::new(13)),
Operation::Push(Felt::new(14)),
Operation::Push(Felt::new(15)),
Operation::Push(Felt::new(16)),
Operation::Push(Felt::new(17)),
Operation::Push(Felt::new(18)),
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
Operation::Drop,
];
hash_memoization_basic_blocks_check(ops, Vec::new());
}
fn hash_memoization_basic_blocks_check(
operations: Vec<Operation>,
decorators: Vec<(usize, DecoratorId)>,
) {
let mut mast_forest = MastForest::new();
let basic_block_1_id = BasicBlockNodeBuilder::new(operations.clone(), decorators.clone())
.add_to_forest(&mut mast_forest)
.unwrap();
let basic_block_1 = mast_forest[basic_block_1_id].clone();
let loop_body_id =
BasicBlockNodeBuilder::new(vec![Operation::Pad, Operation::Eq, Operation::Not], Vec::new())
.add_to_forest(&mut mast_forest)
.unwrap();
let loop_block_id = LoopNodeBuilder::new(loop_body_id).add_to_forest(&mut mast_forest).unwrap();
let loop_block = mast_forest[loop_block_id].clone();
let join2_block_id = JoinNodeBuilder::new([basic_block_1_id, loop_block_id])
.add_to_forest(&mut mast_forest)
.unwrap();
let join2_block = mast_forest[join2_block_id].clone();
let basic_block_2_id = BasicBlockNodeBuilder::new(operations.clone(), decorators.clone())
.add_to_forest(&mut mast_forest)
.unwrap();
let basic_block_2 = mast_forest[basic_block_2_id].clone();
let join1_block_id = JoinNodeBuilder::new([join2_block_id, basic_block_2_id])
.add_to_forest(&mut mast_forest)
.unwrap();
let join1_block = mast_forest[join1_block_id].clone();
let mut hasher = Hasher::default();
let h1: [Felt; DIGEST_LEN] = join2_block
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let h2: [Felt; DIGEST_LEN] = basic_block_2
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let expected_hash = join1_block.digest();
let (_, final_state) =
hasher.hash_control_block(h1.into(), h2.into(), join1_block.domain(), expected_hash);
assert_eq!(final_state, expected_hash);
let h1: [Felt; DIGEST_LEN] = basic_block_1
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let h2: [Felt; DIGEST_LEN] = loop_block
.digest()
.as_elements()
.try_into()
.expect("Could not convert slice to array");
let expected_hash = join2_block.digest();
let (_, final_state) =
hasher.hash_control_block(h1.into(), h2.into(), join2_block.domain(), expected_hash);
assert_eq!(final_state, expected_hash);
let basic_block_1_val = if let MastNode::Block(basic_block) = basic_block_1.clone() {
basic_block
} else {
unreachable!()
};
let (addr, final_state) =
hasher.hash_basic_block(basic_block_1_val.op_batches(), basic_block_1.digest());
let first_basic_block_final_state = final_state;
let expected_hash = basic_block_1.digest();
assert_eq!(final_state, expected_hash);
let start_row = addr.as_canonical_u64() as usize - 1;
let end_row = hasher.trace_len() - 1;
let basic_block_2_val = if let MastNode::Block(basic_block) = basic_block_2.clone() {
basic_block
} else {
unreachable!()
};
let (addr, final_state) =
hasher.hash_basic_block(basic_block_2_val.op_batches(), basic_block_2.digest());
let _num_batches = basic_block_2_val.op_batches().len();
let expected_hash = basic_block_2.digest();
assert_eq!(final_state, expected_hash);
assert_eq!(first_basic_block_final_state, final_state);
let copied_start_row = addr.as_canonical_u64() as usize - 1;
let copied_end_row = hasher.trace_len() - 1;
let trace = build_trace(hasher, copied_end_row + 1);
check_memoized_trace(&trace, start_row, end_row, copied_start_row, copied_end_row);
}
fn build_trace(hasher: Hasher, num_rows: usize) -> Vec<Vec<Felt>> {
let mut trace = (0..TRACE_WIDTH).map(|_| vec![ZERO; num_rows]).collect::<Vec<_>>();
let mut fragment = TraceFragment::trace_to_fragment(&mut trace);
hasher.fill_trace(&mut fragment);
trace
}
fn check_merkle_path(
trace: &[Vec<Felt>],
row_idx: usize,
leaf: Digest,
path: &MerklePath,
node_index: u64,
init_selectors: Selectors,
) {
let mid_selectors = [ZERO, init_selectors[1], init_selectors[2]];
check_selector_trace(trace, row_idx, init_selectors, init_selectors);
for i in 1..path.len() - 1 {
check_selector_trace(trace, row_idx + i * HASH_CYCLE_LEN, mid_selectors, init_selectors);
}
let last_perm_row_addr = row_idx + (path.len() - 1) * HASH_CYCLE_LEN;
check_selector_trace(trace, last_perm_row_addr, mid_selectors, RETURN_HASH);
let mut root = leaf;
for (i, &node) in path.iter().enumerate() {
let index_bit = (node_index >> i) & 1;
let old_root = root;
let init_state = if index_bit == 0 {
root = hasher::merge(&[root, node]);
init_state_from_words(&old_root, &node)
} else {
root = hasher::merge(&[node, root]);
init_state_from_words(&node, &old_root)
};
check_hasher_state_trace(trace, row_idx + i * HASH_CYCLE_LEN, init_state);
}
let node_idx_column = trace.last().unwrap();
assert_eq!(Felt::new(node_index), node_idx_column[row_idx]);
let mut node_index = node_index >> 1;
for i in 1..HASH_CYCLE_LEN {
assert_eq!(Felt::new(node_index), node_idx_column[row_idx + i])
}
for i in 1..path.len() {
node_index >>= 1;
for j in 0..HASH_CYCLE_LEN {
assert_eq!(Felt::new(node_index), node_idx_column[row_idx + i * HASH_CYCLE_LEN + j])
}
}
}
fn check_selector_trace(
trace: &[Vec<Felt>],
row_idx: usize,
init_selectors: Selectors,
final_selectors: Selectors,
) {
let trace = &trace[0..3];
let mid_selectors = [ZERO, init_selectors[1], init_selectors[2]];
assert_row_equal(trace, row_idx, &init_selectors);
for i in 0..NUM_ROUNDS - 1 {
assert_row_equal(trace, row_idx + i + 1, &mid_selectors);
}
assert_row_equal(trace, row_idx + NUM_ROUNDS, &final_selectors);
}
fn check_hasher_state_trace(trace: &[Vec<Felt>], row_idx: usize, init_state: HasherState) {
let trace = &trace[STATE_COL_RANGE];
let mut state = init_state;
assert_row_equal(trace, row_idx, &state);
for i in 0..NUM_ROUNDS {
hasher::apply_round(&mut state, i);
assert_row_equal(trace, row_idx + i + 1, &state);
}
}
fn check_memoized_trace(
trace: &[Vec<Felt>],
start_row: usize,
end_row: usize,
copied_start_row: usize,
copied_end_row: usize,
) {
assert_eq!(end_row - start_row, copied_end_row - copied_start_row);
let selector_trace = &trace[0..NUM_SELECTORS];
for column in selector_trace.iter() {
assert_eq!(column[start_row..end_row], column[copied_start_row..copied_end_row])
}
let hasher_state_trace = &trace[STATE_COL_RANGE];
for column in hasher_state_trace.iter() {
assert_eq!(column[start_row..end_row], column[copied_start_row..copied_end_row])
}
}
fn assert_row_equal(trace: &[Vec<Felt>], row_idx: usize, values: &[Felt]) {
for (column, &value) in trace.iter().zip(values.iter()) {
assert_eq!(column[row_idx], value);
}
}
fn apply_permutation(mut state: HasherState) -> HasherState {
hasher::apply_permutation(&mut state);
state
}
fn init_leaves(values: &[u64]) -> Vec<Digest> {
values.iter().map(|&v| init_leaf(v)).collect()
}
fn init_leaf(value: u64) -> Digest {
[Felt::new(value), ZERO, ZERO, ZERO].into()
}