mod helpers;
pub use helpers::*;
mod path;
pub use path::*;
#[cfg(test)]
mod tests;
use snarkvm_console_types::prelude::*;
use aleo_std::prelude::*;
use std::ops::Range;
#[derive(Clone)]
pub struct KaryMerkleTree<LH: LeafHash<Hash = PH::Hash>, PH: PathHash, const DEPTH: u8, const ARITY: u8> {
leaf_hasher: LH,
path_hasher: PH,
root: PH::Hash,
tree: Vec<PH::Hash>,
empty_hash: PH::Hash,
number_of_leaves: usize,
}
fn checked_next_power_of_n(base: usize, n: usize) -> Option<usize> {
if n <= 1 {
return None;
}
let mut value = 1;
while value < base {
value = value.checked_mul(n)?;
}
Some(value)
}
impl<LH: LeafHash<Hash = PH::Hash>, PH: PathHash, const DEPTH: u8, const ARITY: u8>
KaryMerkleTree<LH, PH, DEPTH, ARITY>
{
#[inline]
pub fn new(leaf_hasher: &LH, path_hasher: &PH, leaves: &[LH::Leaf]) -> Result<Self> {
let timer = timer!("MerkleTree::new");
ensure!(DEPTH > 0, "Merkle tree depth must be greater than 0");
ensure!(DEPTH <= 64u8, "Merkle tree depth must be less than or equal to 64");
ensure!(ARITY > 1, "Merkle tree arity must be greater than 1");
ensure!((ARITY as u128).checked_pow(DEPTH as u32).is_some(), "Merkle tree size overflowed");
let Some(max_leaves) = checked_next_power_of_n(leaves.len(), ARITY as usize) else {
bail!("Integer overflow when computing the maximum number of leaves in the Merkle tree");
};
let num_nodes = (max_leaves - 1) / (ARITY as usize - 1);
let tree_size = max_leaves + num_nodes;
let tree_depth = tree_depth::<DEPTH, ARITY>(tree_size)?;
let padding_depth = DEPTH - tree_depth;
let empty_hash = path_hasher.hash_empty::<ARITY>()?;
let arity = ARITY as usize;
let all_nodes_are_full = leaves.len() % arity == 0;
let minimum_tree_size = std::cmp::max(
1,
num_nodes + leaves.len() + if all_nodes_are_full { 0 } else { arity - leaves.len() % arity },
);
let mut tree = vec![empty_hash; minimum_tree_size];
tree[num_nodes..num_nodes + leaves.len()].clone_from_slice(&leaf_hasher.hash_leaves(leaves)?);
lap!(timer, "Hashed {} leaves", leaves.len());
let mut start_index = num_nodes;
while let Some(start) = parent::<ARITY>(start_index) {
let end = child_indexes::<ARITY>(start).next().ok_or_else(|| anyhow!("Missing left-most child"))?;
let child_nodes = (start..end)
.take_while(|&i| child_indexes::<ARITY>(i).next().and_then(|idx| tree.get(idx)).is_some())
.map(|i| &tree[child_indexes::<ARITY>(i)])
.collect::<Vec<_>>();
let num_full_nodes = child_nodes.len();
let hashes = path_hasher.hash_all_children(&child_nodes)?;
tree[start..][..num_full_nodes].clone_from_slice(&hashes);
if start + num_full_nodes < end {
let empty_node_hash = path_hasher.hash_children(&vec![empty_hash; arity])?;
for node in tree.iter_mut().take(end).skip(start + num_full_nodes) {
*node = empty_node_hash;
}
}
start_index = start;
}
lap!(timer, "Hashed {} levels", tree_depth);
let mut root_hash = tree[0];
for _ in 0..padding_depth {
let mut input = Vec::with_capacity(ARITY as usize);
input.push(root_hash);
input.resize(ARITY as usize, empty_hash);
root_hash = path_hasher.hash_children(&input)?;
}
lap!(timer, "Hashed {} padding levels", padding_depth);
finish!(timer);
Ok(Self {
leaf_hasher: leaf_hasher.clone(),
path_hasher: path_hasher.clone(),
root: root_hash,
tree,
empty_hash,
number_of_leaves: leaves.len(),
})
}
#[inline]
pub fn prove(&self, leaf_index: usize, leaf: &LH::Leaf) -> Result<KaryMerklePath<PH, DEPTH, ARITY>> {
ensure!(leaf_index < self.number_of_leaves, "The given Merkle leaf index is out of bounds");
let leaf_hash = self.leaf_hasher.hash_leaf(leaf)?;
let start = match checked_next_power_of_n(self.number_of_leaves, ARITY as usize) {
Some(num_leaves) => (num_leaves - 1) / (ARITY as usize - 1),
None => bail!("Integer overflow when computing the Merkle tree start index"),
};
let mut index = start + leaf_index;
ensure!(index < self.tree.len(), "The given Merkle leaf index is out of bounds");
ensure!(self.tree[index] == leaf_hash, "The given Merkle leaf does not match the one in the Merkle tree");
let mut path = Vec::with_capacity(DEPTH as usize);
for _ in 0..DEPTH {
if let Some(siblings) = siblings::<ARITY>(index) {
let sibling_hashes = siblings.map(|index| self.tree[index]).collect::<Vec<_>>();
path.push(sibling_hashes);
match parent::<ARITY>(index) {
Some(parent) => index = parent,
None => break,
}
}
}
if path.len() != DEPTH as usize {
let empty_hashes = (0..ARITY.saturating_sub(1)).map(|_| self.empty_hash).collect::<Vec<_>>();
path.resize(DEPTH as usize, empty_hashes);
}
KaryMerklePath::try_from((leaf_index as u64, path))
}
pub fn verify(&self, path: &KaryMerklePath<PH, DEPTH, ARITY>, root: &PH::Hash, leaf: &LH::Leaf) -> bool {
path.verify(&self.leaf_hasher, &self.path_hasher, root, leaf)
}
pub const fn root(&self) -> &PH::Hash {
&self.root
}
pub fn tree(&self) -> &[PH::Hash] {
&self.tree
}
pub const fn empty_hash(&self) -> &PH::Hash {
&self.empty_hash
}
pub const fn number_of_leaves(&self) -> usize {
self.number_of_leaves
}
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
fn tree_depth<const DEPTH: u8, const ARITY: u8>(tree_size: usize) -> Result<u8> {
let tree_size = u64::try_from(tree_size)?;
ensure!(tree_size < 4503599627370496_u64, "Tree size must be less than 2^52");
let tree_depth_float = (tree_size as f64).ln() / (ARITY as f64).ln();
let tree_depth = u8::try_from(tree_depth_float.floor() as u64)?;
match tree_depth <= DEPTH {
true => Ok(tree_depth),
false => bail!("Merkle tree cannot exceed depth {DEPTH}: attempted to reach depth {tree_depth}"),
}
}
fn child_indexes<const ARITY: u8>(index: usize) -> Range<usize> {
let start = index * ARITY as usize + 1;
start..start + ARITY as usize
}
#[inline]
fn siblings<const ARITY: u8>(index: usize) -> Option<impl Iterator<Item = usize>> {
if is_root(index) {
None
} else {
let left_most_sibling = ((index - 1) / ARITY as usize) * ARITY as usize + 1;
Some((left_most_sibling..left_most_sibling + ARITY as usize).filter(move |&i| index != i))
}
}
#[inline]
const fn is_root(index: usize) -> bool {
index == 0
}
#[inline]
const fn parent<const ARITY: u8>(index: usize) -> Option<usize> {
if index > 0 { Some((index - 1) / ARITY as usize) } else { None }
}