use crate::zk::error::ZKError;
use soroban_sdk::{contracttype, BytesN, Env, Vec};
use super::tree::MerkleProof;
#[contracttype]
#[derive(Clone, Debug)]
pub struct OnChainMerkleProof {
pub siblings: Vec<BytesN<32>>,
pub path_bits: u32,
pub leaf: BytesN<32>,
pub leaf_index: u32,
pub depth: u32,
}
pub fn to_on_chain_proof(proof: &MerkleProof, env: &Env) -> OnChainMerkleProof {
let mut siblings: Vec<BytesN<32>> = Vec::new(env);
let mut path_bits: u32 = 0;
for (i, sibling) in proof.siblings.iter().enumerate() {
siblings.push_back(BytesN::from_array(env, sibling));
if proof.path_indices[i] {
path_bits |= 1 << i;
}
}
OnChainMerkleProof {
siblings,
path_bits,
leaf: BytesN::from_array(env, &proof.leaf),
leaf_index: proof.leaf_index,
depth: proof.siblings.len() as u32,
}
}
pub fn verify_inclusion(
env: &Env,
proof: &OnChainMerkleProof,
expected_root: &BytesN<32>,
) -> Result<bool, ZKError> {
if proof.siblings.len() != proof.depth {
return Err(ZKError::InvalidProofLength);
}
let mut current = proof.leaf.to_array();
for i in 0..proof.depth {
let sibling = proof.siblings.get(i).ok_or(ZKError::InvalidProofLength)?;
let sibling_arr = sibling.to_array();
let is_right = (proof.path_bits >> i) & 1 == 1;
if is_right {
current = hash_pair_raw(env, &sibling_arr, ¤t);
} else {
current = hash_pair_raw(env, ¤t, &sibling_arr);
}
}
Ok(current == expected_root.to_array())
}
fn hash_pair_raw(env: &Env, left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
let mut input = [0u8; 65];
input[0] = 0x01;
input[1..33].copy_from_slice(left);
input[33..65].copy_from_slice(right);
let bytes = soroban_sdk::Bytes::from_slice(env, &input);
let result = env.crypto().sha256(&bytes);
result.to_array()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zk::merkle::tree::MerkleTree;
#[test]
fn test_to_on_chain_proof_and_verify() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let root = tree.root_bytes(&env);
for i in 0..4u32 {
let proof = tree.proof(i).unwrap();
let on_chain = to_on_chain_proof(&proof, &env);
let result = verify_inclusion(&env, &on_chain, &root).unwrap();
assert!(result, "Proof verification failed for leaf {}", i);
}
}
#[test]
fn test_verify_inclusion_wrong_root() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let proof = tree.proof(0).unwrap();
let on_chain = to_on_chain_proof(&proof, &env);
let wrong_root = BytesN::from_array(&env, &[0xFFu8; 32]);
let result = verify_inclusion(&env, &on_chain, &wrong_root).unwrap();
assert!(!result);
}
#[test]
fn test_verify_inclusion_invalid_depth() {
let env = Env::default();
let proof = OnChainMerkleProof {
siblings: Vec::new(&env),
path_bits: 0,
leaf: BytesN::from_array(&env, &[1u8; 32]),
leaf_index: 0,
depth: 5, };
let root = BytesN::from_array(&env, &[0u8; 32]);
let result = verify_inclusion(&env, &proof, &root);
assert_eq!(result.unwrap_err(), ZKError::InvalidProofLength);
}
#[test]
fn test_on_chain_proof_path_bits_packing() {
let env = Env::default();
let leaves = [[1u8; 32], [2u8; 32], [3u8; 32], [4u8; 32]];
let tree = MerkleTree::from_leaves(&env, &leaves).unwrap();
let p0 = to_on_chain_proof(&tree.proof(0).unwrap(), &env);
assert_eq!(p0.path_bits, 0);
let p1 = to_on_chain_proof(&tree.proof(1).unwrap(), &env);
assert_eq!(p1.path_bits, 1);
let p2 = to_on_chain_proof(&tree.proof(2).unwrap(), &env);
assert_eq!(p2.path_bits, 2);
let p3 = to_on_chain_proof(&tree.proof(3).unwrap(), &env);
assert_eq!(p3.path_bits, 3);
}
}