use ff::Field as _;
use pasta_curves::Fp;
use rayon::prelude::*;
pub(crate) use crate::hasher::PoseidonHasher;
#[cfg(test)]
mod tests;
pub const TREE_DEPTH: usize = 29;
pub type PuncturedRange = [Fp; 3];
pub fn build_punctured_ranges(sorted_nfs: &[Fp]) -> Vec<PuncturedRange> {
let n = sorted_nfs.len();
assert!(
n >= 3,
"need at least 3 sorted nullifiers for K=2 punctured ranges, got {n}"
);
assert!(
n % 2 == 1,
"sorted nullifier count must be odd for K=2 (got {n}); \
inject an additional sentinel to fix"
);
let num_leaves = (n - 1) / 2;
(0..num_leaves)
.map(|i| {
let base = i * 2;
let (lo, mid, hi) = (sorted_nfs[base], sorted_nfs[base + 1], sorted_nfs[base + 2]);
assert!(
lo < mid && mid < hi,
"punctured range {i} violates strict ordering: \
nf_lo={lo:?}, nf_mid={mid:?}, nf_hi={hi:?} \
(input must be sorted and deduplicated)"
);
[lo, mid, hi]
})
.collect()
}
pub fn commit_punctured_ranges(ranges: &[PuncturedRange]) -> Vec<Fp> {
ranges
.par_iter()
.map_init(PoseidonHasher::new, |hasher, &[a, b, c]| {
hasher.hash3(a, b, c)
})
.collect()
}
pub fn find_punctured_range_for_value(ranges: &[PuncturedRange], value: Fp) -> Option<usize> {
let i = ranges.partition_point(|[nf_lo, _, _]| *nf_lo < value);
if i == 0 {
return None;
}
let idx = i - 1;
let [nf_lo, nf_mid, nf_hi] = ranges[idx];
let offset = value - nf_lo;
let span = nf_hi - nf_lo;
if offset == Fp::zero() || offset >= span {
return None;
}
if value == nf_mid {
return None;
}
Some(idx)
}
pub fn verify_punctured_range_spans(ranges: &[PuncturedRange]) -> anyhow::Result<()> {
let max_span = Fp::from(2u64).pow([250, 0, 0, 0]);
for (i, &[nf_lo, _, nf_hi]) in ranges.iter().enumerate() {
let span = nf_hi - nf_lo;
anyhow::ensure!(
span <= max_span,
"punctured range {i} has span > 2^250: nf_lo={nf_lo:?}, nf_hi={nf_hi:?}"
);
}
Ok(())
}
pub fn precompute_empty_hashes() -> [Fp; TREE_DEPTH] {
let hasher = PoseidonHasher::new();
let mut empty = [Fp::default(); TREE_DEPTH];
empty[0] = hasher.hash3(Fp::zero(), Fp::zero(), Fp::zero());
for i in 1..TREE_DEPTH {
empty[i] = hasher.hash(empty[i - 1], empty[i - 1]);
}
empty
}
pub fn build_levels(mut leaves: Vec<Fp>, empty: &[Fp; TREE_DEPTH], depth: usize) -> (Fp, Vec<Vec<Fp>>) {
let hasher = PoseidonHasher::new();
let mut levels: Vec<Vec<Fp>> = Vec::with_capacity(depth);
if leaves.is_empty() {
leaves.push(empty[0]);
}
if leaves.len() & 1 == 1 {
leaves.push(empty[0]);
}
levels.push(leaves);
const PAR_THRESHOLD: usize = 1024;
for i in 0..depth - 1 {
let prev = &levels[i];
let pairs = prev.len() / 2;
let mut next: Vec<Fp> = if pairs >= PAR_THRESHOLD {
prev.par_chunks_exact(2)
.map_init(PoseidonHasher::new, |h, pair| h.hash(pair[0], pair[1]))
.collect()
} else {
(0..pairs)
.map(|j| hasher.hash(prev[j * 2], prev[j * 2 + 1]))
.collect()
};
if next.len() & 1 == 1 {
next.push(empty[i + 1]);
}
levels.push(next);
}
let top = &levels[depth - 1];
let root = hasher.hash(top[0], top[1]);
(root, levels)
}