imt_tree/tree/mod.rs
1use ff::Field as _;
2use pasta_curves::Fp;
3use rayon::prelude::*;
4
5pub(crate) use crate::hasher::PoseidonHasher;
6
7#[cfg(test)]
8mod tests;
9
10/// Depth of the nullifier Merkle tree.
11///
12/// Each on-chain nullifier produces approximately one gap; with K=2 punctured
13/// ranges, ~n/2 leaves are needed for n nullifiers. Zcash mainnet currently
14/// has under 64M Orchard nullifiers. We plan for this circuit to support up
15/// to 256M nullifiers, so the tree needs capacity for ~2^28 leaves:
16/// `log2(256 << 20) + 1 = 29`.
17pub const TREE_DEPTH: usize = 29;
18
19/// A punctured range `[nf_lo, nf_mid, nf_hi]` representing the interval
20/// `(nf_lo, nf_hi) \ {nf_mid}` — two adjacent gaps joined by excluding the
21/// nullifier between them.
22///
23/// With K=2, each leaf stores three sorted nullifier boundaries. The leaf
24/// commitment is `Poseidon3(nf_lo, nf_mid, nf_hi)`.
25pub type PuncturedRange = [Fp; 3];
26
27/// Build punctured ranges (K=2) from a sorted, deduplicated nullifier list.
28///
29/// Groups consecutive nullifiers into overlapping triples:
30/// `[nf_0, nf_1, nf_2]`, `[nf_2, nf_3, nf_4]`, `[nf_4, nf_5, nf_6]`, ...
31///
32/// Each triple covers the punctured interval `(nf_lo, nf_hi) \ {nf_mid}`.
33/// Consecutive triples share boundary nullifiers, so every gap between
34/// adjacent nullifiers is covered by exactly one leaf.
35///
36/// # Panics
37///
38/// Panics if `sorted_nfs` has fewer than 3 elements or an even length
39/// (which would leave a trailing gap without a matching triple — callers
40/// should ensure an odd count via sentinel injection).
41pub fn build_punctured_ranges(sorted_nfs: &[Fp]) -> Vec<PuncturedRange> {
42 let n = sorted_nfs.len();
43 assert!(
44 n >= 3,
45 "need at least 3 sorted nullifiers for K=2 punctured ranges, got {n}"
46 );
47 assert!(
48 n % 2 == 1,
49 "sorted nullifier count must be odd for K=2 (got {n}); \
50 inject an additional sentinel to fix"
51 );
52
53 let num_leaves = (n - 1) / 2;
54 (0..num_leaves)
55 .map(|i| {
56 let base = i * 2;
57 let (lo, mid, hi) = (sorted_nfs[base], sorted_nfs[base + 1], sorted_nfs[base + 2]);
58 assert!(
59 lo < mid && mid < hi,
60 "punctured range {i} violates strict ordering: \
61 nf_lo={lo:?}, nf_mid={mid:?}, nf_hi={hi:?} \
62 (input must be sorted and deduplicated)"
63 );
64 [lo, mid, hi]
65 })
66 .collect()
67}
68
69/// Hash each punctured range triple into a single leaf commitment.
70pub fn commit_punctured_ranges(ranges: &[PuncturedRange]) -> Vec<Fp> {
71 ranges
72 .par_iter()
73 .map_init(PoseidonHasher::new, |hasher, &[a, b, c]| {
74 hasher.hash3(a, b, c)
75 })
76 .collect()
77}
78
79/// Find the punctured-range index that contains `value`.
80///
81/// Returns `Some(i)` where `ranges[i]` is `[nf_lo, nf_mid, nf_hi]` and
82/// `nf_lo < value < nf_hi` and `value != nf_mid`. Returns `None` if the
83/// value is an existing nullifier.
84pub fn find_punctured_range_for_value(ranges: &[PuncturedRange], value: Fp) -> Option<usize> {
85 let i = ranges.partition_point(|[nf_lo, _, _]| *nf_lo < value);
86 if i == 0 {
87 return None;
88 }
89 let idx = i - 1;
90 let [nf_lo, nf_mid, nf_hi] = ranges[idx];
91 let offset = value - nf_lo;
92 let span = nf_hi - nf_lo;
93 if offset == Fp::zero() || offset >= span {
94 return None;
95 }
96 if value == nf_mid {
97 return None;
98 }
99 Some(idx)
100}
101
102/// Verify that every punctured range has outer span `≤ 2^250`.
103///
104/// For K=2, the outer span `nf_hi - nf_lo` covers two consecutive sentinel
105/// intervals. With sentinel spacing `2^249`, the maximum span is
106/// `2 * 2^249 = 2^250`, which matches the circuit's 250-bit range check
107/// (25 limbs × 10 bits).
108pub fn verify_punctured_range_spans(ranges: &[PuncturedRange]) -> anyhow::Result<()> {
109 let max_span = Fp::from(2u64).pow([250, 0, 0, 0]);
110 for (i, &[nf_lo, _, nf_hi]) in ranges.iter().enumerate() {
111 let span = nf_hi - nf_lo;
112 anyhow::ensure!(
113 span <= max_span,
114 "punctured range {i} has span > 2^250: nf_lo={nf_lo:?}, nf_hi={nf_hi:?}"
115 );
116 }
117 Ok(())
118}
119
120/// Pre-compute the empty subtree hash at each tree level.
121///
122/// `empty[0] = hash3(0, 0, 0)` -- the commitment of an all-zero punctured range.
123/// `empty[i] = hash(empty[i-1], empty[i-1])` for higher levels.
124pub fn precompute_empty_hashes() -> [Fp; TREE_DEPTH] {
125 let hasher = PoseidonHasher::new();
126 let mut empty = [Fp::default(); TREE_DEPTH];
127 empty[0] = hasher.hash3(Fp::zero(), Fp::zero(), Fp::zero());
128 for i in 1..TREE_DEPTH {
129 empty[i] = hasher.hash(empty[i - 1], empty[i - 1]);
130 }
131 empty
132}
133
134/// Build Merkle tree levels bottom-up from leaf hashes.
135///
136/// `depth` controls the number of tree levels (use `TREE_DEPTH` for a full
137/// depth-29 tree, or a smaller value like 25 for the PIR tree).
138/// Returns `(root, levels)` where `levels[0]` contains leaf hashes and
139/// `levels[depth-1]` contains the root's two children.
140///
141/// Each level is padded to even length using the pre-computed empty hash so
142/// that pair-wise hashing produces the next level cleanly. All intermediate
143/// layers are retained so Merkle auth paths can be extracted in O(`depth`)
144/// via simple sibling lookups.
145pub fn build_levels(mut leaves: Vec<Fp>, empty: &[Fp; TREE_DEPTH], depth: usize) -> (Fp, Vec<Vec<Fp>>) {
146 let hasher = PoseidonHasher::new();
147 let mut levels: Vec<Vec<Fp>> = Vec::with_capacity(depth);
148
149 // Level 0 = leaf commitments, padded to even length.
150 // Takes ownership of `leaves` to avoid a 1.6 GB memcpy at scale.
151 if leaves.is_empty() {
152 leaves.push(empty[0]);
153 }
154 if leaves.len() & 1 == 1 {
155 leaves.push(empty[0]);
156 }
157 levels.push(leaves);
158
159 const PAR_THRESHOLD: usize = 1024;
160
161 for i in 0..depth - 1 {
162 let prev = &levels[i];
163 let pairs = prev.len() / 2;
164 let mut next: Vec<Fp> = if pairs >= PAR_THRESHOLD {
165 prev.par_chunks_exact(2)
166 .map_init(PoseidonHasher::new, |h, pair| h.hash(pair[0], pair[1]))
167 .collect()
168 } else {
169 (0..pairs)
170 .map(|j| hasher.hash(prev[j * 2], prev[j * 2 + 1]))
171 .collect()
172 };
173 if next.len() & 1 == 1 {
174 next.push(empty[i + 1]);
175 }
176 levels.push(next);
177 }
178
179 let top = &levels[depth - 1];
180 let root = hasher.hash(top[0], top[1]);
181
182 (root, levels)
183}
184
185