use alloc::{
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use core::slice;
mod proofs;
pub use proofs::BatchMerkleProof;
use crate::{Hasher, MerkleTreeError, VectorCommitment};
#[cfg(feature = "concurrent")]
pub mod concurrent;
#[cfg(test)]
mod tests;
#[derive(Debug)]
pub struct MerkleTree<H: Hasher> {
nodes: Vec<H::Digest>,
leaves: Vec<H::Digest>,
}
pub type MerkleTreeOpening<H> = (<H as Hasher>::Digest, Vec<<H as Hasher>::Digest>);
impl<H: Hasher> MerkleTree<H> {
pub fn new(leaves: Vec<H::Digest>) -> Result<Self, MerkleTreeError> {
if leaves.len() < 2 {
return Err(MerkleTreeError::TooFewLeaves(2, leaves.len()));
}
if !leaves.len().is_power_of_two() {
return Err(MerkleTreeError::NumberOfLeavesNotPowerOfTwo(leaves.len()));
}
#[cfg(not(feature = "concurrent"))]
let nodes = build_merkle_nodes::<H>(&leaves);
#[cfg(feature = "concurrent")]
let nodes = if leaves.len() <= concurrent::MIN_CONCURRENT_LEAVES {
build_merkle_nodes::<H>(&leaves)
} else {
concurrent::build_merkle_nodes::<H>(&leaves)
};
Ok(MerkleTree { nodes, leaves })
}
pub fn from_raw_parts(
nodes: Vec<H::Digest>,
leaves: Vec<H::Digest>,
) -> Result<Self, MerkleTreeError> {
if leaves.len() < 2 {
return Err(MerkleTreeError::TooFewLeaves(2, leaves.len()));
}
if !leaves.len().is_power_of_two() {
return Err(MerkleTreeError::NumberOfLeavesNotPowerOfTwo(leaves.len()));
}
assert_eq!(nodes.len(), leaves.len());
Ok(MerkleTree { nodes, leaves })
}
pub fn root(&self) -> &H::Digest {
&self.nodes[1]
}
pub fn depth(&self) -> usize {
self.leaves.len().ilog2() as usize
}
pub fn leaves(&self) -> &[H::Digest] {
&self.leaves
}
pub fn prove(&self, index: usize) -> Result<MerkleTreeOpening<H>, MerkleTreeError> {
if index >= self.leaves.len() {
return Err(MerkleTreeError::LeafIndexOutOfBounds(self.leaves.len(), index));
}
let leaf = self.leaves[index];
let mut proof = vec![self.leaves[index ^ 1]];
let mut index = (index + self.nodes.len()) >> 1;
while index > 1 {
proof.push(self.nodes[index ^ 1]);
index >>= 1;
}
Ok((leaf, proof))
}
pub fn prove_batch(
&self,
indexes: &[usize],
) -> Result<(Vec<H::Digest>, BatchMerkleProof<H>), MerkleTreeError> {
if indexes.is_empty() {
return Err(MerkleTreeError::TooFewLeafIndexes);
}
let index_map = map_indexes(indexes, self.depth())?;
let indexes = normalize_indexes(indexes);
let mut leaves = vec![H::Digest::default(); index_map.len()];
let mut nodes: Vec<Vec<H::Digest>> = Vec::with_capacity(indexes.len());
let n = self.leaves.len();
let mut next_indexes: Vec<usize> = Vec::new();
for index in indexes {
let missing: Vec<H::Digest> = (index..index + 2)
.flat_map(|i| {
let v = self.leaves[i];
if let Some(idx) = index_map.get(&i) {
leaves[*idx] = v;
None
} else {
Some(v)
}
})
.collect();
nodes.push(missing);
next_indexes.push((index + n) >> 1);
}
for _ in 1..self.depth() {
let indexes = next_indexes.clone();
next_indexes.truncate(0);
let mut i = 0;
while i < indexes.len() {
let sibling_index = indexes[i] ^ 1;
if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
i += 1;
} else {
nodes[i].push(self.nodes[sibling_index]);
}
next_indexes.push(sibling_index >> 1);
i += 1;
}
}
Ok((leaves, BatchMerkleProof { depth: self.depth() as u8, nodes }))
}
pub fn verify(
root: H::Digest,
index: usize,
leaf: H::Digest,
proof: &[H::Digest],
) -> Result<(), MerkleTreeError> {
let r = index & 1;
let mut v = if r == 0 {
H::merge(&[leaf, proof[0]])
} else {
H::merge(&[proof[0], leaf])
};
let mut index = (index + 2usize.pow((proof.len()) as u32)) >> 1;
for &p in proof.iter().skip(1) {
v = if index & 1 == 0 {
H::merge(&[v, p])
} else {
H::merge(&[p, v])
};
index >>= 1;
}
if v != root {
return Err(MerkleTreeError::InvalidProof);
}
Ok(())
}
pub fn verify_batch(
root: &H::Digest,
indexes: &[usize],
leaves: &[H::Digest],
proof: &BatchMerkleProof<H>,
) -> Result<(), MerkleTreeError> {
if *root != proof.get_root(indexes, leaves)? {
return Err(MerkleTreeError::InvalidProof);
}
Ok(())
}
}
pub fn build_merkle_nodes<H: Hasher>(leaves: &[H::Digest]) -> Vec<H::Digest> {
let n = leaves.len() / 2;
let mut nodes = unsafe { utils::uninit_vector::<H::Digest>(2 * n) };
nodes[0] = H::Digest::default();
let two_leaves = unsafe { slice::from_raw_parts(leaves.as_ptr() as *const [H::Digest; 2], n) };
for (i, j) in (0..n).zip(n..nodes.len()) {
nodes[j] = H::merge(&two_leaves[i]);
}
let two_nodes = unsafe { slice::from_raw_parts(nodes.as_ptr() as *const [H::Digest; 2], n) };
for i in (1..n).rev() {
nodes[i] = H::merge(&two_nodes[i]);
}
nodes
}
fn map_indexes(
indexes: &[usize],
tree_depth: usize,
) -> Result<BTreeMap<usize, usize>, MerkleTreeError> {
let num_leaves = 2usize.pow(tree_depth as u32);
let mut map = BTreeMap::new();
for (i, index) in indexes.iter().cloned().enumerate() {
map.insert(index, i);
if index >= num_leaves {
return Err(MerkleTreeError::LeafIndexOutOfBounds(num_leaves, index));
}
}
if indexes.len() != map.len() {
return Err(MerkleTreeError::DuplicateLeafIndex);
}
Ok(map)
}
fn normalize_indexes(indexes: &[usize]) -> Vec<usize> {
let mut set = BTreeSet::new();
for &index in indexes {
set.insert(index - (index & 1));
}
set.into_iter().collect()
}
impl<H: Hasher> VectorCommitment<H> for MerkleTree<H> {
type Options = ();
type Proof = Vec<H::Digest>;
type MultiProof = BatchMerkleProof<H>;
type Error = MerkleTreeError;
fn with_options(items: Vec<H::Digest>, _options: Self::Options) -> Result<Self, Self::Error> {
MerkleTree::new(items)
}
fn commitment(&self) -> H::Digest {
*self.root()
}
fn domain_len(&self) -> usize {
1 << self.depth()
}
fn get_proof_domain_len(proof: &Self::Proof) -> usize {
1 << proof.len()
}
fn get_multiproof_domain_len(proof: &Self::MultiProof) -> usize {
1 << proof.depth
}
fn open(&self, index: usize) -> Result<(H::Digest, Self::Proof), Self::Error> {
self.prove(index)
}
fn open_many(
&self,
indexes: &[usize],
) -> Result<(Vec<H::Digest>, Self::MultiProof), Self::Error> {
self.prove_batch(indexes)
}
fn verify(
commitment: H::Digest,
index: usize,
item: H::Digest,
proof: &Self::Proof,
) -> Result<(), Self::Error> {
MerkleTree::<H>::verify(commitment, index, item, proof)
}
fn verify_many(
commitment: H::Digest,
indexes: &[usize],
items: &[H::Digest],
proof: &Self::MultiProof,
) -> Result<(), Self::Error> {
MerkleTree::<H>::verify_batch(&commitment, indexes, items, proof)
}
}