mt_rs/
proof.rs

1//! Merkle tree proof and verification implementation
2
3use crate::{
4    hasher::Hasher,
5    node::{Node, NodeChildType},
6};
7use rayon::prelude::*;
8
9/// Represents a single step in a Merkle proof path.
10#[derive(Debug, Clone)]
11pub struct ProofNode {
12    /// The hash value of the sibling node.
13    pub hash: String,
14    /// Whether this sibling is left or right
15    pub child_type: NodeChildType,
16}
17
18/// A Merkle proof containing the path from a leaf to the root.
19#[derive(Debug)]
20pub struct MerkleProof {
21    /// The sequence of sibling hashes needed to reconstruct the path to root.
22    pub path: Vec<ProofNode>,
23    /// The index of the leaf node this proof corresponds.
24    pub leaf_index: usize,
25}
26
27pub trait Proofer {
28    /// Generates a Merkle proof for the data at the specified index
29    ///
30    /// # Arguments
31    ///
32    /// * `index` - The index of the leaf node to generate a proof.
33    ///
34    /// # Returns
35    ///
36    /// `Some(MerkleProof)` if the index is valid, `None` otherwise.
37    fn generate(&self, index: usize) -> Option<MerkleProof>;
38
39    /// Verifies that a piece of data exists in the tree using a Merkle proof.
40    ///
41    /// # Arguments
42    ///
43    /// * `proof` - The Merkle proof.
44    /// * `data` - The original data to verify.
45    /// * `root_hash` - The expected root hash of the tree.
46    ///
47    /// # Returns
48    ///
49    /// `true` if the proof is valid and the data exists in the tree, `false` otherwise.
50    fn verify<T>(&self, proof: &MerkleProof, data: T, root_hash: &str) -> bool
51    where
52        T: AsRef<[u8]>;
53}
54
55pub struct DefaultProofer<H: Hasher> {
56    hasher: H,
57    levels: Vec<Vec<Node>>,
58}
59
60impl<H> DefaultProofer<H>
61where
62    H: Hasher,
63{
64    pub fn new(hasher: H, leaves: Vec<Node>) -> Self {
65        let mut levels = Vec::new();
66        levels.push(leaves.clone());
67
68        let mut current_level = leaves;
69        while current_level.len() > 1 {
70            if current_level.len() % 2 != 0 {
71                current_level.push(current_level.last().unwrap().clone());
72            }
73            let next_level: Vec<Node> = current_level
74                .par_chunks(2)
75                .map(|pair| {
76                    let (left, right) = (&pair[0], &pair[1]);
77                    let combined = [left.hash().as_bytes(), right.hash().as_bytes()].concat();
78                    let hash = hasher.hash(&combined);
79                    Node::new_internal(hash, left.clone(), right.clone())
80                })
81                .collect();
82
83            levels.push(next_level.clone());
84            current_level = next_level;
85        }
86
87        Self { hasher, levels }
88    }
89
90    pub fn verify_hash(&self, proof: &MerkleProof, hash: String, root_hash: &str) -> bool {
91        let mut current_hash = hash;
92        // Walk up the tree using the proof path
93        for proof_node in &proof.path {
94            let combined: String = match proof_node.child_type {
95                NodeChildType::Left => format!("{}{}", proof_node.hash, current_hash),
96                NodeChildType::Right => format!("{}{}", current_hash, proof_node.hash),
97            };
98            current_hash = self.hasher.hash(combined.as_bytes());
99        }
100
101        // Check if the computed root matches the expected root
102        current_hash == root_hash
103    }
104}
105
106impl<H> Proofer for DefaultProofer<H>
107where
108    H: Hasher,
109{
110    fn generate(&self, index: usize) -> Option<MerkleProof> {
111        if index >= self.levels[0].len() {
112            return None;
113        }
114
115        let mut path = Vec::new();
116        let mut current_index = index;
117
118        for level in &self.levels[..self.levels.len() - 1] {
119            // Flip the last bit and ensures that it never goes out-of-bounds
120            let sibling_index = (current_index ^ 1).min(level.len() - 1);
121
122            let sibling = &level[sibling_index];
123
124            let child_type = if sibling_index < current_index {
125                NodeChildType::Left
126            } else {
127                NodeChildType::Right
128            };
129
130            path.push(ProofNode {
131                hash: sibling.hash().to_string(),
132                child_type,
133            });
134
135            current_index >>= 1;
136        }
137
138        Some(MerkleProof {
139            path,
140            leaf_index: index,
141        })
142    }
143
144    fn verify<T>(&self, proof: &MerkleProof, data: T, root_hash: &str) -> bool
145    where
146        T: AsRef<[u8]>,
147    {
148        // Start with the hash of the data
149        let hash: String = self.hasher.hash(data.as_ref());
150        self.verify_hash(proof, hash, root_hash)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::{hasher::*, merkletree::MerkleTree};
157
158    use super::*;
159
160    #[test]
161    fn test_proof_generation_and_verification_dummy() {
162        let hasher = DummyHasher;
163        let data = vec!["a", "b", "c", "d"];
164        let tree = MerkleTree::new(hasher.clone(), data.clone());
165        let proofer = DefaultProofer::new(hasher, tree.leaves());
166
167        for (index, item) in data.iter().enumerate() {
168            let proof = proofer.generate(index).unwrap();
169
170            assert!(proofer.verify(&proof, item, tree.root().hash()));
171        }
172    }
173
174    #[test]
175    fn test_proof_generation_and_verification_sha256() {
176        let hasher = SHA256Hasher::new();
177        let data = vec!["a", "b", "c", "d"];
178        let tree = MerkleTree::new(hasher.clone(), data.clone());
179        let proofer = DefaultProofer::new(hasher, tree.leaves().clone());
180
181        for (index, item) in data.iter().enumerate() {
182            let proof = proofer.generate(index).unwrap();
183
184            assert!(proofer.verify(&proof, item, tree.root().hash()));
185        }
186    }
187
188    #[test]
189    fn test_proof_not_valid() {
190        let hasher = SHA256Hasher::new();
191        let data = vec!["a", "b", "c", "d"];
192        let tree = MerkleTree::new(hasher.clone(), data.clone());
193        let proofer = DefaultProofer::new(hasher, tree.leaves().clone());
194
195        let proof = proofer.generate(0).unwrap();
196
197        assert!(proofer.verify(&proof, b"a", tree.root().hash()));
198        assert!(!proofer.verify(&proof, b"b", tree.root().hash()));
199        assert!(!proofer.verify(&proof, b"c", tree.root().hash()));
200        assert!(!proofer.verify(&proof, b"d", tree.root().hash()));
201
202        assert!(!proofer.verify(&proof, b"e", tree.root().hash()));
203    }
204}