Skip to main content

privacy_core/commitment_tree/
poseidon.rs

1//! Poseidon-based Merkle hashing on the BN254 scalar field (`Fr`).
2//!
3//! Lifted verbatim from `orchard-bn254`'s `poseidon_merkle_bn254` (the merkle-tree parts
4//! only), re-sourced to the **standalone** crates `halo2_poseidon` (Poseidon primitives)
5//! and `halo2curves` (BN254 `Fr`) so it carries **no `halo2_proofs` / `halo2_gadgets`
6//! dependency**. The Poseidon spec/constants are byte-identical to the prover + on-chain
7//! `PoseidonT3` (verified against the live commitment-tree root).
8
9use super::poseidon_primitives::{generate_constants, ConstantLength, Hash, Mds, Spec};
10use ff::Field;
11use halo2curves::bn256::Fr;
12use std::collections::HashMap;
13use std::sync::{OnceLock, RwLock};
14
15/// Orchard-compatible depth for the incremental note commitment tree.
16pub const MERKLE_DEPTH_EVM: usize = 32;
17
18/// Poseidon-128, width 3, rate 2, x^5 S-box; Grain-generated constants for BN254 `Fr`.
19#[derive(Clone, Copy, Debug)]
20pub struct Bn254PoseidonMerkleSpec;
21
22impl Spec<Fr, 3, 2> for Bn254PoseidonMerkleSpec {
23    fn full_rounds() -> usize {
24        8
25    }
26    fn partial_rounds() -> usize {
27        56
28    }
29    fn sbox(val: Fr) -> Fr {
30        val.pow_vartime([5])
31    }
32    fn secure_mds() -> usize {
33        0
34    }
35    fn constants() -> (Vec<[Fr; 3]>, Mds<Fr, 3>, Mds<Fr, 3>) {
36        // Grain-based constant generation is ~10x the cost of the hash itself and the
37        // output is fixed for this spec, so generate once and clone (a 64-entry Vec plus
38        // two 3x3 matrices — negligible next to one round of field exponentiation).
39        static CONSTANTS: OnceLock<(Vec<[Fr; 3]>, Mds<Fr, 3>, Mds<Fr, 3>)> = OnceLock::new();
40        CONSTANTS
41            .get_or_init(generate_constants::<Fr, Self, 3, 2>)
42            .clone()
43    }
44}
45
46/// One Merkle layer: `H(level || left || right)` with domain separation on `level`.
47#[inline]
48pub fn merkle_compress(level: u8, left: Fr, right: Fr) -> Fr {
49    Hash::<Fr, Bn254PoseidonMerkleSpec, ConstantLength<3>, 3, 2>::init().hash([
50        Fr::from(level as u64),
51        left,
52        right,
53    ])
54}
55
56/// `Poseidon(domain, a, b)` over `ConstantLength<3>` — the width-3 two-input hash with
57/// a `u64` domain tag. Byte-identical to the prover's `poseidon_merkle_bn254::poseidon2`
58/// and to `merkle_compress` when `domain == level`. Used for the frozen Indexed MT,
59/// whose node domains exceed the `u8` range of `merkle_compress`.
60#[inline]
61pub fn poseidon_domain_pair(domain: u64, a: Fr, b: Fr) -> Fr {
62    Hash::<Fr, Bn254PoseidonMerkleSpec, ConstantLength<3>, 3, 2>::init().hash([
63        Fr::from(domain),
64        a,
65        b,
66    ])
67}
68
69/// Full Merkle root (depth [`MERKLE_DEPTH_EVM`]) over `Fr` leaves from a sibling path.
70pub fn merkle_root(position: u32, leaf: Fr, siblings: &[Fr; MERKLE_DEPTH_EVM]) -> Fr {
71    let mut node = leaf;
72    for (level, sibling) in siblings.iter().enumerate() {
73        let l = level as u8;
74        if (position >> level) & 1 == 0 {
75            node = merkle_compress(l, node, *sibling);
76        } else {
77            node = merkle_compress(l, *sibling, node);
78        }
79    }
80    node
81}
82
83/// Append-only incremental Merkle tree of fixed depth [`MERKLE_DEPTH_EVM`] (32) over BN254
84/// scalar leaves. Empty leaves default to [`Fr::ZERO`].
85#[derive(Debug)]
86pub struct Bn254IncrementalMerkleTree {
87    leaves: Vec<Fr>,
88    /// `empty[l]` is the root of a depth-`l` all-zero subtree (`empty[0]` = `Fr::ZERO`).
89    empty: [Fr; MERKLE_DEPTH_EVM + 1],
90    /// Memoized hashes of *complete* subtrees, keyed by `(level, idx)`. The tree is
91    /// append-only, so a complete subtree's hash never changes and the cache never needs
92    /// invalidation. Without this, every `root()`/`witness()` call rehashes the whole
93    /// tree — O(leaves) Poseidon compressions per query, which made indexer Merkle-path
94    /// queries degrade linearly as notes accumulated.
95    node_cache: RwLock<HashMap<(usize, usize), Fr>>,
96}
97
98impl Bn254IncrementalMerkleTree {
99    pub fn new() -> Self {
100        let mut empty = [Fr::ZERO; MERKLE_DEPTH_EVM + 1];
101        for i in 1..=MERKLE_DEPTH_EVM {
102            empty[i] = merkle_compress((i - 1) as u8, empty[i - 1], empty[i - 1]);
103        }
104        Self { leaves: Vec::new(), empty, node_cache: RwLock::new(HashMap::new()) }
105    }
106
107    pub fn append(&mut self, leaf: Fr) {
108        self.leaves.push(leaf);
109    }
110
111    pub fn len(&self) -> usize {
112        self.leaves.len()
113    }
114
115    pub fn is_empty(&self) -> bool {
116        self.leaves.is_empty()
117    }
118
119    pub fn root(&self) -> Fr {
120        self.subtree_hash(MERKLE_DEPTH_EVM, 0)
121    }
122
123    /// Authentication path (siblings) for the leaf at `pos`. Panics if `pos >= len()`.
124    pub fn witness(&self, pos: u32) -> [Fr; MERKLE_DEPTH_EVM] {
125        assert!((pos as usize) < self.leaves.len(), "position out of tree");
126        let mut siblings = [Fr::ZERO; MERKLE_DEPTH_EVM];
127        for level in 0..MERKLE_DEPTH_EVM {
128            let sibling_node_idx = ((pos >> level) ^ 1) as usize;
129            siblings[level] = self.subtree_hash(level, sibling_node_idx);
130        }
131        siblings
132    }
133
134    fn subtree_hash(&self, level: usize, idx: usize) -> Fr {
135        let start = idx << level; // idx * 2^level
136        if start >= self.leaves.len() {
137            return self.empty[level];
138        }
139        if level == 0 {
140            return self.leaves[start];
141        }
142        // Only complete subtrees are cacheable: an incomplete one (right frontier)
143        // still changes as leaves are appended.
144        let complete = start + (1usize << level) <= self.leaves.len();
145        if complete {
146            if let Some(cached) = self.node_cache.read().unwrap().get(&(level, idx)) {
147                return *cached;
148            }
149        }
150        let left = self.subtree_hash(level - 1, idx * 2);
151        let right = self.subtree_hash(level - 1, idx * 2 + 1);
152        let node = merkle_compress((level - 1) as u8, left, right);
153        if complete {
154            self.node_cache.write().unwrap().insert((level, idx), node);
155        }
156        node
157    }
158}
159
160impl Default for Bn254IncrementalMerkleTree {
161    fn default() -> Self {
162        Self::new()
163    }
164}