1use crate::hash::{Hash, hash};
29use serde::{Deserialize, Serialize};
30use thiserror::Error;
31
32#[derive(Debug, Error)]
34pub enum MerkleError {
35 #[error("Invalid leaf index: {0}")]
36 InvalidLeafIndex(usize),
37
38 #[error("Empty tree")]
39 EmptyTree,
40
41 #[error("Proof verification failed")]
42 VerificationFailed,
43
44 #[error("Invalid proof length")]
45 InvalidProofLength,
46
47 #[error("Tree size mismatch")]
48 TreeSizeMismatch,
49}
50
51pub type MerkleResult<T> = Result<T, MerkleError>;
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct MerkleTree {
59 levels: Vec<Vec<Hash>>,
62}
63
64impl MerkleTree {
65 pub fn from_leaves(leaves: &[Vec<u8>]) -> Self {
73 assert!(!leaves.is_empty(), "Cannot create tree from empty leaves");
74
75 let leaf_hashes: Vec<Hash> = leaves.iter().map(|leaf| hash(leaf)).collect();
77
78 Self::from_leaf_hashes(&leaf_hashes)
79 }
80
81 pub fn from_leaf_hashes(leaf_hashes: &[Hash]) -> Self {
83 assert!(
84 !leaf_hashes.is_empty(),
85 "Cannot create tree from empty leaves"
86 );
87
88 let mut levels = vec![leaf_hashes.to_vec()];
89 let mut current_level = leaf_hashes.to_vec();
90
91 while current_level.len() > 1 {
93 let mut next_level = Vec::new();
94
95 for i in (0..current_level.len()).step_by(2) {
96 let left = ¤t_level[i];
97 let right = if i + 1 < current_level.len() {
98 ¤t_level[i + 1]
99 } else {
100 left
102 };
103
104 let mut data = Vec::with_capacity(64);
105 data.extend_from_slice(left);
106 data.extend_from_slice(right);
107 next_level.push(hash(&data));
108 }
109
110 levels.push(next_level.clone());
111 current_level = next_level;
112 }
113
114 Self { levels }
115 }
116
117 pub fn root(&self) -> &Hash {
119 &self.levels.last().unwrap()[0]
120 }
121
122 pub fn leaf_count(&self) -> usize {
124 self.levels[0].len()
125 }
126
127 pub fn generate_proof(&self, leaf_index: usize) -> MerkleResult<MerkleProof> {
135 if leaf_index >= self.leaf_count() {
136 return Err(MerkleError::InvalidLeafIndex(leaf_index));
137 }
138
139 let mut proof_hashes = Vec::new();
140 let mut proof_positions = Vec::new(); let mut index = leaf_index;
142
143 for level in &self.levels[..self.levels.len() - 1] {
145 if index % 2 == 0 {
146 let sibling_index = index + 1;
148 if sibling_index < level.len() {
149 proof_hashes.push(level[sibling_index]);
151 proof_positions.push(true);
152 } else {
153 proof_hashes.push(level[index]);
155 proof_positions.push(true);
156 }
157 } else {
158 let sibling_index = index - 1;
160 proof_hashes.push(level[sibling_index]);
161 proof_positions.push(false);
162 }
163
164 index /= 2;
165 }
166
167 Ok(MerkleProof {
168 hashes: proof_hashes,
169 positions: proof_positions,
170 leaf_index,
171 })
172 }
173
174 pub fn verify_leaf(&self, leaf_data: &[u8], leaf_index: usize) -> bool {
183 if leaf_index >= self.leaf_count() {
184 return false;
185 }
186
187 let leaf_hash = hash(leaf_data);
188 self.levels[0][leaf_index] == leaf_hash
189 }
190}
191
192#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct MerkleProof {
195 hashes: Vec<Hash>,
197 positions: Vec<bool>,
199 leaf_index: usize,
201}
202
203impl MerkleProof {
204 pub fn verify(&self, root: &Hash, leaf_data: &[u8], leaf_index: usize) -> bool {
214 if self.leaf_index != leaf_index {
215 return false;
216 }
217
218 let mut current_hash = hash(leaf_data);
219
220 for (sibling_hash, is_left) in self.hashes.iter().zip(&self.positions) {
221 let mut data = Vec::with_capacity(64);
222
223 if *is_left {
224 data.extend_from_slice(¤t_hash);
226 data.extend_from_slice(sibling_hash);
227 } else {
228 data.extend_from_slice(sibling_hash);
230 data.extend_from_slice(¤t_hash);
231 }
232
233 current_hash = hash(&data);
234 }
235
236 ¤t_hash == root
237 }
238
239 pub fn leaf_index(&self) -> usize {
241 self.leaf_index
242 }
243
244 pub fn depth(&self) -> usize {
246 self.hashes.len()
247 }
248
249 pub fn to_bytes(&self) -> Vec<u8> {
251 crate::codec::encode(self).expect("serialization should not fail")
252 }
253
254 pub fn from_bytes(bytes: &[u8]) -> MerkleResult<Self> {
256 crate::codec::decode(bytes).map_err(|_| MerkleError::InvalidProofLength)
257 }
258}
259
260#[derive(Debug, Clone)]
265pub struct MultiProof {
266 hashes: Vec<Hash>,
268 instructions: Vec<ProofInstruction>,
270}
271
272#[derive(Debug, Clone)]
273#[allow(dead_code)]
274enum ProofInstruction {
275 UseProofHash(usize),
277 UseLeafHash(usize),
279 Combine { left_idx: usize, right_idx: usize },
281}
282
283impl MultiProof {
284 #[allow(dead_code)]
293 pub fn verify(&self, root: &Hash, leaves: &[(usize, &[u8])]) -> bool {
294 let mut stack = Vec::new();
295
296 for instruction in &self.instructions {
297 match instruction {
298 ProofInstruction::UseProofHash(idx) => {
299 stack.push(self.hashes[*idx]);
300 }
301 ProofInstruction::UseLeafHash(idx) => {
302 let leaf_hash = hash(leaves[*idx].1);
303 stack.push(leaf_hash);
304 }
305 ProofInstruction::Combine {
306 left_idx,
307 right_idx,
308 } => {
309 let left = stack[*left_idx];
310 let right = stack[*right_idx];
311
312 let mut data = Vec::with_capacity(64);
313 data.extend_from_slice(&left);
314 data.extend_from_slice(&right);
315
316 stack.push(hash(&data));
317 }
318 }
319 }
320
321 stack.last() == Some(root)
322 }
323}
324
325#[derive(Debug)]
330pub struct IncrementalMerkleBuilder {
331 leaf_hashes: Vec<Hash>,
333}
334
335impl Default for IncrementalMerkleBuilder {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl IncrementalMerkleBuilder {
342 pub fn new() -> Self {
344 Self {
345 leaf_hashes: Vec::new(),
346 }
347 }
348
349 pub fn add_leaf(&mut self, data: &[u8]) {
351 self.leaf_hashes.push(hash(data));
352 }
353
354 pub fn add_leaf_hash(&mut self, leaf_hash: Hash) {
356 self.leaf_hashes.push(leaf_hash);
357 }
358
359 pub fn leaf_count(&self) -> usize {
361 self.leaf_hashes.len()
362 }
363
364 pub fn finalize(self) -> MerkleResult<MerkleTree> {
366 if self.leaf_hashes.is_empty() {
367 return Err(MerkleError::EmptyTree);
368 }
369
370 Ok(MerkleTree::from_leaf_hashes(&self.leaf_hashes))
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_merkle_tree_basic() {
380 let chunks = vec![
381 b"chunk1".to_vec(),
382 b"chunk2".to_vec(),
383 b"chunk3".to_vec(),
384 b"chunk4".to_vec(),
385 ];
386
387 let tree = MerkleTree::from_leaves(&chunks);
388 assert_eq!(tree.leaf_count(), 4);
389
390 let root = tree.root();
391 assert_ne!(root, &[0u8; 32]);
392 }
393
394 #[test]
395 fn test_merkle_proof_generation() {
396 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
397
398 let tree = MerkleTree::from_leaves(&chunks);
399
400 for i in 0..chunks.len() {
401 let proof = tree.generate_proof(i);
402 assert!(proof.is_ok());
403 }
404
405 let invalid_proof = tree.generate_proof(10);
406 assert!(invalid_proof.is_err());
407 }
408
409 #[test]
410 fn test_merkle_proof_verification() {
411 let chunks = vec![
412 b"chunk1".to_vec(),
413 b"chunk2".to_vec(),
414 b"chunk3".to_vec(),
415 b"chunk4".to_vec(),
416 ];
417
418 let tree = MerkleTree::from_leaves(&chunks);
419 let root = tree.root();
420
421 for (i, chunk) in chunks.iter().enumerate() {
422 let proof = tree.generate_proof(i).unwrap();
423 assert!(proof.verify(root, chunk, i));
424 }
425 }
426
427 #[test]
428 fn test_merkle_proof_invalid() {
429 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
430
431 let tree = MerkleTree::from_leaves(&chunks);
432 let root = tree.root();
433
434 let proof = tree.generate_proof(0).unwrap();
435
436 assert!(!proof.verify(root, b"wrong", 0));
438
439 assert!(!proof.verify(root, &chunks[0], 1));
441 }
442
443 #[test]
444 fn test_verify_leaf() {
445 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
446
447 let tree = MerkleTree::from_leaves(&chunks);
448
449 assert!(tree.verify_leaf(b"chunk1", 0));
450 assert!(tree.verify_leaf(b"chunk2", 1));
451 assert!(tree.verify_leaf(b"chunk3", 2));
452
453 assert!(!tree.verify_leaf(b"chunk1", 1));
454 assert!(!tree.verify_leaf(b"wrong", 0));
455 }
456
457 #[test]
458 fn test_incremental_builder() {
459 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
460
461 let mut builder = IncrementalMerkleBuilder::new();
462 for chunk in &chunks {
463 builder.add_leaf(chunk);
464 }
465
466 let tree = builder.finalize().unwrap();
467 assert_eq!(tree.leaf_count(), 3);
468
469 let expected_tree = MerkleTree::from_leaves(&chunks);
470 assert_eq!(tree.root(), expected_tree.root());
471 }
472
473 #[test]
474 fn test_single_leaf() {
475 let chunks = vec![b"single".to_vec()];
476 let tree = MerkleTree::from_leaves(&chunks);
477
478 assert_eq!(tree.leaf_count(), 1);
479
480 let proof = tree.generate_proof(0).unwrap();
481 assert!(proof.verify(tree.root(), b"single", 0));
482 }
483
484 #[test]
485 fn test_proof_serialization() {
486 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec(), b"chunk3".to_vec()];
487
488 let tree = MerkleTree::from_leaves(&chunks);
489 let proof = tree.generate_proof(1).unwrap();
490
491 let bytes = proof.to_bytes();
492 let deserialized = MerkleProof::from_bytes(&bytes).unwrap();
493
494 assert_eq!(proof.leaf_index(), deserialized.leaf_index());
495 assert_eq!(proof.depth(), deserialized.depth());
496
497 let root = tree.root();
498 assert!(deserialized.verify(root, &chunks[1], 1));
499 }
500
501 #[test]
502 fn test_large_tree() {
503 let chunks: Vec<Vec<u8>> = (0..1000)
504 .map(|i| format!("chunk{}", i).into_bytes())
505 .collect();
506
507 let tree = MerkleTree::from_leaves(&chunks);
508 assert_eq!(tree.leaf_count(), 1000);
509
510 for i in [0, 100, 500, 999] {
512 let proof = tree.generate_proof(i).unwrap();
513 assert!(proof.verify(tree.root(), &chunks[i], i));
514 }
515 }
516
517 #[test]
518 fn test_odd_number_of_leaves() {
519 let chunks = vec![
520 b"chunk1".to_vec(),
521 b"chunk2".to_vec(),
522 b"chunk3".to_vec(),
523 b"chunk4".to_vec(),
524 b"chunk5".to_vec(),
525 ];
526
527 let tree = MerkleTree::from_leaves(&chunks);
528 assert_eq!(tree.leaf_count(), 5);
529
530 for (i, chunk) in chunks.iter().enumerate() {
531 let proof = tree.generate_proof(i).unwrap();
532 assert!(proof.verify(tree.root(), chunk, i));
533 }
534 }
535
536 #[test]
537 fn test_two_leaves() {
538 let chunks = vec![b"chunk1".to_vec(), b"chunk2".to_vec()];
539
540 let tree = MerkleTree::from_leaves(&chunks);
541 assert_eq!(tree.leaf_count(), 2);
542
543 for (i, chunk) in chunks.iter().enumerate() {
544 let proof = tree.generate_proof(i).unwrap();
545 assert!(proof.verify(tree.root(), chunk, i));
546 }
547 }
548}