use crate::error::{Error, Result};
use crate::hash::MimcHasher;
use crate::merkle::proof::MerkleProof;
use crate::merkle::ROOT_HISTORY_SIZE;
#[cfg(feature = "std")]
use std::collections::HashMap;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::collections::BTreeMap as HashMap;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone)]
pub struct MerkleTree {
levels: u8,
filled_subtrees: HashMap<u8, u128>,
roots: HashMap<u8, u128>,
current_root_index: u8,
next_index: u32,
hasher: MimcHasher,
leaves: Vec<u128>,
}
impl MerkleTree {
pub fn new(levels: u8) -> Result<Self> {
if levels == 0 {
return Err(Error::InvalidTreeConfig(
"Tree must have at least 1 level".to_string(),
));
}
if levels > 32 {
return Err(Error::InvalidTreeConfig(
"Tree depth cannot exceed 32 levels".to_string(),
));
}
let hasher = MimcHasher::default();
let mut instance = MerkleTree {
levels,
filled_subtrees: HashMap::new(),
roots: HashMap::new(),
current_root_index: 0,
next_index: 0,
hasher,
leaves: Vec::new(),
};
for i in 0..levels {
instance.filled_subtrees.insert(i, instance.zeros(i));
}
instance.roots.insert(0, instance.zeros(levels - 1));
Ok(instance)
}
pub fn with_hasher(levels: u8, hasher: MimcHasher) -> Result<Self> {
if levels == 0 {
return Err(Error::InvalidTreeConfig(
"Tree must have at least 1 level".to_string(),
));
}
if levels > 32 {
return Err(Error::InvalidTreeConfig(
"Tree depth cannot exceed 32 levels".to_string(),
));
}
let mut instance = MerkleTree {
levels,
filled_subtrees: HashMap::new(),
roots: HashMap::new(),
current_root_index: 0,
next_index: 0,
hasher,
leaves: Vec::new(),
};
for i in 0..levels {
instance.filled_subtrees.insert(i, instance.zeros(i));
}
instance.roots.insert(0, instance.zeros(levels - 1));
Ok(instance)
}
#[inline]
pub fn levels(&self) -> u8 {
self.levels
}
#[inline]
pub fn capacity(&self) -> usize {
1usize << self.levels
}
#[inline]
pub fn len(&self) -> u32 {
self.next_index
}
#[inline]
pub fn is_empty(&self) -> bool {
self.next_index == 0
}
#[inline]
pub fn hasher(&self) -> &MimcHasher {
&self.hasher
}
pub fn root(&self) -> Option<u128> {
self.roots.get(&self.current_root_index).copied()
}
fn hash_left_right(&self, left: u128, right: u128) -> u128 {
let field_size = self.hasher.field_prime();
let c = 0_u128;
let mut r = left;
r = self.hasher.mimc_sponge(r, c, field_size);
r = r.wrapping_add(right).wrapping_rem(field_size);
r = self.hasher.mimc_sponge(r, c, field_size);
r
}
pub fn insert(&mut self, leaf: u128) -> Result<u32> {
let capacity = self.capacity();
if (self.next_index as usize) >= capacity {
return Err(Error::TreeFull {
capacity,
attempted_index: self.next_index as usize,
});
}
let inserted_index = self.next_index;
let mut current_index = self.next_index;
let mut current_level_hash = leaf;
self.leaves.push(leaf);
for i in 0..self.levels {
let (left, right) = if current_index % 2 == 0 {
self.filled_subtrees.insert(i, current_level_hash);
(current_level_hash, self.zeros(i))
} else {
let left = self
.filled_subtrees
.get(&i)
.copied()
.unwrap_or_else(|| self.zeros(i));
(left, current_level_hash)
};
current_level_hash = self.hash_left_right(left, right);
current_index /= 2;
}
let new_root_index = (self.current_root_index + 1) % ROOT_HISTORY_SIZE;
self.current_root_index = new_root_index;
self.roots.insert(new_root_index, current_level_hash);
self.next_index = inserted_index + 1;
Ok(inserted_index)
}
pub fn is_known_root(&self, root: u128) -> bool {
if root == 0 {
return false;
}
let mut i = self.current_root_index;
loop {
if let Some(&stored_root) = self.roots.get(&i) {
if stored_root == root {
return true;
}
}
i = if i == 0 {
ROOT_HISTORY_SIZE - 1
} else {
i - 1
};
if i == self.current_root_index {
break;
}
}
false
}
#[deprecated(since = "1.0.0", note = "Use root() instead")]
pub fn get_last_root(&self) -> u128 {
self.root().expect("Tree in invalid state: no root")
}
pub fn zeros(&self, level: u8) -> u128 {
let mut result = 0u128;
for _ in 0..level {
result = self.hasher.mimc_sponge(result, 0, self.hasher.field_prime());
}
result
}
pub fn prove(&self, leaf_index: u32) -> Result<MerkleProof> {
if leaf_index >= self.next_index {
return Err(Error::LeafIndexOutOfBounds {
index: leaf_index,
tree_size: self.next_index,
});
}
let leaf = self.leaves[leaf_index as usize];
let mut path = Vec::with_capacity(self.levels as usize);
let mut indices = Vec::with_capacity(self.levels as usize);
let mut current_index = leaf_index;
for level in 0..self.levels {
let is_right = current_index % 2 == 1;
indices.push(is_right);
let sibling_index = if is_right {
current_index - 1
} else {
current_index + 1
};
let sibling = self.get_node_at(level, sibling_index);
path.push(sibling);
current_index /= 2;
}
Ok(MerkleProof {
leaf,
leaf_index,
path,
indices,
})
}
fn get_node_at(&self, level: u8, index: u32) -> u128 {
if level == 0 {
if (index as usize) < self.leaves.len() {
return self.leaves[index as usize];
} else {
return 0; }
}
let leaves_per_subtree = 1u32 << level;
let subtree_start = index * leaves_per_subtree;
if subtree_start >= self.next_index {
return self.zeros(level);
}
let left_index = index * 2;
let right_index = left_index + 1;
let left = self.get_node_at(level - 1, left_index);
let right = self.get_node_at(level - 1, right_index);
self.hash_left_right(left, right)
}
}
#[cfg(feature = "borsh")]
mod borsh_impl {
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_tree() {
let tree = MerkleTree::new(20).unwrap();
assert_eq!(tree.levels(), 20);
assert_eq!(tree.capacity(), 1 << 20);
assert_eq!(tree.len(), 0);
assert!(tree.is_empty());
}
#[test]
fn test_new_tree_invalid_levels() {
assert!(MerkleTree::new(0).is_err());
assert!(MerkleTree::new(33).is_err());
}
#[test]
fn test_insert_single() {
let mut tree = MerkleTree::new(20).unwrap();
let index = tree.insert(12345).unwrap();
assert_eq!(index, 0);
assert_eq!(tree.len(), 1);
assert!(!tree.is_empty());
}
#[test]
fn test_insert_multiple() {
let mut tree = MerkleTree::new(20).unwrap();
for i in 0..10 {
let index = tree.insert(i as u128).unwrap();
assert_eq!(index, i);
}
assert_eq!(tree.len(), 10);
}
#[test]
fn test_tree_full() {
let mut tree = MerkleTree::new(2).unwrap(); for i in 0..4 {
tree.insert(i as u128).unwrap();
}
let result = tree.insert(100);
assert!(matches!(result, Err(Error::TreeFull { .. })));
}
#[test]
fn test_root_changes_on_insert() {
let mut tree = MerkleTree::new(20).unwrap();
let root1 = tree.root().unwrap();
tree.insert(12345).unwrap();
let root2 = tree.root().unwrap();
assert_ne!(root1, root2);
}
#[test]
fn test_is_known_root() {
let mut tree = MerkleTree::new(20).unwrap();
let root1 = tree.root().unwrap();
tree.insert(12345).unwrap();
let root2 = tree.root().unwrap();
assert!(tree.is_known_root(root1));
assert!(tree.is_known_root(root2));
assert!(!tree.is_known_root(99999));
assert!(!tree.is_known_root(0));
}
#[test]
fn test_zeros_computation() {
let tree = MerkleTree::new(10).unwrap();
let zero0 = tree.zeros(0);
let zero1 = tree.zeros(1);
assert_eq!(zero0, 0);
assert_ne!(zero1, 0);
}
#[test]
fn test_deterministic_roots() {
let mut tree1 = MerkleTree::new(10).unwrap();
let mut tree2 = MerkleTree::new(10).unwrap();
tree1.insert(123).unwrap();
tree1.insert(456).unwrap();
tree2.insert(123).unwrap();
tree2.insert(456).unwrap();
assert_eq!(tree1.root(), tree2.root());
}
#[test]
fn test_prove_valid_index() {
let mut tree = MerkleTree::new(10).unwrap();
tree.insert(12345).unwrap();
tree.insert(67890).unwrap();
let proof = tree.prove(0).unwrap();
assert_eq!(proof.leaf, 12345);
assert_eq!(proof.leaf_index, 0);
assert_eq!(proof.path.len(), 10);
}
#[test]
fn test_prove_invalid_index() {
let mut tree = MerkleTree::new(10).unwrap();
tree.insert(12345).unwrap();
let result = tree.prove(1);
assert!(matches!(result, Err(Error::LeafIndexOutOfBounds { .. })));
}
#[test]
fn test_proof_verifies() {
let mut tree = MerkleTree::new(10).unwrap();
tree.insert(12345).unwrap();
tree.insert(67890).unwrap();
tree.insert(11111).unwrap();
let root = tree.root().unwrap();
for i in 0..3 {
let proof = tree.prove(i).unwrap();
assert!(proof.verify(root, &tree.hasher()), "Proof failed for leaf {}", i);
}
}
}