use alloc::vec::Vec;
use miden_air::trace::{
AUX_TRACE_RAND_CHALLENGES, Challenges, MainTrace,
chiplets::hasher::{HASH_CYCLE_LEN, P1_COL_IDX},
};
use miden_core::{
ONE, Word, ZERO,
crypto::merkle::{MerkleStore, MerkleTree, NodeIndex},
field::{ExtensionField, Field},
operations::Operation,
};
use rstest::rstest;
use super::{Felt, build_trace_from_ops_with_inputs, rand_array};
use crate::{AdviceInputs, StackInputs};
#[rstest]
#[case(5_u64)]
#[case(4_u64)]
fn hasher_p1_mp_verify(#[case] index: u64) {
let (tree, _) = build_merkle_tree();
let store = MerkleStore::from(&tree);
let depth = 3;
let node = tree.get_node(NodeIndex::new(depth as u8, index).unwrap()).unwrap();
let mut init_stack = vec![];
append_word(&mut init_stack, node);
init_stack.extend_from_slice(&[depth, index]);
append_word(&mut init_stack, tree.root());
let stack_inputs = StackInputs::try_from_ints(init_stack).unwrap();
let advice_inputs = AdviceInputs::default().with_merkle_store(store);
let ops = vec![Operation::MpVerify(ZERO)];
let trace = build_trace_from_ops_with_inputs(ops, stack_inputs, advice_inputs);
let challenges = rand_array::<Felt, AUX_TRACE_RAND_CHALLENGES>();
let aux_columns = trace.build_aux_trace(&challenges).unwrap();
let p1 = aux_columns.get_column(P1_COL_IDX);
for value in p1.iter() {
assert_eq!(ONE, *value);
}
}
#[rstest]
#[case(5_u64)]
#[case(4_u64)]
fn hasher_p1_mr_update(#[case] index: u64) {
let (tree, _) = build_merkle_tree();
let old_node = tree.get_node(NodeIndex::new(3, index).unwrap()).unwrap();
let new_node = init_leaf(11);
let path = tree.get_path(NodeIndex::new(3, index).unwrap()).unwrap();
let mut init_stack = vec![];
append_word(&mut init_stack, old_node);
init_stack.extend_from_slice(&[3, index]);
append_word(&mut init_stack, tree.root());
append_word(&mut init_stack, new_node);
let stack_inputs = StackInputs::try_from_ints(init_stack).unwrap();
let store = MerkleStore::from(&tree);
let advice_inputs = AdviceInputs::default().with_merkle_store(store);
let ops = vec![Operation::MrUpdate];
let trace = build_trace_from_ops_with_inputs(ops, stack_inputs, advice_inputs);
let challenges = rand_array::<Felt, AUX_TRACE_RAND_CHALLENGES>();
let aux_columns = trace.build_aux_trace(&challenges).unwrap();
let p1 = aux_columns.get_column(P1_COL_IDX);
let challenges = Challenges::<Felt>::new(challenges[0], challenges[1]);
let row_values = [
SiblingTableRow::new(Felt::new(index), path[0]).to_value(&trace.main_trace, &challenges),
SiblingTableRow::new(Felt::new(index >> 1), path[1])
.to_value(&trace.main_trace, &challenges),
SiblingTableRow::new(Felt::new(index >> 2), path[2])
.to_value(&trace.main_trace, &challenges),
];
let mut expected_value = ONE;
assert_eq!(expected_value, p1[0]);
let row_add_1 = HASH_CYCLE_LEN + 1;
for value in p1.iter().take(row_add_1).skip(1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[0];
assert_eq!(expected_value, p1[row_add_1]);
let row_add_2 = 2 * HASH_CYCLE_LEN;
for value in p1.iter().take(row_add_2).skip(row_add_1 + 1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[1];
assert_eq!(expected_value, p1[row_add_2]);
let row_add_3 = 3 * HASH_CYCLE_LEN;
for value in p1.iter().take(row_add_3).skip(row_add_2 + 1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[2];
assert_eq!(expected_value, p1[row_add_3]);
let row_remove_1 = 4 * HASH_CYCLE_LEN + 1;
for value in p1.iter().take(row_remove_1).skip(row_add_3 + 1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[0].inverse();
assert_eq!(expected_value, p1[row_remove_1]);
let row_remove_2 = 5 * HASH_CYCLE_LEN;
for value in p1.iter().take(row_remove_2).skip(row_remove_1 + 1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[1].inverse();
assert_eq!(expected_value, p1[row_remove_2]);
let row_remove_3 = 6 * HASH_CYCLE_LEN;
for value in p1.iter().take(row_remove_3).skip(row_remove_2 + 1) {
assert_eq!(expected_value, *value);
}
expected_value *= row_values[2].inverse();
assert_eq!(expected_value, p1[row_remove_3]);
assert_eq!(expected_value, ONE);
for value in p1.iter().skip(row_remove_3 + 1) {
assert_eq!(ONE, *value);
}
}
fn build_merkle_tree() -> (MerkleTree, Vec<Word>) {
let leaves = init_leaves(&[1, 2, 3, 4, 5, 6, 7, 8]);
(MerkleTree::new(leaves.clone()).unwrap(), leaves)
}
fn init_leaves(values: &[u64]) -> Vec<Word> {
values.iter().map(|&v| init_leaf(v)).collect()
}
fn init_leaf(value: u64) -> Word {
[Felt::new(value), ZERO, ZERO, ZERO].into()
}
fn append_word(target: &mut Vec<u64>, word: Word) {
word.iter().for_each(|v| target.push(v.as_canonical_u64()));
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SiblingTableRow {
index: Felt,
sibling: Word,
}
impl SiblingTableRow {
pub fn new(index: Felt, sibling: Word) -> Self {
Self { index, sibling }
}
pub fn to_value<E: ExtensionField<Felt>>(
&self,
_main_trace: &MainTrace,
challenges: &Challenges<E>,
) -> E {
let lsb = self.index.as_canonical_u64() & 1;
if lsb == 0 {
challenges.alpha
+ challenges.beta_powers[2] * self.index
+ challenges.beta_powers[7] * self.sibling[0]
+ challenges.beta_powers[8] * self.sibling[1]
+ challenges.beta_powers[9] * self.sibling[2]
+ challenges.beta_powers[10] * self.sibling[3]
} else {
challenges.alpha
+ challenges.beta_powers[2] * self.index
+ challenges.beta_powers[3] * self.sibling[0]
+ challenges.beta_powers[4] * self.sibling[1]
+ challenges.beta_powers[5] * self.sibling[2]
+ challenges.beta_powers[6] * self.sibling[3]
}
}
}