use super::poseidon_primitives::{generate_constants, ConstantLength, Hash, Mds, Spec};
use ff::Field;
use halo2curves::bn256::Fr;
use std::collections::HashMap;
use std::sync::{OnceLock, RwLock};
pub const MERKLE_DEPTH_EVM: usize = 32;
#[derive(Clone, Copy, Debug)]
pub struct Bn254PoseidonMerkleSpec;
impl Spec<Fr, 3, 2> for Bn254PoseidonMerkleSpec {
fn full_rounds() -> usize {
8
}
fn partial_rounds() -> usize {
56
}
fn sbox(val: Fr) -> Fr {
val.pow_vartime([5])
}
fn secure_mds() -> usize {
0
}
fn constants() -> (Vec<[Fr; 3]>, Mds<Fr, 3>, Mds<Fr, 3>) {
static CONSTANTS: OnceLock<(Vec<[Fr; 3]>, Mds<Fr, 3>, Mds<Fr, 3>)> = OnceLock::new();
CONSTANTS
.get_or_init(generate_constants::<Fr, Self, 3, 2>)
.clone()
}
}
#[inline]
pub fn merkle_compress(level: u8, left: Fr, right: Fr) -> Fr {
Hash::<Fr, Bn254PoseidonMerkleSpec, ConstantLength<3>, 3, 2>::init().hash([
Fr::from(level as u64),
left,
right,
])
}
#[inline]
pub fn poseidon_domain_pair(domain: u64, a: Fr, b: Fr) -> Fr {
Hash::<Fr, Bn254PoseidonMerkleSpec, ConstantLength<3>, 3, 2>::init().hash([
Fr::from(domain),
a,
b,
])
}
pub fn merkle_root(position: u32, leaf: Fr, siblings: &[Fr; MERKLE_DEPTH_EVM]) -> Fr {
let mut node = leaf;
for (level, sibling) in siblings.iter().enumerate() {
let l = level as u8;
if (position >> level) & 1 == 0 {
node = merkle_compress(l, node, *sibling);
} else {
node = merkle_compress(l, *sibling, node);
}
}
node
}
#[derive(Debug)]
pub struct Bn254IncrementalMerkleTree {
leaves: Vec<Fr>,
empty: [Fr; MERKLE_DEPTH_EVM + 1],
node_cache: RwLock<HashMap<(usize, usize), Fr>>,
}
impl Bn254IncrementalMerkleTree {
pub fn new() -> Self {
let mut empty = [Fr::ZERO; MERKLE_DEPTH_EVM + 1];
for i in 1..=MERKLE_DEPTH_EVM {
empty[i] = merkle_compress((i - 1) as u8, empty[i - 1], empty[i - 1]);
}
Self { leaves: Vec::new(), empty, node_cache: RwLock::new(HashMap::new()) }
}
pub fn append(&mut self, leaf: Fr) {
self.leaves.push(leaf);
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn is_empty(&self) -> bool {
self.leaves.is_empty()
}
pub fn root(&self) -> Fr {
self.subtree_hash(MERKLE_DEPTH_EVM, 0)
}
pub fn witness(&self, pos: u32) -> [Fr; MERKLE_DEPTH_EVM] {
assert!((pos as usize) < self.leaves.len(), "position out of tree");
let mut siblings = [Fr::ZERO; MERKLE_DEPTH_EVM];
for level in 0..MERKLE_DEPTH_EVM {
let sibling_node_idx = ((pos >> level) ^ 1) as usize;
siblings[level] = self.subtree_hash(level, sibling_node_idx);
}
siblings
}
fn subtree_hash(&self, level: usize, idx: usize) -> Fr {
let start = idx << level; if start >= self.leaves.len() {
return self.empty[level];
}
if level == 0 {
return self.leaves[start];
}
let complete = start + (1usize << level) <= self.leaves.len();
if complete {
if let Some(cached) = self.node_cache.read().unwrap().get(&(level, idx)) {
return *cached;
}
}
let left = self.subtree_hash(level - 1, idx * 2);
let right = self.subtree_hash(level - 1, idx * 2 + 1);
let node = merkle_compress((level - 1) as u8, left, right);
if complete {
self.node_cache.write().unwrap().insert((level, idx), node);
}
node
}
}
impl Default for Bn254IncrementalMerkleTree {
fn default() -> Self {
Self::new()
}
}