use ff::{Field, PrimeField};
use halo2curves::bn256::Fr;
use super::poseidon::poseidon_domain_pair;
pub const FROZEN_IMT_DEPTH: usize = 20;
pub const FROZEN_IMT_LEAF_DOMAIN: u64 = 240;
pub const FROZEN_IMT_NODE_D0: u64 = 256;
#[inline]
pub fn frozen_imt_leaf(val: Fr, next_val: Fr) -> Fr {
poseidon_domain_pair(FROZEN_IMT_LEAF_DOMAIN, val, next_val)
}
fn fr_lt(a: &Fr, b: &Fr) -> bool {
let (ar, br) = (a.to_repr(), b.to_repr());
let (a, b) = (ar.as_ref(), br.as_ref());
for i in (0..a.len()).rev() {
if a[i] != b[i] {
return a[i] < b[i];
}
}
false
}
pub fn fr_from_be_bytes(be: &[u8; 32]) -> Option<Fr> {
let mut le = *be;
le.reverse();
Option::from(Fr::from_repr(le.into()))
}
pub fn fr_to_le_bytes(fr: Fr) -> [u8; 32] {
fr.to_repr().into()
}
pub fn fr_to_be_bytes(fr: Fr) -> [u8; 32] {
let mut be = fr_to_le_bytes(fr);
be.reverse();
be
}
pub fn fr_to_le_hex(fr: Fr) -> String {
format!("0x{}", hex::encode(fr_to_le_bytes(fr)))
}
#[derive(Clone, Copy, Debug)]
struct Leaf {
val: Fr,
next_val: Fr,
}
#[derive(Clone, Debug)]
pub struct FrozenNonMembershipWitness {
pub low_val: Fr,
pub low_next_val: Fr,
pub siblings: [Fr; FROZEN_IMT_DEPTH],
pub path_bits: [Fr; FROZEN_IMT_DEPTH],
}
#[derive(Clone, Debug)]
pub struct FrozenImt {
leaves: Vec<Leaf>,
}
impl FrozenImt {
pub fn new() -> Self {
Self { leaves: vec![Leaf { val: Fr::ZERO, next_val: Fr::ZERO }] }
}
pub fn from_frozen_values(values: &[Fr]) -> Self {
let mut t = Self::new();
for &v in values {
t.insert(v);
}
t
}
pub fn len(&self) -> usize {
self.leaves.len()
}
pub fn frozen_values(&self) -> Vec<Fr> {
self.leaves[1..].iter().map(|l| l.val).collect()
}
pub fn contains(&self, v: Fr) -> bool {
self.leaves.iter().any(|l| l.val == v)
}
pub fn insert(&mut self, v: Fr) -> bool {
if v == Fr::ZERO || self.contains(v) {
return false;
}
let pred = self.bracketing_index(v);
let new_leaf = Leaf { val: v, next_val: self.leaves[pred].next_val };
self.leaves[pred].next_val = v;
self.leaves.push(new_leaf);
true
}
fn bracketing_index(&self, v: Fr) -> usize {
for (i, l) in self.leaves.iter().enumerate() {
let above_low = fr_lt(&l.val, &v);
let below_high = l.next_val == Fr::ZERO || fr_lt(&v, &l.next_val);
if above_low && below_high {
return i;
}
}
0 }
#[inline]
fn leaf_hash(&self, i: usize) -> Fr {
frozen_imt_leaf(self.leaves[i].val, self.leaves[i].next_val)
}
fn empty_at(&self, level: usize) -> Fr {
let mut e = Fr::ZERO;
for i in 0..level {
e = poseidon_domain_pair(FROZEN_IMT_NODE_D0 + i as u64, e, e);
}
e
}
fn subtree_hash(&self, level: usize, idx: usize) -> Fr {
let start = idx << level;
if start >= self.leaves.len() {
return self.empty_at(level);
}
if level == 0 {
return self.leaf_hash(start);
}
let left = self.subtree_hash(level - 1, idx * 2);
let right = self.subtree_hash(level - 1, idx * 2 + 1);
poseidon_domain_pair(FROZEN_IMT_NODE_D0 + (level - 1) as u64, left, right)
}
pub fn root(&self) -> Fr {
self.subtree_hash(FROZEN_IMT_DEPTH, 0)
}
fn witness_at(&self, pos: usize) -> ([Fr; FROZEN_IMT_DEPTH], [Fr; FROZEN_IMT_DEPTH]) {
let mut siblings = [Fr::ZERO; FROZEN_IMT_DEPTH];
let mut path_bits = [Fr::ZERO; FROZEN_IMT_DEPTH];
for level in 0..FROZEN_IMT_DEPTH {
path_bits[level] = if (pos >> level) & 1 == 1 { Fr::ONE } else { Fr::ZERO };
siblings[level] = self.subtree_hash(level, (pos >> level) ^ 1);
}
(siblings, path_bits)
}
pub fn non_membership_witness(&self, cmx: Fr) -> Option<FrozenNonMembershipWitness> {
if self.contains(cmx) {
return None;
}
let pos = self.bracketing_index(cmx);
let low = self.leaves[pos];
let (siblings, path_bits) = self.witness_at(pos);
Some(FrozenNonMembershipWitness {
low_val: low.val,
low_next_val: low.next_val,
siblings,
path_bits,
})
}
}
impl Default for FrozenImt {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_root_matches_perc20_constant() {
const DEC: &str =
"9079151408671112139333676443195611613776084922747126087146403043120709007371";
let expected = Fr::from_str_vartime(DEC).unwrap();
assert_eq!(FrozenImt::new().root(), expected);
}
#[test]
fn witness_reproduces_root_and_brackets() {
let mut t = FrozenImt::new();
for v in [50u64, 10, 99, 7] {
assert!(t.insert(Fr::from(v)));
}
let root = t.root();
let cmx = Fr::from(42u64); let w = t.non_membership_witness(cmx).expect("non-member");
assert!(fr_lt(&w.low_val, &cmx));
assert!(w.low_next_val == Fr::ZERO || fr_lt(&cmx, &w.low_next_val));
assert_eq!(recompute_root(&w), root, "witness must reproduce rt_frozen");
}
#[test]
fn frozen_value_has_no_witness() {
let mut t = FrozenImt::new();
t.insert(Fr::from(123u64));
assert!(t.non_membership_witness(Fr::from(123u64)).is_none());
}
#[test]
fn rebuild_from_values_reproduces_root() {
let mut t = FrozenImt::new();
for v in [3u64, 1, 4, 1, 5, 9, 2, 6] {
t.insert(Fr::from(v));
}
let rebuilt = FrozenImt::from_frozen_values(&t.frozen_values());
assert_eq!(rebuilt.frozen_values(), t.frozen_values());
assert_eq!(rebuilt.root(), t.root());
}
fn recompute_root(w: &FrozenNonMembershipWitness) -> Fr {
let mut level = frozen_imt_leaf(w.low_val, w.low_next_val);
for i in 0..FROZEN_IMT_DEPTH {
let bit = w.path_bits[i];
let diff = level - w.siblings[i];
let left = level - bit * diff;
let right = w.siblings[i] + bit * diff;
level = poseidon_domain_pair(FROZEN_IMT_NODE_D0 + i as u64, left, right);
}
level
}
}