use digest::{Digest, FixedOutputReset, Output, typenum::Unsigned};
pub const INNER_NODE_PREFIX: u8 = 0x01;
#[derive(Debug, Clone, Default)]
pub struct MerkleTree<D: Digest> {
nodes: Box<[Output<D>]>,
len: usize,
}
#[derive(Debug, Clone)]
pub struct UnhashedMerkleTree<D: Digest> {
buffer: Vec<Output<D>>,
len: usize,
}
impl<D: Digest + FixedOutputReset> MerkleTree<D>
where
Output<D>: Copy,
{
pub fn new(data: &[Output<D>]) -> Self {
Self::new_unhashed(data).finalize()
}
pub fn new_unhashed(data: &[Output<D>]) -> UnhashedMerkleTree<D> {
let raw_len = data.len();
assert_ne!(raw_len, 0, "Cannot create Merkle tree with zero leaves");
let len = raw_len.next_power_of_two();
let mut nodes = Vec::<Output<D>>::with_capacity(2 * len);
unsafe {
let maybe_uninit = nodes.spare_capacity_mut();
maybe_uninit
.get_unchecked_mut(0)
.write(Output::<D>::default());
let dst = maybe_uninit.get_unchecked_mut(len..).as_mut_ptr().cast();
let src = data.as_ptr();
std::ptr::copy_nonoverlapping(src, dst, raw_len);
for e in maybe_uninit.get_unchecked_mut(len + raw_len..) {
e.write(Output::<D>::default());
}
}
UnhashedMerkleTree { buffer: nodes, len }
}
#[inline]
pub fn root(&self) -> &Output<D> {
unsafe { self.nodes.get_unchecked(1) }
}
#[inline]
pub fn leaves(&self) -> &[Output<D>] {
unsafe { self.nodes.get_unchecked(self.len..self.len + self.len) }
}
#[inline]
pub fn contains(&self, leaf: &Output<D>) -> bool {
self.leaves().contains(leaf)
}
pub fn get_proof_iter(&self, leaf: &Output<D>) -> Option<SiblingIter<'_, D>> {
let leaf_index_in_slice = self.leaves().iter().position(|a| a == leaf)?;
Some(SiblingIter {
nodes: &self.nodes,
current: self.len + leaf_index_in_slice,
})
}
#[inline]
pub fn to_raw_bytes(&self) -> Vec<u8> {
self.nodes
.iter()
.flat_map(|node| node.as_slice())
.copied()
.collect()
}
#[inline]
pub fn from_raw_bytes(bytes: &[u8]) -> Self {
assert!(
bytes.len().is_multiple_of(D::OutputSize::USIZE),
"Invalid raw bytes length"
);
let len = bytes.len() / D::OutputSize::USIZE;
assert!(len.is_multiple_of(2));
let mut nodes: Vec<Output<D>> = Vec::with_capacity(len);
for chunk in bytes.chunks_exact(D::OutputSize::USIZE) {
let node = Output::<D>::from_slice(chunk);
nodes.push(*node);
}
assert_eq!(nodes[0], Output::<D>::default());
let len = nodes.len() / 2;
Self {
nodes: nodes.to_vec().into_boxed_slice(),
len,
}
}
}
impl<D: Digest + FixedOutputReset> UnhashedMerkleTree<D>
where
Output<D>: Copy,
{
pub fn finalize(self) -> MerkleTree<D> {
let mut nodes = self.buffer;
let len = self.len;
unsafe {
let maybe_uninit = nodes.spare_capacity_mut();
let mut hasher = D::new();
for i in (1..len).rev() {
let left = maybe_uninit.get_unchecked(2 * i).assume_init_ref();
let right = maybe_uninit.get_unchecked(2 * i + 1).assume_init_ref();
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, left);
Digest::update(&mut hasher, right);
let parent_hash = hasher.finalize_reset();
maybe_uninit.get_unchecked_mut(i).write(parent_hash);
}
nodes.set_len(2 * len);
}
MerkleTree {
nodes: nodes.into_boxed_slice(),
len,
}
}
}
#[derive(Debug, Clone)]
pub struct SiblingIter<'a, D: Digest> {
nodes: &'a [Output<D>],
current: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum NodePosition {
Left,
Right,
}
impl<'a, D: Digest> Iterator for SiblingIter<'a, D> {
type Item = (NodePosition, &'a Output<D>);
fn next(&mut self) -> Option<Self::Item> {
if self.current <= 1 {
return None;
}
let side = if (self.current & 1) == 0 {
NodePosition::Left
} else {
NodePosition::Right
};
let sibling_index = self.current ^ 1;
let sibling = unsafe { self.nodes.get_unchecked(sibling_index) };
self.current >>= 1;
Some((side, sibling))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let exact = self.current.ilog2() as usize;
(exact, Some(exact))
}
}
impl<D: Digest> ExactSizeIterator for SiblingIter<'_, D> {
fn len(&self) -> usize {
self.current.ilog2() as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::{B256, U256};
use alloy_sol_types::SolValue;
use sha2::Sha256;
use sha3::Keccak256;
#[test]
fn basic() {
test_merkle_tree::<Sha256>();
test_merkle_tree::<Keccak256>();
}
#[test]
fn proof() {
test_proof::<Sha256>();
test_proof::<Keccak256>();
}
fn test_merkle_tree<D: Digest + FixedOutputReset>()
where
Output<D>: Copy,
{
let leaves = vec![
D::digest(b"leaf1"),
D::digest(b"leaf2"),
D::digest(b"leaf3"),
D::digest(b"leaf4"),
];
let tree = MerkleTree::<D>::new(&leaves);
let mut hasher = D::new();
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, leaves[0]);
Digest::update(&mut hasher, leaves[1]);
let left_hash = hasher.finalize_reset();
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, leaves[2]);
Digest::update(&mut hasher, leaves[3]);
let right_hash = hasher.finalize_reset();
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, left_hash);
Digest::update(&mut hasher, right_hash);
let expected_root = hasher.finalize();
assert_eq!(tree.root().as_slice(), expected_root.as_slice());
}
fn test_proof<D: Digest + FixedOutputReset>()
where
Output<D>: Copy,
{
let leaves = vec![
D::digest(b"apple"),
D::digest(b"banana"),
D::digest(b"cherry"),
D::digest(b"date"),
];
let tree = MerkleTree::<D>::new(&leaves);
for leaf in &leaves {
let iter = tree
.get_proof_iter(leaf)
.expect("Leaf should be in the tree");
let mut current_hash = *leaf;
let mut hasher = D::new();
for (side, sibling_hash) in iter {
match side {
NodePosition::Left => {
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, current_hash);
Digest::update(&mut hasher, sibling_hash);
}
NodePosition::Right => {
Digest::update(&mut hasher, [INNER_NODE_PREFIX]);
Digest::update(&mut hasher, sibling_hash);
Digest::update(&mut hasher, current_hash);
}
}
current_hash = hasher.finalize_reset();
}
assert_eq!(current_hash.as_slice(), tree.root().as_slice());
}
}
#[ignore]
#[test]
fn generate_sol_test() {
let mut leaves = Vec::with_capacity(1024);
for i in 0..1024 {
let mut hasher = Keccak256::new();
let value = U256::from(i).abi_encode_packed();
hasher.update(&value);
leaves.push(hasher.finalize());
}
for i in 0..=10u32 {
let tree = MerkleTree::<Keccak256>::new(&leaves[..2usize.pow(i)]);
let root = B256::from_slice(tree.root());
println!("bytes32({root}),");
}
}
}