#![allow(clippy::needless_range_loop)]
use crate::crh::TwoToOneCRHScheme;
use crate::{crh::CRHScheme, Error};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::borrow::Borrow;
use ark_std::hash::Hash;
use ark_std::vec::Vec;
#[cfg(test)]
mod tests;
#[cfg(feature = "r1cs")]
pub mod constraints;
pub trait DigestConverter<From, To: ?Sized> {
type TargetType: Borrow<To>;
fn convert(item: From) -> Result<Self::TargetType, Error>;
}
pub struct IdentityDigestConverter<T> {
_prev_layer_digest: T,
}
impl<T> DigestConverter<T, T> for IdentityDigestConverter<T> {
type TargetType = T;
fn convert(item: T) -> Result<T, Error> {
Ok(item)
}
}
pub struct ByteDigestConverter<T: CanonicalSerialize> {
_prev_layer_digest: T,
}
impl<T: CanonicalSerialize> DigestConverter<T, [u8]> for ByteDigestConverter<T> {
type TargetType = Vec<u8>;
fn convert(item: T) -> Result<Self::TargetType, Error> {
Ok(crate::to_uncompressed_bytes!(item)?)
}
}
pub trait Config {
type Leaf: ?Sized; type LeafDigest: Clone
+ Eq
+ core::fmt::Debug
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
type LeafInnerDigestConverter: DigestConverter<
Self::LeafDigest,
<Self::TwoToOneHash as TwoToOneCRHScheme>::Input,
>;
type InnerDigest: Clone
+ Eq
+ core::fmt::Debug
+ Hash
+ Default
+ CanonicalSerialize
+ CanonicalDeserialize;
type LeafHash: CRHScheme<Input = Self::Leaf, Output = Self::LeafDigest>;
type TwoToOneHash: TwoToOneCRHScheme<Output = Self::InnerDigest>;
}
pub type TwoToOneParam<P> = <<P as Config>::TwoToOneHash as TwoToOneCRHScheme>::Parameters;
pub type LeafParam<P> = <<P as Config>::LeafHash as CRHScheme>::Parameters;
#[derive(Derivative, CanonicalSerialize, CanonicalDeserialize)]
#[derivative(
Clone(bound = "P: Config"),
Debug(bound = "P: Config"),
Default(bound = "P: Config")
)]
pub struct Path<P: Config> {
pub leaf_sibling_hash: P::LeafDigest,
pub auth_path: Vec<P::InnerDigest>,
pub leaf_index: usize,
}
impl<P: Config> Path<P> {
#[allow(unused)] fn position_list(&'_ self) -> impl '_ + Iterator<Item = bool> {
(0..self.auth_path.len() + 1)
.map(move |i| ((self.leaf_index >> i) & 1) != 0)
.rev()
}
}
impl<P: Config> Path<P> {
pub fn verify<L: Borrow<P::Leaf>>(
&self,
leaf_hash_params: &LeafParam<P>,
two_to_one_params: &TwoToOneParam<P>,
root_hash: &P::InnerDigest,
leaf: L,
) -> Result<bool, crate::Error> {
let claimed_leaf_hash = P::LeafHash::evaluate(&leaf_hash_params, leaf)?;
let (left_child, right_child) =
select_left_right_child(self.leaf_index, &claimed_leaf_hash, &self.leaf_sibling_hash)?;
let left_child = P::LeafInnerDigestConverter::convert(left_child)?;
let right_child = P::LeafInnerDigestConverter::convert(right_child)?;
let mut curr_path_node =
P::TwoToOneHash::evaluate(&two_to_one_params, left_child, right_child)?;
let mut index = self.leaf_index;
index >>= 1;
for level in (0..self.auth_path.len()).rev() {
let (left, right) =
select_left_right_child(index, &curr_path_node, &self.auth_path[level])?;
curr_path_node = P::TwoToOneHash::compress(&two_to_one_params, &left, &right)?;
index >>= 1;
}
if &curr_path_node != root_hash {
return Ok(false);
}
Ok(true)
}
}
fn select_left_right_child<L: Clone>(
index: usize,
computed_hash: &L,
sibling_hash: &L,
) -> Result<(L, L), crate::Error> {
let is_left = index & 1 == 0;
let mut left_child = computed_hash;
let mut right_child = sibling_hash;
if !is_left {
core::mem::swap(&mut left_child, &mut right_child);
}
Ok((left_child.clone(), right_child.clone()))
}
#[derive(Derivative)]
#[derivative(Clone(bound = "P: Config"))]
pub struct MerkleTree<P: Config> {
non_leaf_nodes: Vec<P::InnerDigest>,
leaf_nodes: Vec<P::LeafDigest>,
two_to_one_hash_param: TwoToOneParam<P>,
leaf_hash_param: LeafParam<P>,
height: usize,
}
impl<P: Config> MerkleTree<P> {
pub fn blank(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
height: usize,
) -> Result<Self, crate::Error> {
let leaves_digest = vec![P::LeafDigest::default(); 1 << (height - 1)];
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digest)
}
pub fn new<L: Borrow<P::Leaf>>(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves: impl IntoIterator<Item = L>,
) -> Result<Self, crate::Error> {
let mut leaves_digests = Vec::new();
for leaf in leaves.into_iter() {
leaves_digests.push(P::LeafHash::evaluate(leaf_hash_param, leaf)?)
}
Self::new_with_leaf_digest(leaf_hash_param, two_to_one_hash_param, leaves_digests)
}
pub fn new_with_leaf_digest(
leaf_hash_param: &LeafParam<P>,
two_to_one_hash_param: &TwoToOneParam<P>,
leaves_digest: Vec<P::LeafDigest>,
) -> Result<Self, crate::Error> {
let leaf_nodes_size = leaves_digest.len();
assert!(
leaf_nodes_size.is_power_of_two() && leaf_nodes_size > 1,
"`leaves.len() should be power of two and greater than one"
);
let non_leaf_nodes_size = leaf_nodes_size - 1;
let tree_height = tree_height(leaf_nodes_size);
let hash_of_empty: P::InnerDigest = P::InnerDigest::default();
let mut non_leaf_nodes: Vec<P::InnerDigest> = (0..non_leaf_nodes_size)
.map(|_| hash_of_empty.clone())
.collect();
let mut index = 0;
let mut level_indices = Vec::with_capacity(tree_height - 1);
for _ in 0..(tree_height - 1) {
level_indices.push(index);
index = left_child(index);
}
{
let start_index = level_indices.pop().unwrap();
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
let left_leaf_index = left_child(current_index) - upper_bound;
let right_leaf_index = right_child(current_index) - upper_bound;
non_leaf_nodes[current_index] = P::TwoToOneHash::evaluate(
&two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(leaves_digest[left_leaf_index].clone())?,
P::LeafInnerDigestConverter::convert(leaves_digest[right_leaf_index].clone())?,
)?
}
}
level_indices.reverse();
for &start_index in &level_indices {
let upper_bound = left_child(start_index);
for current_index in start_index..upper_bound {
let left_index = left_child(current_index);
let right_index = right_child(current_index);
non_leaf_nodes[current_index] = P::TwoToOneHash::compress(
&two_to_one_hash_param,
non_leaf_nodes[left_index].clone(),
non_leaf_nodes[right_index].clone(),
)?
}
}
Ok(MerkleTree {
leaf_nodes: leaves_digest,
non_leaf_nodes,
height: tree_height,
leaf_hash_param: leaf_hash_param.clone(),
two_to_one_hash_param: two_to_one_hash_param.clone(),
})
}
pub fn root(&self) -> P::InnerDigest {
self.non_leaf_nodes[0].clone()
}
pub fn height(&self) -> usize {
self.height
}
pub fn generate_proof(&self, index: usize) -> Result<Path<P>, crate::Error> {
let tree_height = tree_height(self.leaf_nodes.len());
let leaf_index_in_tree = convert_index_to_last_level(index, tree_height);
let leaf_sibling_hash = if index & 1 == 0 {
self.leaf_nodes[index + 1].clone()
} else {
self.leaf_nodes[index - 1].clone()
};
let mut path = Vec::with_capacity(tree_height - 2);
let mut current_node = parent(leaf_index_in_tree).unwrap();
while !is_root(current_node) {
let sibling_node = sibling(current_node).unwrap();
path.push(self.non_leaf_nodes[sibling_node].clone());
current_node = parent(current_node).unwrap();
}
debug_assert_eq!(path.len(), tree_height - 2);
path.reverse();
Ok(Path {
leaf_index: index,
auth_path: path,
leaf_sibling_hash,
})
}
fn updated_path<T: Borrow<P::Leaf>>(
&self,
index: usize,
new_leaf: T,
) -> Result<(P::LeafDigest, Vec<P::InnerDigest>), crate::Error> {
let new_leaf_hash: P::LeafDigest = P::LeafHash::evaluate(&self.leaf_hash_param, new_leaf)?;
let (leaf_left, leaf_right) = if index & 1 == 0 {
(&new_leaf_hash, &self.leaf_nodes[index + 1])
} else {
(&self.leaf_nodes[index - 1], &new_leaf_hash)
};
let mut path_bottom_to_top = Vec::with_capacity(self.height - 1);
{
path_bottom_to_top.push(P::TwoToOneHash::evaluate(
&self.two_to_one_hash_param,
P::LeafInnerDigestConverter::convert(leaf_left.clone())?,
P::LeafInnerDigestConverter::convert(leaf_right.clone())?,
)?);
}
let leaf_index_in_tree = convert_index_to_last_level(index, self.height);
let mut prev_index = parent(leaf_index_in_tree).unwrap();
while !is_root(prev_index) {
let (left_child, right_child) = if is_left_child(prev_index) {
(
path_bottom_to_top.last().unwrap(),
&self.non_leaf_nodes[sibling(prev_index).unwrap()],
)
} else {
(
&self.non_leaf_nodes[sibling(prev_index).unwrap()],
path_bottom_to_top.last().unwrap(),
)
};
let evaluated =
P::TwoToOneHash::compress(&self.two_to_one_hash_param, left_child, right_child)?;
path_bottom_to_top.push(evaluated);
prev_index = parent(prev_index).unwrap();
}
debug_assert_eq!(path_bottom_to_top.len(), self.height - 1);
let path_top_to_bottom: Vec<_> = path_bottom_to_top.into_iter().rev().collect();
Ok((new_leaf_hash, path_top_to_bottom))
}
pub fn update(&mut self, index: usize, new_leaf: &P::Leaf) -> Result<(), crate::Error> {
assert!(index < self.leaf_nodes.len(), "index out of range");
let (updated_leaf_hash, mut updated_path) = self.updated_path(index, new_leaf)?;
self.leaf_nodes[index] = updated_leaf_hash;
let mut curr_index = convert_index_to_last_level(index, self.height);
for _ in 0..self.height - 1 {
curr_index = parent(curr_index).unwrap();
self.non_leaf_nodes[curr_index] = updated_path.pop().unwrap();
}
Ok(())
}
pub fn check_update<T: Borrow<P::Leaf>>(
&mut self,
index: usize,
new_leaf: &P::Leaf,
asserted_new_root: &P::InnerDigest,
) -> Result<bool, crate::Error> {
let new_leaf = new_leaf.borrow();
assert!(index < self.leaf_nodes.len(), "index out of range");
let (updated_leaf_hash, mut updated_path) = self.updated_path(index, new_leaf)?;
if &updated_path[0] != asserted_new_root {
return Ok(false);
}
self.leaf_nodes[index] = updated_leaf_hash;
let mut curr_index = convert_index_to_last_level(index, self.height);
for _ in 0..self.height - 1 {
curr_index = parent(curr_index).unwrap();
self.non_leaf_nodes[curr_index] = updated_path.pop().unwrap();
}
Ok(true)
}
}
#[inline]
fn tree_height(num_leaves: usize) -> usize {
if num_leaves == 1 {
return 1;
}
(ark_std::log2(num_leaves) as usize) + 1
}
#[inline]
fn is_root(index: usize) -> bool {
index == 0
}
#[inline]
fn left_child(index: usize) -> usize {
2 * index + 1
}
#[inline]
fn right_child(index: usize) -> usize {
2 * index + 2
}
#[inline]
fn sibling(index: usize) -> Option<usize> {
if index == 0 {
None
} else if is_left_child(index) {
Some(index + 1)
} else {
Some(index - 1)
}
}
#[inline]
fn is_left_child(index: usize) -> bool {
index % 2 == 1
}
#[inline]
fn parent(index: usize) -> Option<usize> {
if index > 0 {
Some((index - 1) >> 1)
} else {
None
}
}
#[inline]
fn convert_index_to_last_level(index: usize, tree_height: usize) -> usize {
index + (1 << (tree_height - 1)) - 1
}