Skip to main content

imt_tree/tree/
mod.rs

1use ff::Field as _;
2use pasta_curves::Fp;
3use rayon::prelude::*;
4
5pub(crate) use crate::hasher::PoseidonHasher;
6
7#[cfg(test)]
8mod tests;
9
10/// Depth of the nullifier Merkle tree.
11///
12/// Each on-chain nullifier produces approximately one gap; with K=2 punctured
13/// ranges, ~n/2 leaves are needed for n nullifiers. Zcash mainnet currently
14/// has under 64M Orchard nullifiers. We plan for this circuit to support up
15/// to 256M nullifiers, so the tree needs capacity for ~2^28 leaves:
16/// `log2(256 << 20) + 1 = 29`.
17pub const TREE_DEPTH: usize = 29;
18
19/// A punctured range `[nf_lo, nf_mid, nf_hi]` representing the interval
20/// `(nf_lo, nf_hi) \ {nf_mid}` — two adjacent gaps joined by excluding the
21/// nullifier between them.
22///
23/// With K=2, each leaf stores three sorted nullifier boundaries. The leaf
24/// commitment is `Poseidon3(nf_lo, nf_mid, nf_hi)`.
25pub type PuncturedRange = [Fp; 3];
26
27/// Build punctured ranges (K=2) from a sorted, deduplicated nullifier list.
28///
29/// Groups consecutive nullifiers into overlapping triples:
30///   `[nf_0, nf_1, nf_2]`, `[nf_2, nf_3, nf_4]`, `[nf_4, nf_5, nf_6]`, ...
31///
32/// Each triple covers the punctured interval `(nf_lo, nf_hi) \ {nf_mid}`.
33/// Consecutive triples share boundary nullifiers, so every gap between
34/// adjacent nullifiers is covered by exactly one leaf.
35///
36/// # Panics
37///
38/// Panics if `sorted_nfs` has fewer than 3 elements or an even length
39/// (which would leave a trailing gap without a matching triple — callers
40/// should ensure an odd count via sentinel injection).
41pub fn build_punctured_ranges(sorted_nfs: &[Fp]) -> Vec<PuncturedRange> {
42    let n = sorted_nfs.len();
43    assert!(
44        n >= 3,
45        "need at least 3 sorted nullifiers for K=2 punctured ranges, got {n}"
46    );
47    assert!(
48        n % 2 == 1,
49        "sorted nullifier count must be odd for K=2 (got {n}); \
50         inject an additional sentinel to fix"
51    );
52
53    let num_leaves = (n - 1) / 2;
54    (0..num_leaves)
55        .map(|i| {
56            let base = i * 2;
57            let (lo, mid, hi) = (sorted_nfs[base], sorted_nfs[base + 1], sorted_nfs[base + 2]);
58            assert!(
59                lo < mid && mid < hi,
60                "punctured range {i} violates strict ordering: \
61                 nf_lo={lo:?}, nf_mid={mid:?}, nf_hi={hi:?} \
62                 (input must be sorted and deduplicated)"
63            );
64            [lo, mid, hi]
65        })
66        .collect()
67}
68
69/// Hash each punctured range triple into a single leaf commitment.
70pub fn commit_punctured_ranges(ranges: &[PuncturedRange]) -> Vec<Fp> {
71    ranges
72        .par_iter()
73        .map_init(PoseidonHasher::new, |hasher, &[a, b, c]| {
74            hasher.hash3(a, b, c)
75        })
76        .collect()
77}
78
79/// Find the punctured-range index that contains `value`.
80///
81/// Returns `Some(i)` where `ranges[i]` is `[nf_lo, nf_mid, nf_hi]` and
82/// `nf_lo < value < nf_hi` and `value != nf_mid`. Returns `None` if the
83/// value is an existing nullifier.
84pub fn find_punctured_range_for_value(ranges: &[PuncturedRange], value: Fp) -> Option<usize> {
85    let i = ranges.partition_point(|[nf_lo, _, _]| *nf_lo < value);
86    if i == 0 {
87        return None;
88    }
89    let idx = i - 1;
90    let [nf_lo, nf_mid, nf_hi] = ranges[idx];
91    let offset = value - nf_lo;
92    let span = nf_hi - nf_lo;
93    if offset == Fp::zero() || offset >= span {
94        return None;
95    }
96    if value == nf_mid {
97        return None;
98    }
99    Some(idx)
100}
101
102/// Verify that every punctured range has outer span `≤ 2^250`.
103///
104/// For K=2, the outer span `nf_hi - nf_lo` covers two consecutive sentinel
105/// intervals. With sentinel spacing `2^249`, the maximum span is
106/// `2 * 2^249 = 2^250`, which matches the circuit's 250-bit range check
107/// (25 limbs × 10 bits).
108pub fn verify_punctured_range_spans(ranges: &[PuncturedRange]) -> anyhow::Result<()> {
109    let max_span = Fp::from(2u64).pow([250, 0, 0, 0]);
110    for (i, &[nf_lo, _, nf_hi]) in ranges.iter().enumerate() {
111        let span = nf_hi - nf_lo;
112        anyhow::ensure!(
113            span <= max_span,
114            "punctured range {i} has span > 2^250: nf_lo={nf_lo:?}, nf_hi={nf_hi:?}"
115        );
116    }
117    Ok(())
118}
119
120/// Pre-compute the empty subtree hash at each tree level.
121///
122/// `empty[0] = hash3(0, 0, 0)` -- the commitment of an all-zero punctured range.
123/// `empty[i] = hash(empty[i-1], empty[i-1])` for higher levels.
124pub fn precompute_empty_hashes() -> [Fp; TREE_DEPTH] {
125    let hasher = PoseidonHasher::new();
126    let mut empty = [Fp::default(); TREE_DEPTH];
127    empty[0] = hasher.hash3(Fp::zero(), Fp::zero(), Fp::zero());
128    for i in 1..TREE_DEPTH {
129        empty[i] = hasher.hash(empty[i - 1], empty[i - 1]);
130    }
131    empty
132}
133
134/// Build Merkle tree levels bottom-up from leaf hashes.
135///
136/// `depth` controls the number of tree levels (use `TREE_DEPTH` for a full
137/// depth-29 tree, or a smaller value like 25 for the PIR tree).
138/// Returns `(root, levels)` where `levels[0]` contains leaf hashes and
139/// `levels[depth-1]` contains the root's two children.
140///
141/// Each level is padded to even length using the pre-computed empty hash so
142/// that pair-wise hashing produces the next level cleanly. All intermediate
143/// layers are retained so Merkle auth paths can be extracted in O(`depth`)
144/// via simple sibling lookups.
145pub fn build_levels(mut leaves: Vec<Fp>, empty: &[Fp; TREE_DEPTH], depth: usize) -> (Fp, Vec<Vec<Fp>>) {
146    let hasher = PoseidonHasher::new();
147    let mut levels: Vec<Vec<Fp>> = Vec::with_capacity(depth);
148
149    // Level 0 = leaf commitments, padded to even length.
150    // Takes ownership of `leaves` to avoid a 1.6 GB memcpy at scale.
151    if leaves.is_empty() {
152        leaves.push(empty[0]);
153    }
154    if leaves.len() & 1 == 1 {
155        leaves.push(empty[0]);
156    }
157    levels.push(leaves);
158
159    const PAR_THRESHOLD: usize = 1024;
160
161    for i in 0..depth - 1 {
162        let prev = &levels[i];
163        let pairs = prev.len() / 2;
164        let mut next: Vec<Fp> = if pairs >= PAR_THRESHOLD {
165            prev.par_chunks_exact(2)
166                .map_init(PoseidonHasher::new, |h, pair| h.hash(pair[0], pair[1]))
167                .collect()
168        } else {
169            (0..pairs)
170                .map(|j| hasher.hash(prev[j * 2], prev[j * 2 + 1]))
171                .collect()
172        };
173        if next.len() & 1 == 1 {
174            next.push(empty[i + 1]);
175        }
176        levels.push(next);
177    }
178
179    let top = &levels[depth - 1];
180    let root = hasher.hash(top[0], top[1]);
181
182    (root, levels)
183}
184
185