use crate::zk::error::ZKError;
use alloc::vec::Vec;
use soroban_sdk::{BytesN, Env};
pub const MAX_DEPTH: u32 = 20;
pub struct MerkleTree {
depth: u32,
leaf_count: u32,
layers: Vec<Vec<[u8; 32]>>,
}
pub struct MerkleProof {
pub siblings: Vec<[u8; 32]>,
pub path_indices: Vec<bool>,
pub leaf: [u8; 32],
pub leaf_index: u32,
}
impl MerkleTree {
pub fn from_leaves(env: &Env, leaves: &[[u8; 32]]) -> Result<Self, ZKError> {
if leaves.is_empty() {
return Err(ZKError::EmptyTree);
}
let leaf_count = leaves.len() as u32;
let mut depth = 0u32;
let mut size = 1u32;
while size < leaf_count {
depth += 1;
size *= 2;
}
if depth > MAX_DEPTH {
return Err(ZKError::MaxDepthExceeded);
}
let mut current_layer: Vec<[u8; 32]> = Vec::new();
for leaf in leaves {
current_layer.push(hash_leaf(env, leaf));
}
while (current_layer.len() as u32) < size {
current_layer.push([0u8; 32]);
}
let mut layers = Vec::new();
layers.push(current_layer.clone());
while current_layer.len() > 1 {
let mut next_layer = Vec::new();
for i in (0..current_layer.len()).step_by(2) {
let left = ¤t_layer[i];
let right = ¤t_layer[i + 1];
next_layer.push(hash_pair(env, left, right));
}
layers.push(next_layer.clone());
current_layer = next_layer;
}
Ok(Self {
depth,
leaf_count,
layers,
})
}
pub fn root(&self) -> [u8; 32] {
self.layers.last().unwrap()[0]
}
pub fn root_bytes(&self, env: &Env) -> BytesN<32> {
BytesN::from_array(env, &self.root())
}
pub fn proof(&self, leaf_index: u32) -> Result<MerkleProof, ZKError> {
if leaf_index >= self.leaf_count {
return Err(ZKError::LeafOutOfBounds);
}
let mut siblings = Vec::new();
let mut path_indices = Vec::new();
let mut idx = leaf_index as usize;
for layer_idx in 0..self.depth as usize {
let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
siblings.push(self.layers[layer_idx][sibling_idx]);
path_indices.push(idx % 2 != 0);
idx /= 2;
}
Ok(MerkleProof {
siblings,
path_indices,
leaf: self.layers[0][leaf_index as usize],
leaf_index,
})
}
pub fn depth(&self) -> u32 {
self.depth
}
pub fn leaf_count(&self) -> u32 {
self.leaf_count
}
}
fn hash_leaf(env: &Env, data: &[u8; 32]) -> [u8; 32] {
let mut input = [0u8; 33];
input[0] = 0x00;
input[1..].copy_from_slice(data);
let bytes = soroban_sdk::Bytes::from_slice(env, &input);
let result = env.crypto().sha256(&bytes);
result.to_array()
}
fn hash_pair(env: &Env, left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut input = [0u8; 65];
input[0] = 0x01;
input[1..33].copy_from_slice(left);
input[33..65].copy_from_slice(right);
let bytes = soroban_sdk::Bytes::from_slice(env, &input);
let result = env.crypto().sha256(&bytes);
result.to_array()
}
pub fn verify_proof(env: &Env, proof: &MerkleProof, expected_root: &[u8; 32]) -> bool {
let mut current = proof.leaf;
for i in 0..proof.siblings.len() {
let sibling = &proof.siblings[i];
if proof.path_indices[i] {
current = hash_pair(env, sibling, ¤t);
} else {
current = hash_pair(env, ¤t, sibling);
}
}
current == *expected_root
}
#[cfg(feature = "hazmat-crypto")]
pub struct PoseidonMerkleTree {
depth: u32,
leaf_count: u32,
layers: alloc::vec::Vec<alloc::vec::Vec<soroban_sdk::U256>>,
}
#[cfg(feature = "hazmat-crypto")]
pub struct PoseidonMerkleProof {
pub siblings: alloc::vec::Vec<soroban_sdk::U256>,
pub path_indices: alloc::vec::Vec<bool>,
pub leaf: soroban_sdk::U256,
pub leaf_index: u32,
}
#[cfg(feature = "hazmat-crypto")]
impl PoseidonMerkleTree {
pub fn from_leaves(
env: &Env,
params: &crate::zk::crypto::Poseidon2Params,
leaves: &[soroban_sdk::U256],
) -> Result<Self, ZKError> {
if leaves.is_empty() {
return Err(ZKError::EmptyTree);
}
let leaf_count = leaves.len() as u32;
let mut depth = 0u32;
let mut size = 1u32;
while size < leaf_count {
depth += 1;
size *= 2;
}
if depth > MAX_DEPTH {
return Err(ZKError::MaxDepthExceeded);
}
let zero = soroban_sdk::U256::from_u32(env, 0);
let mut current_layer: alloc::vec::Vec<soroban_sdk::U256> = alloc::vec::Vec::new();
for leaf in leaves {
current_layer.push(crate::zk::crypto::poseidon2_hash(env, params, leaf, &zero));
}
while (current_layer.len() as u32) < size {
current_layer.push(zero.clone());
}
let mut layers = alloc::vec::Vec::new();
layers.push(current_layer.clone());
while current_layer.len() > 1 {
let mut next_layer = alloc::vec::Vec::new();
for i in (0..current_layer.len()).step_by(2) {
let left = ¤t_layer[i];
let right = ¤t_layer[i + 1];
next_layer.push(crate::zk::crypto::poseidon2_hash(env, params, left, right));
}
layers.push(next_layer.clone());
current_layer = next_layer;
}
Ok(Self {
depth,
leaf_count,
layers,
})
}
pub fn root(&self) -> soroban_sdk::U256 {
self.layers.last().unwrap()[0].clone()
}
pub fn proof(&self, leaf_index: u32) -> Result<PoseidonMerkleProof, ZKError> {
if leaf_index >= self.leaf_count {
return Err(ZKError::LeafOutOfBounds);
}
let mut siblings = alloc::vec::Vec::new();
let mut path_indices = alloc::vec::Vec::new();
let mut idx = leaf_index as usize;
for layer_idx in 0..self.depth as usize {
let sibling_idx = if idx % 2 == 0 { idx + 1 } else { idx - 1 };
siblings.push(self.layers[layer_idx][sibling_idx].clone());
path_indices.push(idx % 2 != 0);
idx /= 2;
}
Ok(PoseidonMerkleProof {
siblings,
path_indices,
leaf: self.layers[0][leaf_index as usize].clone(),
leaf_index,
})
}
pub fn depth(&self) -> u32 {
self.depth
}
pub fn leaf_count(&self) -> u32 {
self.leaf_count
}
}
#[cfg(feature = "hazmat-crypto")]
pub fn verify_poseidon_proof(
env: &Env,
params: &crate::zk::crypto::Poseidon2Params,
proof: &PoseidonMerkleProof,
expected_root: &soroban_sdk::U256,
) -> bool {
let mut current = proof.leaf.clone();
for i in 0..proof.siblings.len() {
let sibling = &proof.siblings[i];
if proof.path_indices[i] {
current = crate::zk::crypto::poseidon2_hash(env, params, sibling, ¤t);
} else {
current = crate::zk::crypto::poseidon2_hash(env, params, ¤t, sibling);
}
}
current == *expected_root
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_leaves_error() {
let env = Env::default();
let result = MerkleTree::from_leaves(&env, &[]);
match result {
Err(e) => assert_eq!(e, ZKError::EmptyTree),
Ok(_) => panic!("Expected EmptyTree error"),
}
}
#[test]
fn test_single_leaf() {
let env = Env::default();
let leaf = [1u8; 32];
let tree = MerkleTree::from_leaves(&env, &[leaf]).unwrap();
assert_eq!(tree.depth(), 0);
assert_eq!(tree.leaf_count(), 1);
let expected_root = hash_leaf(&env, &leaf);
assert_eq!(tree.root(), expected_root);
}
#[test]
fn test_two_leaves() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
assert_eq!(tree.depth(), 1);
assert_eq!(tree.leaf_count(), 2);
let h0 = hash_leaf(&env, &leaves[0]);
let h1 = hash_leaf(&env, &leaves[1]);
let expected_root = hash_pair(&env, &h0, &h1);
assert_eq!(tree.root(), expected_root);
}
#[test]
fn test_four_leaves() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
assert_eq!(tree.depth(), 2);
assert_eq!(tree.leaf_count(), 4);
}
#[test]
fn test_three_leaves_padded() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
assert_eq!(tree.depth(), 2);
assert_eq!(tree.leaf_count(), 3);
}
#[test]
fn test_proof_generation() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let proof = tree.proof(0).unwrap();
assert_eq!(proof.leaf_index, 0);
assert_eq!(proof.siblings.len(), 2);
assert_eq!(proof.path_indices.len(), 2);
}
#[test]
fn test_proof_out_of_bounds() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let result = tree.proof(2);
match result {
Err(e) => assert_eq!(e, ZKError::LeafOutOfBounds),
Ok(_) => panic!("Expected LeafOutOfBounds error"),
}
}
#[test]
fn test_proof_verification_valid() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let root = tree.root();
for i in 0..4u32 {
let proof = tree.proof(i).unwrap();
assert!(verify_proof(&env, &proof, &root));
}
}
#[test]
fn test_proof_verification_wrong_root() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let proof = tree.proof(0).unwrap();
let wrong_root = [0xFFu8; 32];
assert!(!verify_proof(&env, &proof, &wrong_root));
}
#[test]
fn test_root_bytes() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let root_bytes = tree.root_bytes(&env);
assert_eq!(root_bytes.to_array(), tree.root());
}
#[test]
fn test_eight_leaves() {
let env = Env::default();
let mut leaves = Vec::new();
for i in 0..8u8 {
leaves.push([i; 32]);
}
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
assert_eq!(tree.depth(), 3);
assert_eq!(tree.leaf_count(), 8);
let root = tree.root();
for i in 0..8u32 {
let proof = tree.proof(i).unwrap();
assert!(verify_proof(&env, &proof, &root));
}
}
}