use pasta_curves::pallas;
use std::string::String;
use crate::protocol_hash::{poseidon_hash_2, poseidon_hash_3};
pub const IMT_DEPTH: usize = 29;
pub(super) use crate::domain_tags::governance_authorization as gov_auth_domain_tag;
pub fn derive_nullifier_domain(vote_round_id: pallas::Base) -> pallas::Base {
poseidon_hash_2(gov_auth_domain_tag(), vote_round_id)
}
pub fn gov_null_hash(nk: pallas::Base, dom: pallas::Base, real_nf: pallas::Base) -> pallas::Base {
poseidon_hash_3(nk, dom, real_nf)
}
#[derive(Clone, Debug)]
pub struct ImtProofData {
pub root: pallas::Base,
pub nf_bounds: [pallas::Base; 3],
pub leaf_pos: u32,
pub path: [pallas::Base; IMT_DEPTH],
}
#[derive(Clone, Debug)]
pub struct ImtError(pub String);
impl core::fmt::Display for ImtError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "IMT error: {}", self.0)
}
}
impl std::error::Error for ImtError {}
pub trait ImtProvider {
fn root(&self) -> pallas::Base;
fn non_membership_proof(&self, nf: pallas::Base) -> Result<ImtProofData, ImtError>;
}
use ff::Field;
use std::vec::Vec;
fn empty_imt_hashes() -> Vec<pallas::Base> {
let empty_leaf = poseidon_hash_3(
pallas::Base::zero(),
pallas::Base::zero(),
pallas::Base::zero(),
);
let mut hashes = vec![empty_leaf];
for _ in 1..=IMT_DEPTH {
let prev = *hashes.last().unwrap();
hashes.push(poseidon_hash_2(prev, prev));
}
hashes
}
const SENTINEL_EXPONENT: u64 = 249;
const SENTINEL_COUNT: u64 = 32;
pub fn build_sentinel_list() -> Vec<pallas::Base> {
let step = pallas::Base::from(2u64).pow([SENTINEL_EXPONENT, 0, 0, 0]);
let mut nfs: Vec<pallas::Base> = (0u64..=SENTINEL_COUNT)
.map(|k| step * pallas::Base::from(k))
.collect();
nfs.push(-pallas::Base::one()); nfs.sort();
nfs.dedup();
if nfs.len() % 2 == 0 {
debug_assert_eq!(nfs[0], pallas::Base::zero(), "sentinel 0 must be first");
nfs.insert(1, pallas::Base::from(2u64));
}
nfs
}
pub fn build_nullifier_list(extra_nfs: &[pallas::Base]) -> Vec<pallas::Base> {
let mut nfs = build_sentinel_list();
nfs.extend_from_slice(extra_nfs);
nfs.sort();
nfs.dedup();
if nfs.len() % 2 == 0 {
let padding = std::iter::once(2u64)
.chain(1u64..)
.map(pallas::Base::from)
.find(|candidate| nfs.binary_search(candidate).is_err())
.expect("small field-element padding candidate should exist");
let insert_at = nfs.binary_search(&padding).unwrap_err();
nfs.insert(insert_at, padding);
}
nfs
}
fn build_punctured_ranges_local(sorted_nfs: &[pallas::Base]) -> Vec<[pallas::Base; 3]> {
let n = sorted_nfs.len();
assert!(n >= 3, "need at least 3 sorted nullifiers, got {n}");
assert!(n % 2 == 1, "sorted nullifier count must be odd (got {n})");
let num_leaves = (n - 1) / 2;
(0..num_leaves)
.map(|i| {
let base = i * 2;
let (lo, mid, hi) = (sorted_nfs[base], sorted_nfs[base + 1], sorted_nfs[base + 2]);
assert!(
lo < mid && mid < hi,
"punctured range {i} violates strict ordering: \
nf_lo={lo:?}, nf_mid={mid:?}, nf_hi={hi:?}"
);
[lo, mid, hi]
})
.collect()
}
fn find_range_for_value(ranges: &[[pallas::Base; 3]], value: pallas::Base) -> Option<usize> {
let i = ranges.partition_point(|[nf_lo, _, _]| *nf_lo < value);
if i == 0 {
return None;
}
let idx = i - 1;
let [nf_lo, nf_mid, nf_hi] = ranges[idx];
let offset = value - nf_lo;
let span = nf_hi - nf_lo;
if offset == pallas::Base::zero() || offset >= span {
return None;
}
if value == nf_mid {
return None;
}
Some(idx)
}
#[derive(Debug)]
pub struct SpacedLeafImtProvider {
root: pallas::Base,
leaves: Vec<[pallas::Base; 3]>,
subtree_levels: Vec<Vec<pallas::Base>>,
}
impl Default for SpacedLeafImtProvider {
fn default() -> Self {
Self::new()
}
}
impl SpacedLeafImtProvider {
pub fn new() -> Self {
Self::with_extra_nullifiers(&[])
}
pub fn with_extra_nullifiers(extra_nfs: &[pallas::Base]) -> Self {
let sorted_nfs = build_nullifier_list(extra_nfs);
let leaves = build_punctured_ranges_local(&sorted_nfs);
assert!(
leaves.len() <= 32,
"spaced-leaf fixture supports at most 32 leaves, got {}",
leaves.len()
);
let empty = empty_imt_hashes();
let empty_leaf_hash = poseidon_hash_3(
pallas::Base::zero(),
pallas::Base::zero(),
pallas::Base::zero(),
);
let mut level0 = vec![empty_leaf_hash; 32];
for (k, bounds) in leaves.iter().enumerate() {
level0[k] = poseidon_hash_3(bounds[0], bounds[1], bounds[2]);
}
let mut subtree_levels = vec![level0];
for _l in 1..=5 {
let prev = subtree_levels.last().unwrap();
let mut current = Vec::with_capacity(prev.len() / 2);
for j in 0..(prev.len() / 2) {
current.push(poseidon_hash_2(prev[2 * j], prev[2 * j + 1]));
}
subtree_levels.push(current);
}
let mut root = subtree_levels[5][0];
for l in 5..IMT_DEPTH {
root = poseidon_hash_2(root, empty[l]);
}
SpacedLeafImtProvider {
root,
leaves,
subtree_levels,
}
}
}
impl ImtProvider for SpacedLeafImtProvider {
fn root(&self) -> pallas::Base {
self.root
}
fn non_membership_proof(&self, nf: pallas::Base) -> Result<ImtProofData, ImtError> {
let k = find_range_for_value(&self.leaves, nf)
.ok_or_else(|| ImtError(format!("nullifier {nf:?} not in any punctured range")))?;
let nf_bounds = self.leaves[k];
let leaf_pos = k as u32;
let empty = empty_imt_hashes();
let mut path = [pallas::Base::zero(); IMT_DEPTH];
let mut idx = k;
for l in 0..5 {
let sibling_idx = idx ^ 1;
path[l] = self.subtree_levels[l][sibling_idx];
idx >>= 1;
}
path[5..IMT_DEPTH].copy_from_slice(&empty[5..IMT_DEPTH]);
Ok(ImtProofData {
root: self.root,
nf_bounds,
leaf_pos,
path,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use ff::PrimeField;
fn base_from_repr(bytes: [u8; 32]) -> pallas::Base {
pallas::Base::from_repr(bytes).expect("frozen vector must be canonical")
}
#[test]
fn derive_nullifier_domain_frozen_vector() {
assert_eq!(
derive_nullifier_domain(pallas::Base::from(42u64)),
base_from_repr([
202, 12, 215, 224, 168, 199, 68, 160, 148, 160, 237, 250, 131, 157, 181, 207, 158,
105, 141, 50, 135, 245, 182, 83, 151, 198, 14, 254, 122, 79, 78, 23,
])
);
}
#[test]
fn gov_null_hash_frozen_vector() {
assert_eq!(
gov_null_hash(
pallas::Base::from(1u64),
pallas::Base::from(2u64),
pallas::Base::from(3u64),
),
base_from_repr([
234, 252, 225, 20, 190, 170, 130, 80, 54, 152, 212, 172, 198, 24, 120, 139, 100,
140, 198, 64, 152, 34, 38, 95, 158, 62, 234, 30, 198, 66, 171, 24,
])
);
}
#[test]
fn find_range_for_value_rejects_punctured_interval_boundaries() {
let ranges = [
[
pallas::Base::from(10u64),
pallas::Base::from(15u64),
pallas::Base::from(20u64),
],
[
pallas::Base::from(20u64),
pallas::Base::from(25u64),
pallas::Base::from(30u64),
],
];
for rejected in [9u64, 10, 15, 20, 25, 30, 31] {
assert_eq!(
find_range_for_value(&ranges, pallas::Base::from(rejected)),
None,
"boundary value {rejected} must not produce an IMT proof"
);
}
assert_eq!(
find_range_for_value(&ranges, pallas::Base::from(11u64)),
Some(0)
);
assert_eq!(
find_range_for_value(&ranges, pallas::Base::from(19u64)),
Some(0)
);
assert_eq!(
find_range_for_value(&ranges, pallas::Base::from(21u64)),
Some(1)
);
assert_eq!(
find_range_for_value(&ranges, pallas::Base::from(29u64)),
Some(1)
);
}
}