use crate::{
errors::MerkleError,
merkle_tree::{MerklePath, MerkleTreeDigest},
traits::{MerkleParameters, CRH},
};
use snarkvm_utilities::ToBytes;
#[derive(Default)]
pub struct MerkleTree<P: MerkleParameters> {
root: Option<MerkleTreeDigest<P>>,
tree: Vec<MerkleTreeDigest<P>>,
hashed_leaves_index: usize,
padding_tree: Vec<(MerkleTreeDigest<P>, MerkleTreeDigest<P>)>,
parameters: P,
}
impl<P: MerkleParameters> MerkleTree<P> {
pub const DEPTH: u8 = P::DEPTH as u8;
pub fn new<L: ToBytes, I: ExactSizeIterator<Item = L>>(parameters: P, leaves: I) -> Result<Self, MerkleError> {
let new_time = start_timer!(|| "MerkleTree::new");
let last_level_size = leaves.len().next_power_of_two();
let tree_size = 2 * last_level_size - 1;
let tree_depth = tree_depth(tree_size);
if tree_depth > Self::DEPTH as usize {
return Err(MerkleError::InvalidTreeDepth(tree_depth, Self::DEPTH as usize));
}
let empty_hash = parameters.hash_empty()?;
let mut tree = vec![empty_hash.clone(); tree_size];
let mut index = 0;
let mut level_indices = Vec::with_capacity(tree_depth);
for _ in 0..=tree_depth {
level_indices.push(index);
index = left_child(index);
}
let hash_input_size_in_bytes = (P::H::INPUT_SIZE_BITS / 8) * 2;
let last_level_index = level_indices.pop().unwrap_or(0);
let mut buffer = vec![0u8; hash_input_size_in_bytes];
for (i, leaf) in leaves.enumerate() {
tree[last_level_index + i] = parameters.hash_leaf(&leaf, &mut buffer)?;
}
let mut upper_bound = last_level_index;
let mut buffer = vec![0u8; hash_input_size_in_bytes];
level_indices.reverse();
for &start_index in &level_indices {
for current_index in start_index..upper_bound {
let left_index = left_child(current_index);
let right_index = right_child(current_index);
tree[current_index] = parameters.hash_inner_node(&tree[left_index], &tree[right_index], &mut buffer)?;
}
upper_bound = start_index;
}
let mut current_depth = tree_depth;
let mut padding_tree = Vec::with_capacity((Self::DEPTH as usize).saturating_sub(current_depth + 1));
let mut current_hash = tree[0].clone();
while current_depth < Self::DEPTH as usize {
current_hash = parameters.hash_inner_node(¤t_hash, &empty_hash, &mut buffer)?;
if current_depth < Self::DEPTH as usize - 1 {
padding_tree.push((current_hash.clone(), empty_hash.clone()));
}
current_depth += 1;
}
let root_hash = current_hash;
end_timer!(new_time);
Ok(MerkleTree {
tree,
padding_tree,
hashed_leaves_index: last_level_index,
parameters,
root: Some(root_hash),
})
}
pub fn rebuild<L: ToBytes, I: ExactSizeIterator<Item = L>, J: ExactSizeIterator<Item = L>>(
&mut self,
old_leaves: I,
new_leaves: J,
) -> Result<(), MerkleError> {
let new_time = start_timer!(|| "MerkleTree::rebuild");
let last_level_size = (old_leaves.len() + new_leaves.len()).next_power_of_two();
let tree_size = 2 * last_level_size - 1;
let tree_depth = tree_depth(tree_size);
if tree_depth > Self::DEPTH as usize {
return Err(MerkleError::InvalidTreeDepth(tree_depth, Self::DEPTH as usize));
}
let empty_hash = self.parameters.hash_empty()?;
let mut tree = vec![empty_hash.clone(); tree_size];
let mut index = 0;
let mut level_indices = Vec::with_capacity(tree_depth);
for _ in 0..=tree_depth {
level_indices.push(index);
index = left_child(index);
}
let new_indices = (old_leaves.len()..old_leaves.len() + new_leaves.len()).collect::<Vec<_>>();
let hash_input_size_in_bytes = (P::H::INPUT_SIZE_BITS / 8) * 2;
let last_level_index = level_indices.pop().unwrap_or(0);
let mut buffer = vec![0u8; hash_input_size_in_bytes];
tree[last_level_index..][..old_leaves.len()].clone_from_slice(&self.hashed_leaves()[..old_leaves.len()]);
for (i, leaf) in new_leaves.enumerate() {
tree[last_level_index + old_leaves.len() + i] = self.parameters.hash_leaf(&leaf, &mut buffer)?;
}
let mut upper_bound = last_level_index;
let mut buffer = vec![0u8; hash_input_size_in_bytes];
level_indices.reverse();
for &start_index in &level_indices {
for current_index in start_index..upper_bound {
let left_index = left_child(current_index);
let right_index = right_child(current_index);
if new_indices.contains(¤t_index)
|| self.tree.get(left_index) != tree.get(left_index)
|| self.tree.get(right_index) != tree.get(right_index)
|| new_indices
.iter()
.any(|&idx| Ancestors(idx).into_iter().find(|&i| i == current_index).is_some())
{
tree[current_index] =
self.parameters
.hash_inner_node(&tree[left_index], &tree[right_index], &mut buffer)?;
} else {
tree[current_index] = self.tree[current_index].clone();
}
}
upper_bound = start_index;
}
let mut current_depth = tree_depth;
let mut current_hash = tree[0].clone();
let new_padding_tree = if current_hash == self.tree[0] {
current_hash =
self.parameters
.hash_inner_node(&self.padding_tree.last().unwrap().0, &empty_hash, &mut buffer)?;
None
} else {
let mut padding_tree = Vec::with_capacity((Self::DEPTH as usize).saturating_sub(current_depth + 1));
while current_depth < Self::DEPTH as usize {
current_hash = self
.parameters
.hash_inner_node(¤t_hash, &empty_hash, &mut buffer)?;
if current_depth < Self::DEPTH as usize - 1 {
padding_tree.push((current_hash.clone(), empty_hash.clone()));
}
current_depth += 1;
}
Some(padding_tree)
};
let root_hash = current_hash;
end_timer!(new_time);
self.root = Some(root_hash);
self.tree = tree;
self.hashed_leaves_index = last_level_index;
if let Some(padding_tree) = new_padding_tree {
self.padding_tree = padding_tree;
}
Ok(())
}
#[inline]
pub fn root(&self) -> <P::H as CRH>::Output {
self.root.clone().unwrap()
}
#[inline]
pub fn tree(&self) -> &[<P::H as CRH>::Output] {
&self.tree
}
#[inline]
pub fn hashed_leaves(&self) -> &[<P::H as CRH>::Output] {
&self.tree[self.hashed_leaves_index..]
}
pub fn generate_proof<L: ToBytes>(&self, index: usize, leaf: &L) -> Result<MerklePath<P>, MerkleError> {
let prove_time = start_timer!(|| "MerkleTree::generate_proof");
let mut path = vec![];
let hash_input_size_in_bytes = (P::H::INPUT_SIZE_BITS / 8) * 2;
let mut buffer = vec![0u8; hash_input_size_in_bytes];
let leaf_hash = self.parameters.hash_leaf(leaf, &mut buffer)?;
let tree_depth = tree_depth(self.tree.len());
let tree_index = convert_index_to_last_level(index, tree_depth);
if leaf_hash != self.tree[tree_index] {
return Err(MerkleError::IncorrectLeafIndex(tree_index));
}
let mut current_node = tree_index;
while !is_root(current_node) {
let sibling_node = sibling(current_node).unwrap();
let (curr_hash, sibling_hash) = (self.tree[current_node].clone(), self.tree[sibling_node].clone());
if is_left_child(current_node) {
path.push((curr_hash, sibling_hash));
} else {
path.push((sibling_hash, curr_hash));
}
current_node = parent(current_node).unwrap();
}
if path.len() > Self::DEPTH as usize {
return Err(MerkleError::InvalidPathLength(path.len(), Self::DEPTH as usize));
}
if path.len() != Self::DEPTH as usize {
let empty_hash = self.parameters.hash_empty()?;
path.push((self.tree[0].clone(), empty_hash));
for &(ref hash, ref sibling_hash) in &self.padding_tree {
path.push((hash.clone(), sibling_hash.clone()));
}
}
end_timer!(prove_time);
if path.len() != Self::DEPTH as usize {
Err(MerkleError::IncorrectPathLength(path.len()))
} else {
Ok(MerklePath {
parameters: self.parameters.clone(),
path,
})
}
}
}
#[inline]
fn tree_depth(tree_size: usize) -> usize {
fn log2(number: usize) -> usize {
(number as f64).log2() as usize
}
log2(tree_size)
}
#[inline]
fn is_root(index: usize) -> bool {
index == 0
}
#[inline]
fn left_child(index: usize) -> usize {
2 * index + 1
}
#[inline]
fn right_child(index: usize) -> usize {
2 * index + 2
}
#[inline]
fn sibling(index: usize) -> Option<usize> {
if index == 0 {
None
} else if is_left_child(index) {
Some(index + 1)
} else {
Some(index - 1)
}
}
#[inline]
fn is_left_child(index: usize) -> bool {
index % 2 == 1
}
#[inline]
fn parent(index: usize) -> Option<usize> {
if index > 0 { Some((index - 1) >> 1) } else { None }
}
#[inline]
fn convert_index_to_last_level(index: usize, tree_depth: usize) -> usize {
index + (1 << tree_depth) - 1
}
pub struct Ancestors(usize);
impl Iterator for Ancestors {
type Item = usize;
fn next(&mut self) -> Option<usize> {
if let Some(parent) = parent(self.0) {
self.0 = parent;
Some(parent)
} else {
None
}
}
}