use ark_crypto_primitives::{Error, CRH};
use ark_ff::{to_bytes, PrimeField, ToBytes};
use ark_std::{
borrow::Borrow,
collections::{BTreeMap, BTreeSet},
io::{Result as IoResult, Write},
rc::Rc,
vec::Vec,
};
use core::convert::TryInto;
#[cfg(feature = "r1cs")]
pub mod constraints;
pub mod simple_merkle;
pub trait Config: Clone + PartialEq {
const HEIGHT: u8;
type H: CRH;
type LeafH: CRH;
}
type InnerNode<P> = <<P as Config>::H as CRH>::Output;
type LeafNode<P> = <<P as Config>::LeafH as CRH>::Output;
type InnerParameters<P> = <<P as Config>::H as CRH>::Parameters;
type LeafParameters<P> = <<P as Config>::LeafH as CRH>::Parameters;
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum Node<P: Config> {
Leaf(LeafNode<P>),
Inner(InnerNode<P>),
}
impl<P: Config> Node<P> {
pub fn inner(self) -> InnerNode<P> {
match self {
Node::Inner(inner) => inner,
_ => panic!("Not inner node!"),
}
}
pub fn leaf(self) -> LeafNode<P> {
match self {
Node::Leaf(leaf) => leaf,
_ => panic!("Not leaf node!"),
}
}
}
impl<P: Config> ToBytes for Node<P> {
fn write<W: Write>(&self, writer: W) -> IoResult<()> {
match self {
Self::Inner(inner) => inner.write(writer),
Self::Leaf(leaf) => leaf.write(writer),
}
}
}
#[derive(Clone)]
pub struct Path<P: Config, const N: usize> {
pub(crate) path: [(Node<P>, Node<P>); N],
leaf_params: Rc<LeafParameters<P>>,
inner_params: Rc<InnerParameters<P>>,
}
impl<P: Config + PartialEq, const N: usize> Path<P, N> {
pub fn check_membership<L: ToBytes>(
&self,
root_hash: &Node<P>,
leaf: &L,
) -> Result<bool, Error> {
let prev = self.root_hash(leaf)?;
Ok(root_hash == &prev)
}
pub fn get_index<L: ToBytes, F: PrimeField>(
&self,
root_hash: &Node<P>,
leaf: &L,
) -> Result<F, Error> {
if !self.check_membership(root_hash, leaf)? {
panic!("Leaf is not in the path");
}
let mut prev = hash_leaf::<P, L>(self.leaf_params.borrow(), leaf)?;
let mut index = F::zero();
let mut twopower = F::one();
for &(ref left_hash, ref right_hash) in &self.path {
if &prev != left_hash {
index += twopower;
}
twopower = twopower + twopower;
prev = hash_inner_node::<P>(self.inner_params.borrow(), left_hash, right_hash)?;
}
Ok(index)
}
pub fn root_hash<L: ToBytes>(&self, leaf: &L) -> Result<Node<P>, Error> {
if self.path.len() != P::HEIGHT as usize {
panic!("path.len != P::HEIGHT");
}
let claimed_leaf_hash = hash_leaf::<P, L>(self.leaf_params.borrow(), leaf)?;
if claimed_leaf_hash != self.path[0].0 && claimed_leaf_hash != self.path[0].1 {
panic!("Leaf is not on Path");
}
let mut prev = claimed_leaf_hash;
for &(ref left_hash, ref right_hash) in &self.path {
if &prev != left_hash && &prev != right_hash {
panic!("Path nodes are not consistent");
}
prev = hash_inner_node::<P>(self.inner_params.borrow(), left_hash, right_hash)?;
}
Ok(prev)
}
}
pub struct SparseMerkleTree<P: Config> {
pub tree: BTreeMap<u64, Node<P>>,
empty_hashes: Vec<Node<P>>,
leaf_params: Rc<<P::LeafH as CRH>::Parameters>,
inner_params: Rc<<P::H as CRH>::Parameters>,
}
impl<P: Config> SparseMerkleTree<P> {
pub fn blank(inner_params: Rc<InnerParameters<P>>, leaf_params: Rc<LeafParameters<P>>) -> Self {
let empty_hashes =
gen_empty_hashes::<P>(leaf_params.borrow(), inner_params.borrow()).unwrap();
SparseMerkleTree {
tree: BTreeMap::new(),
empty_hashes,
inner_params,
leaf_params,
}
}
pub fn insert_batch<L: Default + ToBytes>(
&mut self,
leaves: &BTreeMap<u32, L>,
) -> Result<(), Error> {
let last_level_index: u64 = (1u64 << P::HEIGHT) - 1;
let mut level_idxs: BTreeSet<u64> = BTreeSet::new();
for (i, leaf) in leaves {
let true_index = last_level_index + (*i as u64);
let leaf_hash = hash_leaf::<P, _>(self.leaf_params.borrow(), leaf)?;
self.tree.insert(true_index, leaf_hash);
level_idxs.insert(parent(true_index).unwrap());
}
for level in 0..P::HEIGHT {
let mut new_idxs: BTreeSet<u64> = BTreeSet::new();
for i in level_idxs {
let left_index = left_child(i);
let right_index = right_child(i);
let empty_hash = self.empty_hashes[level as usize].clone();
let left = self.tree.get(&left_index).unwrap_or(&empty_hash);
let right = self.tree.get(&right_index).unwrap_or(&empty_hash);
#[allow(mutable_borrow_reservation_conflict)]
self.tree.insert(
i,
hash_inner_node::<P>(self.inner_params.borrow(), left, right)?,
);
let parent = match parent(i) {
Some(i) => i,
None => break,
};
new_idxs.insert(parent);
}
level_idxs = new_idxs;
}
Ok(())
}
pub fn new<L: Default + ToBytes>(
inner_params: Rc<InnerParameters<P>>,
leaf_params: Rc<LeafParameters<P>>,
leaves: &BTreeMap<u32, L>,
) -> Result<Self, Error> {
let last_level_size = leaves.len().next_power_of_two();
let tree_size = 2 * last_level_size - 1;
let tree_height = tree_height(tree_size as u64);
assert!(tree_height <= P::HEIGHT as u32);
let tree: BTreeMap<u64, Node<P>> = BTreeMap::new();
let empty_hashes = gen_empty_hashes::<P>(leaf_params.borrow(), inner_params.borrow())?;
let mut smt = SparseMerkleTree {
tree,
empty_hashes,
inner_params,
leaf_params,
};
smt.insert_batch(leaves)?;
Ok(smt)
}
pub fn new_sequential<L: Default + ToBytes + Clone>(
inner_params: Rc<InnerParameters<P>>,
leaf_params: Rc<LeafParameters<P>>,
leaves: &[L],
) -> Result<Self, Error> {
let pairs: BTreeMap<u32, L> = leaves
.iter()
.enumerate()
.map(|(i, l)| (i as u32, l.clone()))
.collect();
let smt = Self::new(inner_params, leaf_params, &pairs)?;
Ok(smt)
}
#[inline]
pub fn root(&self) -> Node<P> {
self.tree.get(&0).cloned().unwrap()
}
pub fn generate_membership_proof<const N: usize>(&self, index: u64) -> Path<P, N> {
let mut path = Vec::with_capacity(N);
let tree_index = convert_index_to_last_level::<P>(index);
let mut current_node = tree_index;
let mut level = 0;
while !is_root(current_node) {
let sibling_node = sibling(current_node).unwrap();
let empty_hash = &self.empty_hashes[level];
let current = self
.tree
.get(¤t_node)
.cloned()
.unwrap_or_else(|| empty_hash.clone());
let sibling = self
.tree
.get(&sibling_node)
.cloned()
.unwrap_or_else(|| empty_hash.clone());
if is_left_child(current_node) {
path.push((current, sibling));
} else {
path.push((sibling, current));
}
current_node = parent(current_node).unwrap();
level += 1;
}
Path {
path: path
.try_into()
.unwrap_or_else(|v: Vec<(Node<P>, Node<P>)>| {
panic!("Expected a Vec of length {} but it was {}", N, v.len())
}),
inner_params: Rc::clone(&self.inner_params),
leaf_params: Rc::clone(&self.leaf_params),
}
}
}
#[inline]
fn log2(number: u64) -> u32 {
ark_std::log2(number as usize)
}
#[inline]
fn tree_height(tree_size: u64) -> u32 {
log2(tree_size)
}
#[inline]
fn is_root(index: u64) -> bool {
index == 0
}
#[inline]
fn left_child(index: u64) -> u64 {
2 * index + 1
}
#[inline]
fn right_child(index: u64) -> u64 {
2 * index + 2
}
#[inline]
fn sibling(index: u64) -> Option<u64> {
if index == 0 {
None
} else if is_left_child(index) {
Some(index + 1)
} else {
Some(index - 1)
}
}
#[inline]
fn is_left_child(index: u64) -> bool {
index % 2 == 1
}
#[inline]
fn parent(index: u64) -> Option<u64> {
if index > 0 {
Some((index - 1) >> 1)
} else {
None
}
}
#[inline]
fn convert_index_to_last_level<P: Config>(index: u64) -> u64 {
index + (1u64 << P::HEIGHT) - 1
}
pub(crate) fn hash_inner_node<P: Config>(
parameters: &<P::H as CRH>::Parameters,
left: &Node<P>,
right: &Node<P>,
) -> Result<Node<P>, Error> {
let bytes = to_bytes![left, right]?;
let inner = <P::H as CRH>::evaluate(parameters, &bytes)?;
Ok(Node::Inner(inner))
}
fn hash_leaf<P: Config, L: ToBytes>(
parameters: &<P::LeafH as CRH>::Parameters,
leaf: &L,
) -> Result<Node<P>, Error> {
let leaf = <P::LeafH as CRH>::evaluate(parameters, &to_bytes![leaf]?)?;
Ok(Node::Leaf(leaf))
}
fn hash_empty<P: Config>(parameters: &<P::LeafH as CRH>::Parameters) -> Result<Node<P>, Error> {
let res = <P::LeafH as CRH>::evaluate(parameters, &vec![
0u8;
<P::LeafH as CRH>::INPUT_SIZE_BITS / 8
])?;
Ok(Node::Leaf(res))
}
pub fn gen_empty_hashes<P: Config>(
leaf_params: &LeafParameters<P>,
inner_params: &InnerParameters<P>,
) -> Result<Vec<Node<P>>, Error> {
let mut empty_hashes = Vec::with_capacity(P::HEIGHT as usize);
let mut empty_hash = hash_empty::<P>(leaf_params)?;
empty_hashes.push(empty_hash.clone());
for _ in 1..=P::HEIGHT {
empty_hash = hash_inner_node::<P>(inner_params, &empty_hash, &empty_hash)?;
empty_hashes.push(empty_hash.clone());
}
Ok(empty_hashes)
}
#[cfg(test)]
mod test {
use super::{gen_empty_hashes, hash_inner_node, hash_leaf, Config, SparseMerkleTree};
use crate::poseidon::CRH as PoseidonCRH;
use ark_bls12_381::Fq;
use ark_crypto_primitives::crh::CRH;
use ark_ff::{ToBytes, UniformRand};
use ark_std::{borrow::Borrow, collections::BTreeMap, rc::Rc, test_rng};
use arkworks_utils::{
mimc::MiMCParameters,
utils::common::{setup_params_x5_3, Curve},
};
type SMTCRH = PoseidonCRH<Fq>;
#[derive(Clone, Debug, Eq, PartialEq)]
struct SMTConfig;
impl Config for SMTConfig {
type H = SMTCRH;
type LeafH = SMTCRH;
const HEIGHT: u8 = 3;
}
fn create_merkle_tree<L: Default + ToBytes + Copy, C: Config>(
inner_params: Rc<<C::H as CRH>::Parameters>,
leaf_params: Rc<<C::LeafH as CRH>::Parameters>,
leaves: &[L],
) -> SparseMerkleTree<C> {
let pairs: BTreeMap<u32, L> = leaves
.iter()
.enumerate()
.map(|(i, l)| (i as u32, *l))
.collect();
let smt = SparseMerkleTree::<C>::new(inner_params, leaf_params, &pairs).unwrap();
smt
}
#[test]
fn should_create_tree_poseidon() {
let rng = &mut test_rng();
let curve = Curve::Bls381;
let params3 = setup_params_x5_3(curve);
let inner_params = Rc::new(params3);
let leaf_params = inner_params.clone();
let leaves = vec![Fq::rand(rng), Fq::rand(rng), Fq::rand(rng)];
let smt = create_merkle_tree(inner_params.clone(), leaf_params.clone(), &leaves);
let root = smt.root();
let empty_hashes =
gen_empty_hashes::<SMTConfig>(inner_params.borrow(), leaf_params.borrow()).unwrap();
let hash1 = hash_leaf::<SMTConfig, _>(leaf_params.borrow(), &leaves[0]).unwrap();
let hash2 = hash_leaf::<SMTConfig, _>(leaf_params.borrow(), &leaves[1]).unwrap();
let hash3 = hash_leaf::<SMTConfig, _>(leaf_params.borrow(), &leaves[2]).unwrap();
let hash12 = hash_inner_node::<SMTConfig>(inner_params.borrow(), &hash1, &hash2).unwrap();
let hash34 =
hash_inner_node::<SMTConfig>(inner_params.borrow(), &hash3, &empty_hashes[0]).unwrap();
let hash1234 =
hash_inner_node::<SMTConfig>(inner_params.borrow(), &hash12, &hash34).unwrap();
let calc_root =
hash_inner_node::<SMTConfig>(inner_params.borrow(), &hash1234, &empty_hashes[2])
.unwrap();
assert_eq!(root, calc_root);
}
#[test]
fn should_generate_and_validate_proof_poseidon() {
let rng = &mut test_rng();
let curve = Curve::Bls381;
let params3 = setup_params_x5_3(curve);
let inner_params = Rc::new(params3);
let leaf_params = inner_params.clone();
let leaves = vec![Fq::rand(rng), Fq::rand(rng), Fq::rand(rng)];
let smt = create_merkle_tree::<_, SMTConfig>(inner_params, leaf_params, &leaves);
let proof = smt.generate_membership_proof::<{ SMTConfig::HEIGHT as usize }>(0);
let res = proof.check_membership(&smt.root(), &leaves[0]).unwrap();
assert!(res);
}
#[test]
fn should_find_the_index_poseidon() {
let rng = &mut test_rng();
let curve = Curve::Bls381;
let params3 = setup_params_x5_3(curve);
let inner_params = Rc::new(params3);
let leaf_params = inner_params.clone();
let index = 2;
let leaves = vec![Fq::rand(rng), Fq::rand(rng), Fq::rand(rng)];
let smt = create_merkle_tree::<_, SMTConfig>(inner_params, leaf_params, &leaves);
let proof = smt.generate_membership_proof::<{ MiMCSMTConfig::HEIGHT as usize }>(index);
let res: Fq = proof
.get_index(&smt.root(), &leaves[index as usize])
.unwrap();
let desired_res = Fq::from(index);
assert_eq!(res, desired_res)
}
use crate::mimc::Rounds as MiMCRounds;
use ark_ed_on_bn254::Fq as Bn254Fq;
#[derive(Default, Clone)]
struct MiMCRounds220_2;
impl crate::mimc::Rounds for MiMCRounds220_2 {
const ROUNDS: usize = 220;
const WIDTH: usize = 2;
}
type MiMC220_2 = crate::mimc::CRH<Bn254Fq, MiMCRounds220_2>;
#[derive(Clone, Debug, Eq, PartialEq)]
struct MiMCSMTConfig;
impl Config for MiMCSMTConfig {
type H = MiMC220_2;
type LeafH = MiMC220_2;
const HEIGHT: u8 = 3;
}
#[test]
fn should_create_tree_mimc() {
let rng = &mut test_rng();
let params = MiMCParameters::<Bn254Fq>::new(
Bn254Fq::from(0),
MiMCRounds220_2::ROUNDS,
MiMCRounds220_2::WIDTH,
MiMCRounds220_2::WIDTH,
arkworks_utils::utils::get_rounds_mimc_220(),
);
let inner_params = Rc::new(params);
let leaf_params = inner_params.clone();
let leaves = vec![Bn254Fq::rand(rng), Bn254Fq::rand(rng), Bn254Fq::rand(rng)];
let smt = create_merkle_tree(inner_params.clone(), leaf_params.clone(), &leaves);
let root = smt.root();
let empty_hashes =
gen_empty_hashes::<MiMCSMTConfig>(inner_params.borrow(), leaf_params.borrow()).unwrap();
let hash1 = hash_leaf::<MiMCSMTConfig, _>(leaf_params.borrow(), &leaves[0]).unwrap();
let hash2 = hash_leaf::<MiMCSMTConfig, _>(leaf_params.borrow(), &leaves[1]).unwrap();
let hash3 = hash_leaf::<MiMCSMTConfig, _>(leaf_params.borrow(), &leaves[2]).unwrap();
let hash12 =
hash_inner_node::<MiMCSMTConfig>(inner_params.borrow(), &hash1, &hash2).unwrap();
let hash34 =
hash_inner_node::<MiMCSMTConfig>(inner_params.borrow(), &hash3, &empty_hashes[0])
.unwrap();
let hash1234 =
hash_inner_node::<MiMCSMTConfig>(inner_params.borrow(), &hash12, &hash34).unwrap();
let calc_root =
hash_inner_node::<MiMCSMTConfig>(inner_params.borrow(), &hash1234, &empty_hashes[2])
.unwrap();
assert_eq!(root, calc_root);
}
#[test]
fn should_generate_and_validate_proof_mimc() {
let rng = &mut test_rng();
let params = MiMCParameters::<Bn254Fq>::new(
Bn254Fq::from(0),
MiMCRounds220_2::ROUNDS,
MiMCRounds220_2::WIDTH,
MiMCRounds220_2::WIDTH,
arkworks_utils::utils::get_rounds_mimc_220(),
);
let inner_params = Rc::new(params);
let leaf_params = inner_params.clone();
let leaves = vec![Bn254Fq::rand(rng), Bn254Fq::rand(rng), Bn254Fq::rand(rng)];
let smt = create_merkle_tree::<_, MiMCSMTConfig>(inner_params, leaf_params, &leaves);
let proof = smt.generate_membership_proof::<{ MiMCSMTConfig::HEIGHT as usize }>(0);
let res = proof.check_membership(&smt.root(), &leaves[0]).unwrap();
assert!(res);
}
#[test]
fn should_find_the_index_mimc() {
let rng = &mut test_rng();
let params = MiMCParameters::<Bn254Fq>::new(
Bn254Fq::from(0),
MiMCRounds220_2::ROUNDS,
MiMCRounds220_2::WIDTH,
MiMCRounds220_2::WIDTH,
arkworks_utils::utils::get_rounds_mimc_220(),
);
let inner_params = Rc::new(params);
let leaf_params = inner_params.clone();
let index = 2;
let leaves = vec![Bn254Fq::rand(rng), Bn254Fq::rand(rng), Bn254Fq::rand(rng)];
let smt = create_merkle_tree::<_, MiMCSMTConfig>(inner_params, leaf_params, &leaves);
let proof = smt.generate_membership_proof::<{ MiMCSMTConfig::HEIGHT as usize }>(index);
let res: Fq = proof
.get_index(&smt.root(), &leaves[index as usize])
.unwrap();
let desired_res = Fq::from(index);
assert_eq!(res, desired_res)
}
}