chie_crypto/
merkle.rs

1//! Merkle tree implementation for efficient content verification.
2//!
3//! This module provides Merkle trees for:
4//! - Efficient chunk integrity verification
5//! - Proving a chunk is part of larger content
6//! - Incremental verification as chunks are received
7//! - Supporting partial downloads with proof of correctness
8//!
9//! # Example
10//!
11//! ```
12//! use chie_crypto::merkle::{MerkleTree, MerkleProof};
13//!
14//! // Build tree from chunks
15//! let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
16//! let tree = MerkleTree::from_leaves(&chunks);
17//!
18//! // Get root hash
19//! let root = tree.root();
20//!
21//! // Generate proof for chunk 1
22//! let proof = tree.generate_proof(1).unwrap();
23//!
24//! // Verify the proof
25//! assert!(proof.verify(root, &chunks[1], 1));
26//! ```
27
28use crate::hash::{Hash, hash};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32/// Merkle tree error types.
33#[derive(Debug, Error)]
34pub enum MerkleError {
35    #[error("Invalid leaf index: {0}")]
36    InvalidLeafIndex(usize),
37
38    #[error("Empty tree")]
39    EmptyTree,
40
41    #[error("Proof verification failed")]
42    VerificationFailed,
43
44    #[error("Invalid proof length")]
45    InvalidProofLength,
46
47    #[error("Tree size mismatch")]
48    TreeSizeMismatch,
49}
50
51pub type MerkleResult<T> = Result<T, MerkleError>;
52
53/// A Merkle tree for efficient content verification.
54///
55/// The tree is built from leaf nodes (content chunks) and allows
56/// generating proofs that a specific chunk is part of the content.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MerkleTree {
59    /// Tree levels, stored from leaves to root.
60    /// levels[0] contains leaf hashes, levels[last] contains root.
61    levels: Vec<Vec<Hash>>,
62}
63
64impl MerkleTree {
65    /// Create a Merkle tree from leaf data.
66    ///
67    /// # Arguments
68    /// * `leaves` - The leaf data (e.g., content chunks)
69    ///
70    /// # Returns
71    /// A new Merkle tree.
72    pub fn from_leaves(leaves: &[Vec<u8>]) -> Self {
73        assert!(!leaves.is_empty(), "Cannot create tree from empty leaves");
74
75        // Hash all leaves
76        let leaf_hashes: Vec<Hash> = leaves.iter().map(|leaf| hash(leaf)).collect();
77
78        Self::from_leaf_hashes(&leaf_hashes)
79    }
80
81    /// Create a Merkle tree from pre-hashed leaves.
82    pub fn from_leaf_hashes(leaf_hashes: &[Hash]) -> Self {
83        assert!(
84            !leaf_hashes.is_empty(),
85            "Cannot create tree from empty leaves"
86        );
87
88        let mut levels = vec![leaf_hashes.to_vec()];
89        let mut current_level = leaf_hashes.to_vec();
90
91        // Build tree bottom-up
92        while current_level.len() > 1 {
93            let mut next_level = Vec::new();
94
95            for i in (0..current_level.len()).step_by(2) {
96                let left = &current_level[i];
97                let right = if i + 1 < current_level.len() {
98                    &current_level[i + 1]
99                } else {
100                    // Odd number of nodes, duplicate the last one
101                    left
102                };
103
104                let mut data = Vec::with_capacity(64);
105                data.extend_from_slice(left);
106                data.extend_from_slice(right);
107                next_level.push(hash(&data));
108            }
109
110            levels.push(next_level.clone());
111            current_level = next_level;
112        }
113
114        Self { levels }
115    }
116
117    /// Get the root hash of the tree.
118    pub fn root(&self) -> &Hash {
119        &self.levels.last().unwrap()[0]
120    }
121
122    /// Get the number of leaves in the tree.
123    pub fn leaf_count(&self) -> usize {
124        self.levels[0].len()
125    }
126
127    /// Generate a Merkle proof for a specific leaf.
128    ///
129    /// # Arguments
130    /// * `leaf_index` - Index of the leaf to generate proof for
131    ///
132    /// # Returns
133    /// A Merkle proof that can be used to verify the leaf.
134    pub fn generate_proof(&self, leaf_index: usize) -> MerkleResult<MerkleProof> {
135        if leaf_index >= self.leaf_count() {
136            return Err(MerkleError::InvalidLeafIndex(leaf_index));
137        }
138
139        let mut proof_hashes = Vec::new();
140        let mut proof_positions = Vec::new(); // true = left, false = right
141        let mut index = leaf_index;
142
143        // Traverse up the tree, collecting sibling hashes
144        for level in &self.levels[..self.levels.len() - 1] {
145            if index % 2 == 0 {
146                // Current node is on the left
147                let sibling_index = index + 1;
148                if sibling_index < level.len() {
149                    // Normal case: sibling exists
150                    proof_hashes.push(level[sibling_index]);
151                    proof_positions.push(true);
152                } else {
153                    // Odd case: we're the last node, duplicate ourselves
154                    proof_hashes.push(level[index]);
155                    proof_positions.push(true);
156                }
157            } else {
158                // Current node is on the right, sibling is on the left
159                let sibling_index = index - 1;
160                proof_hashes.push(level[sibling_index]);
161                proof_positions.push(false);
162            }
163
164            index /= 2;
165        }
166
167        Ok(MerkleProof {
168            hashes: proof_hashes,
169            positions: proof_positions,
170            leaf_index,
171        })
172    }
173
174    /// Verify that a leaf with given data exists at the specified index.
175    ///
176    /// # Arguments
177    /// * `leaf_data` - The leaf data to verify
178    /// * `leaf_index` - The claimed index of the leaf
179    ///
180    /// # Returns
181    /// `true` if the leaf exists at the index, `false` otherwise.
182    pub fn verify_leaf(&self, leaf_data: &[u8], leaf_index: usize) -> bool {
183        if leaf_index >= self.leaf_count() {
184            return false;
185        }
186
187        let leaf_hash = hash(leaf_data);
188        self.levels[0][leaf_index] == leaf_hash
189    }
190}
191
192/// A Merkle proof that a specific leaf is part of a Merkle tree.
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MerkleProof {
195    /// Sibling hashes along the path from leaf to root.
196    hashes: Vec<Hash>,
197    /// Positions indicating whether the node is on left (true) or right (false).
198    positions: Vec<bool>,
199    /// Index of the leaf this proof is for.
200    leaf_index: usize,
201}
202
203impl MerkleProof {
204    /// Verify this proof against a root hash and leaf data.
205    ///
206    /// # Arguments
207    /// * `root` - The expected root hash of the tree
208    /// * `leaf_data` - The leaf data to verify
209    /// * `leaf_index` - The claimed index of the leaf
210    ///
211    /// # Returns
212    /// `true` if the proof is valid, `false` otherwise.
213    pub fn verify(&self, root: &Hash, leaf_data: &[u8], leaf_index: usize) -> bool {
214        if self.leaf_index != leaf_index {
215            return false;
216        }
217
218        let mut current_hash = hash(leaf_data);
219
220        for (sibling_hash, is_left) in self.hashes.iter().zip(&self.positions) {
221            let mut data = Vec::with_capacity(64);
222
223            if *is_left {
224                // Current node is on the left
225                data.extend_from_slice(&current_hash);
226                data.extend_from_slice(sibling_hash);
227            } else {
228                // Current node is on the right
229                data.extend_from_slice(sibling_hash);
230                data.extend_from_slice(&current_hash);
231            }
232
233            current_hash = hash(&data);
234        }
235
236        &current_hash == root
237    }
238
239    /// Get the leaf index this proof is for.
240    pub fn leaf_index(&self) -> usize {
241        self.leaf_index
242    }
243
244    /// Get the number of hashes in the proof (proof depth).
245    pub fn depth(&self) -> usize {
246        self.hashes.len()
247    }
248
249    /// Serialize the proof to bytes.
250    pub fn to_bytes(&self) -> Vec<u8> {
251        crate::codec::encode(self).expect("serialization should not fail")
252    }
253
254    /// Deserialize a proof from bytes.
255    pub fn from_bytes(bytes: &[u8]) -> MerkleResult<Self> {
256        crate::codec::decode(bytes).map_err(|_| MerkleError::InvalidProofLength)
257    }
258}
259
260/// Multi-proof for verifying multiple leaves at once.
261///
262/// This is more efficient than individual proofs when verifying
263/// multiple chunks from the same content.
264#[derive(Debug, Clone)]
265pub struct MultiProof {
266    /// Minimal set of hashes needed to verify all leaves.
267    hashes: Vec<Hash>,
268    /// Instructions for combining hashes.
269    instructions: Vec<ProofInstruction>,
270}
271
272#[derive(Debug, Clone)]
273#[allow(dead_code)]
274enum ProofInstruction {
275    /// Use a hash from the proof at the given index.
276    UseProofHash(usize),
277    /// Use a leaf hash at the given index.
278    UseLeafHash(usize),
279    /// Combine two previous results.
280    Combine { left_idx: usize, right_idx: usize },
281}
282
283impl MultiProof {
284    /// Verify multiple leaves at once.
285    ///
286    /// # Arguments
287    /// * `root` - The expected root hash
288    /// * `leaves` - The leaf data to verify (index, data) pairs
289    ///
290    /// # Returns
291    /// `true` if all leaves are valid, `false` otherwise.
292    #[allow(dead_code)]
293    pub fn verify(&self, root: &Hash, leaves: &[(usize, &[u8])]) -> bool {
294        let mut stack = Vec::new();
295
296        for instruction in &self.instructions {
297            match instruction {
298                ProofInstruction::UseProofHash(idx) => {
299                    stack.push(self.hashes[*idx]);
300                }
301                ProofInstruction::UseLeafHash(idx) => {
302                    let leaf_hash = hash(leaves[*idx].1);
303                    stack.push(leaf_hash);
304                }
305                ProofInstruction::Combine {
306                    left_idx,
307                    right_idx,
308                } => {
309                    let left = stack[*left_idx];
310                    let right = stack[*right_idx];
311
312                    let mut data = Vec::with_capacity(64);
313                    data.extend_from_slice(&left);
314                    data.extend_from_slice(&right);
315
316                    stack.push(hash(&data));
317                }
318            }
319        }
320
321        stack.last() == Some(root)
322    }
323}
324
325/// Incremental Merkle tree builder for streaming content.
326///
327/// This allows building a Merkle tree as chunks arrive,
328/// without needing all chunks in memory at once.
329#[derive(Debug)]
330pub struct IncrementalMerkleBuilder {
331    /// Accumulated leaf hashes.
332    leaf_hashes: Vec<Hash>,
333}
334
335impl Default for IncrementalMerkleBuilder {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl IncrementalMerkleBuilder {
342    /// Create a new incremental builder.
343    pub fn new() -> Self {
344        Self {
345            leaf_hashes: Vec::new(),
346        }
347    }
348
349    /// Add a leaf to the tree.
350    pub fn add_leaf(&mut self, data: &[u8]) {
351        self.leaf_hashes.push(hash(data));
352    }
353
354    /// Add a pre-hashed leaf.
355    pub fn add_leaf_hash(&mut self, leaf_hash: Hash) {
356        self.leaf_hashes.push(leaf_hash);
357    }
358
359    /// Get the current number of leaves.
360    pub fn leaf_count(&self) -> usize {
361        self.leaf_hashes.len()
362    }
363
364    /// Finalize the tree.
365    pub fn finalize(self) -> MerkleResult<MerkleTree> {
366        if self.leaf_hashes.is_empty() {
367            return Err(MerkleError::EmptyTree);
368        }
369
370        Ok(MerkleTree::from_leaf_hashes(&self.leaf_hashes))
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_merkle_tree_basic() {
380        let chunks = vec![
381            b"chunk1".to_vec(),
382            b"chunk2".to_vec(),
383            b"chunk3".to_vec(),
384            b"chunk4".to_vec(),
385        ];
386
387        let tree = MerkleTree::from_leaves(&chunks);
388        assert_eq!(tree.leaf_count(), 4);
389
390        let root = tree.root();
391        assert_ne!(root, &[0u8; 32]);
392    }
393
394    #[test]
395    fn test_merkle_proof_generation() {
396        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
397
398        let tree = MerkleTree::from_leaves(&chunks);
399
400        for i in 0..chunks.len() {
401            let proof = tree.generate_proof(i);
402            assert!(proof.is_ok());
403        }
404
405        let invalid_proof = tree.generate_proof(10);
406        assert!(invalid_proof.is_err());
407    }
408
409    #[test]
410    fn test_merkle_proof_verification() {
411        let chunks = vec![
412            b"chunk1".to_vec(),
413            b"chunk2".to_vec(),
414            b"chunk3".to_vec(),
415            b"chunk4".to_vec(),
416        ];
417
418        let tree = MerkleTree::from_leaves(&chunks);
419        let root = tree.root();
420
421        for (i, chunk) in chunks.iter().enumerate() {
422            let proof = tree.generate_proof(i).unwrap();
423            assert!(proof.verify(root, chunk, i));
424        }
425    }
426
427    #[test]
428    fn test_merkle_proof_invalid() {
429        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
430
431        let tree = MerkleTree::from_leaves(&chunks);
432        let root = tree.root();
433
434        let proof = tree.generate_proof(0).unwrap();
435
436        // Wrong data
437        assert!(!proof.verify(root, b"wrong", 0));
438
439        // Wrong index
440        assert!(!proof.verify(root, &chunks[0], 1));
441    }
442
443    #[test]
444    fn test_verify_leaf() {
445        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
446
447        let tree = MerkleTree::from_leaves(&chunks);
448
449        assert!(tree.verify_leaf(b"chunk1", 0));
450        assert!(tree.verify_leaf(b"chunk2", 1));
451        assert!(tree.verify_leaf(b"chunk3", 2));
452
453        assert!(!tree.verify_leaf(b"chunk1", 1));
454        assert!(!tree.verify_leaf(b"wrong", 0));
455    }
456
457    #[test]
458    fn test_incremental_builder() {
459        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
460
461        let mut builder = IncrementalMerkleBuilder::new();
462        for chunk in &chunks {
463            builder.add_leaf(chunk);
464        }
465
466        let tree = builder.finalize().unwrap();
467        assert_eq!(tree.leaf_count(), 3);
468
469        let expected_tree = MerkleTree::from_leaves(&chunks);
470        assert_eq!(tree.root(), expected_tree.root());
471    }
472
473    #[test]
474    fn test_single_leaf() {
475        let chunks = vec![b"single".to_vec()];
476        let tree = MerkleTree::from_leaves(&chunks);
477
478        assert_eq!(tree.leaf_count(), 1);
479
480        let proof = tree.generate_proof(0).unwrap();
481        assert!(proof.verify(tree.root(), b"single", 0));
482    }
483
484    #[test]
485    fn test_proof_serialization() {
486        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
487
488        let tree = MerkleTree::from_leaves(&chunks);
489        let proof = tree.generate_proof(1).unwrap();
490
491        let bytes = proof.to_bytes();
492        let deserialized = MerkleProof::from_bytes(&bytes).unwrap();
493
494        assert_eq!(proof.leaf_index(), deserialized.leaf_index());
495        assert_eq!(proof.depth(), deserialized.depth());
496
497        let root = tree.root();
498        assert!(deserialized.verify(root, &chunks[1], 1));
499    }
500
501    #[test]
502    fn test_large_tree() {
503        let chunks: Vec<Vec<u8>> = (0..1000)
504            .map(|i| format!("chunk{}", i).into_bytes())
505            .collect();
506
507        let tree = MerkleTree::from_leaves(&chunks);
508        assert_eq!(tree.leaf_count(), 1000);
509
510        // Verify random chunks
511        for i in [0, 100, 500, 999] {
512            let proof = tree.generate_proof(i).unwrap();
513            assert!(proof.verify(tree.root(), &chunks[i], i));
514        }
515    }
516
517    #[test]
518    fn test_odd_number_of_leaves() {
519        let chunks = vec![
520            b"chunk1".to_vec(),
521            b"chunk2".to_vec(),
522            b"chunk3".to_vec(),
523            b"chunk4".to_vec(),
524            b"chunk5".to_vec(),
525        ];
526
527        let tree = MerkleTree::from_leaves(&chunks);
528        assert_eq!(tree.leaf_count(), 5);
529
530        for (i, chunk) in chunks.iter().enumerate() {
531            let proof = tree.generate_proof(i).unwrap();
532            assert!(proof.verify(tree.root(), chunk, i));
533        }
534    }
535
536    #[test]
537    fn test_two_leaves() {
538        let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec()];
539
540        let tree = MerkleTree::from_leaves(&chunks);
541        assert_eq!(tree.leaf_count(), 2);
542
543        for (i, chunk) in chunks.iter().enumerate() {
544            let proof = tree.generate_proof(i).unwrap();
545            assert!(proof.verify(tree.root(), chunk, i));
546        }
547    }
548}