use field_cat::FieldBytes;
use sha2::{Digest, Sha256};
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MerkleRoot([u8; 32]);
impl MerkleRoot {
#[must_use]
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct MerkleProof {
leaf_index: usize,
siblings: Vec<[u8; 32]>,
}
impl MerkleProof {
#[must_use]
pub fn leaf_index(&self) -> usize {
self.leaf_index
}
#[must_use]
pub fn siblings(&self) -> &[[u8; 32]] {
&self.siblings
}
}
#[derive(Debug, Clone)]
pub struct MerkleTree {
nodes: Vec<[u8; 32]>,
depth: usize,
leaf_count: usize,
}
fn hash_leaf(index: usize, value_bytes: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"leaf:");
hasher.update(index.to_le_bytes());
hasher.update(value_bytes);
hasher.finalize().into()
}
fn hash_padding(index: usize) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(b"padding:");
hasher.update(index.to_le_bytes());
hasher.finalize().into()
}
fn hash_pair(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(left);
hasher.update(right);
hasher.finalize().into()
}
fn next_power_of_two(n: usize) -> usize {
if n <= 1 { 1 } else { n.next_power_of_two() }
}
impl MerkleTree {
#[must_use]
pub fn from_field_values<F: FieldBytes>(values: &[F]) -> Self {
let leaf_count = values.len();
let n = next_power_of_two(leaf_count);
let depth = usize::try_from(n.trailing_zeros()).unwrap_or(0);
let leaf_hashes: Vec<[u8; 32]> = (0..n)
.map(|i| {
if i < leaf_count {
hash_leaf(i, &values[i].to_le_bytes())
} else {
hash_padding(i)
}
})
.collect();
let nodes_len = 2 * n;
let zeroed: Vec<[u8; 32]> = (0..nodes_len).map(|_| [0u8; 32]).collect();
let with_leaves: Vec<[u8; 32]> = zeroed
.iter()
.enumerate()
.map(|(idx, zero)| {
if idx >= n && idx < 2 * n {
leaf_hashes[idx - n]
} else {
*zero
}
})
.collect();
let nodes = (1..=depth).fold(with_leaves, |acc, level_from_bottom| {
let start = n >> level_from_bottom;
let end = n >> (level_from_bottom - 1);
(0..acc.len())
.map(|idx| {
if idx >= start && idx < end {
hash_pair(&acc[idx * 2], &acc[idx * 2 + 1])
} else {
acc[idx]
}
})
.collect()
});
Self {
nodes,
depth,
leaf_count,
}
}
#[must_use]
pub fn root(&self) -> MerkleRoot {
if self.nodes.len() > 1 {
MerkleRoot(self.nodes[1])
} else {
MerkleRoot([0u8; 32])
}
}
#[must_use]
pub fn leaf_count(&self) -> usize {
self.leaf_count
}
pub fn open(&self, index: usize) -> Result<MerkleProof, Error> {
if index >= self.leaf_count {
Err(Error::LeafIndexOutOfBounds {
index,
leaf_count: self.leaf_count,
})
} else {
let n = 1usize << self.depth;
let siblings = (0..self.depth)
.scan(n + index, |pos, _| {
let sibling_pos = *pos ^ 1;
let sibling = self.nodes[sibling_pos];
*pos /= 2;
Some(sibling)
})
.collect();
Ok(MerkleProof {
leaf_index: index,
siblings,
})
}
}
#[must_use]
pub fn verify_opening<F: FieldBytes>(
root: &MerkleRoot,
index: usize,
value: &F,
proof: &MerkleProof,
) -> bool {
let leaf_hash = hash_leaf(index, &value.to_le_bytes());
let n = 1usize << proof.siblings.len();
let computed_root = proof
.siblings
.iter()
.enumerate()
.fold((leaf_hash, n + index), |(current, pos), (_, sibling)| {
let parent = if pos % 2 == 0 {
hash_pair(¤t, sibling)
} else {
hash_pair(sibling, ¤t)
};
(parent, pos / 2)
})
.0;
computed_root == root.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use field_cat::{BabyBear, F101};
#[test]
fn single_leaf_roundtrip() -> Result<(), Error> {
let tree = MerkleTree::from_field_values(&[F101::new(42)]);
let proof = tree.open(0)?;
assert!(MerkleTree::verify_opening(
&tree.root(),
0,
&F101::new(42),
&proof
));
Ok(())
}
#[test]
fn two_leaf_roundtrip() -> Result<(), Error> {
let values = [BabyBear::new(10), BabyBear::new(20)];
let tree = MerkleTree::from_field_values(&values);
let proof0 = tree.open(0)?;
let proof1 = tree.open(1)?;
assert!(MerkleTree::verify_opening(
&tree.root(),
0,
&BabyBear::new(10),
&proof0
));
assert!(MerkleTree::verify_opening(
&tree.root(),
1,
&BabyBear::new(20),
&proof1
));
Ok(())
}
#[test]
fn tampered_value_fails() -> Result<(), Error> {
let tree = MerkleTree::from_field_values(&[F101::new(42)]);
let proof = tree.open(0)?;
assert!(!MerkleTree::verify_opening(
&tree.root(),
0,
&F101::new(99),
&proof
));
Ok(())
}
#[test]
fn out_of_bounds_open_fails() {
let tree = MerkleTree::from_field_values(&[F101::new(1), F101::new(2)]);
assert!(tree.open(2).is_err());
}
#[test]
fn four_leaves() -> Result<(), Error> {
let values = [
BabyBear::new(1),
BabyBear::new(2),
BabyBear::new(3),
BabyBear::new(4),
];
let tree = MerkleTree::from_field_values(&values);
(0..4).try_for_each(|i| {
let proof = tree.open(i)?;
assert!(
MerkleTree::verify_opening(&tree.root(), i, &values[i], &proof),
"failed at leaf {i}"
);
Ok(())
})
}
}