use crate::hash::MimcHasher;
#[cfg(not(feature = "std"))]
extern crate alloc;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MerkleProof {
pub leaf: u128,
pub leaf_index: u32,
pub path: Vec<u128>,
pub indices: Vec<bool>,
}
impl MerkleProof {
pub fn new(leaf: u128, leaf_index: u32, path: Vec<u128>, indices: Vec<bool>) -> Self {
MerkleProof {
leaf,
leaf_index,
path,
indices,
}
}
#[inline]
pub fn depth(&self) -> usize {
self.path.len()
}
pub fn verify(&self, root: u128, hasher: &MimcHasher) -> bool {
if self.path.len() != self.indices.len() {
return false;
}
let computed_root = self.compute_root(hasher);
computed_root == root
}
pub fn compute_root(&self, hasher: &MimcHasher) -> u128 {
let field_size = hasher.field_prime();
let c = 0_u128;
let mut current = self.leaf;
for (sibling, &is_right) in self.path.iter().zip(self.indices.iter()) {
let (left, right) = if is_right {
(*sibling, current)
} else {
(current, *sibling)
};
let mut r = left;
r = hasher.mimc_sponge(r, c, field_size);
r = r.wrapping_add(right).wrapping_rem(field_size);
r = hasher.mimc_sponge(r, c, field_size);
current = r;
}
current
}
#[inline]
pub fn leaf(&self) -> u128 {
self.leaf
}
#[inline]
pub fn leaf_index(&self) -> u32 {
self.leaf_index
}
#[inline]
pub fn path(&self) -> &[u128] {
&self.path
}
#[inline]
pub fn indices(&self) -> &[bool] {
&self.indices
}
}
#[cfg(feature = "serde")]
mod serde_impl {
use super::*;
use serde::Serialize;
impl Serialize for MerkleProof {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("MerkleProof", 4)?;
state.serialize_field("leaf", &self.leaf)?;
state.serialize_field("leaf_index", &self.leaf_index)?;
state.serialize_field("path", &self.path)?;
state.serialize_field("indices", &self.indices)?;
state.end()
}
}
}
#[cfg(feature = "borsh")]
mod borsh_impl {
use super::*;
use borsh::{BorshDeserialize, BorshSerialize};
impl BorshSerialize for MerkleProof {
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
self.leaf.serialize(writer)?;
self.leaf_index.serialize(writer)?;
self.path.serialize(writer)?;
self.indices.serialize(writer)?;
Ok(())
}
}
impl BorshDeserialize for MerkleProof {
fn deserialize_reader<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
let leaf = u128::deserialize_reader(reader)?;
let leaf_index = u32::deserialize_reader(reader)?;
let path = Vec::<u128>::deserialize_reader(reader)?;
let indices = Vec::<bool>::deserialize_reader(reader)?;
Ok(MerkleProof {
leaf,
leaf_index,
path,
indices,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_hasher() -> MimcHasher {
MimcHasher::default()
}
#[test]
fn test_proof_new() {
let proof = MerkleProof::new(12345, 0, vec![1, 2, 3], vec![false, true, false]);
assert_eq!(proof.leaf(), 12345);
assert_eq!(proof.leaf_index(), 0);
assert_eq!(proof.depth(), 3);
}
#[test]
fn test_proof_depth() {
let proof = MerkleProof::new(0, 0, vec![1, 2, 3, 4, 5], vec![false; 5]);
assert_eq!(proof.depth(), 5);
}
#[test]
fn test_proof_mismatched_lengths_fails_verify() {
let proof = MerkleProof {
leaf: 12345,
leaf_index: 0,
path: vec![1, 2, 3],
indices: vec![false, true], };
assert!(!proof.verify(0, &default_hasher()));
}
#[test]
fn test_compute_root_deterministic() {
let proof = MerkleProof::new(12345, 0, vec![1, 2, 3], vec![false, false, false]);
let hasher = default_hasher();
let root1 = proof.compute_root(&hasher);
let root2 = proof.compute_root(&hasher);
assert_eq!(root1, root2);
}
#[test]
fn test_verify_wrong_root_fails() {
let proof = MerkleProof::new(12345, 0, vec![1, 2, 3], vec![false, false, false]);
let hasher = default_hasher();
let computed = proof.compute_root(&hasher);
assert!(proof.verify(computed, &hasher));
assert!(!proof.verify(computed + 1, &hasher));
}
}