brine_tree/
tree.rs

1use bytemuck::{Pod, Zeroable};
2use super::{hash::Hash, hash::Leaf, hashv};
3use super::utils::check_condition;
4use super::error::{BrineTreeError, ProgramResult};
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9    root: Hash,
10    filled_subtrees: [Hash; N],
11    zero_values: [Hash; N],
12    next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19    pub const fn get_depth(&self) -> u8 {
20        N as u8
21    }
22
23    pub const fn get_size() -> usize {
24        core::mem::size_of::<Self>()
25    }
26
27    pub fn get_root(&self) -> Hash {
28        self.root
29    }
30
31    pub fn get_empty_leaf(&self) -> Hash {
32        self.zero_values[0]
33    }
34
35    pub fn new(seeds: &[&[u8]]) -> Self {
36        let zeros = Self::calc_zeros(seeds);
37        Self {
38            next_index: 0,
39            root: zeros[N - 1],
40            filled_subtrees: zeros,
41            zero_values: zeros,
42        }
43    }
44
45    pub fn init(&mut self, seeds: &[&[u8]]) {
46        let zeros = Self::calc_zeros(seeds);
47        self.next_index = 0;
48        self.root = zeros[N - 1];
49        self.filled_subtrees = zeros;
50        self.zero_values = zeros;
51    }
52
53    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
54        let mut zeros: [Hash; N] = [Hash::default(); N];
55        let mut current = hashv(seeds);
56
57        for i in 0..N {
58            zeros[i] = current;
59            current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
60        }
61
62        zeros
63    }
64
65    pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
66        let leaf = Leaf::new(data);
67        self.try_add_leaf(leaf)
68    }
69
70    pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
71        check_condition(
72            self.next_index < (1u64 << N),
73            BrineTreeError::TreeFull,
74        )?;
75
76        let mut current_index = self.next_index;
77        let mut current_hash = Hash::from(leaf);
78        let mut left;
79        let mut right;
80
81        for i in 0..N {
82            if current_index % 2 == 0 {
83                left = current_hash;
84                right = self.zero_values[i];
85                self.filled_subtrees[i] = current_hash;
86            } else {
87                left = self.filled_subtrees[i];
88                right = current_hash;
89            }
90
91            current_hash = Self::hash_left_right(left, right);
92            current_index /= 2;
93        }
94
95        self.root = current_hash;
96        self.next_index += 1;
97
98        Ok(())
99    }
100
101    pub fn try_remove(&mut self, proof: &[Hash], data: &[&[u8]]) -> ProgramResult {
102        let leaf = Leaf::new(data);
103        self.try_remove_leaf(proof, leaf)
104    }
105
106    pub fn try_remove_leaf(&mut self, proof: &[Hash], leaf: Leaf) -> ProgramResult {
107        self.check_length(proof)?;
108        self.try_replace_leaf(proof, Hash::from(leaf), self.get_empty_leaf())
109    }
110
111    pub fn try_replace(&mut self, proof: &[Hash], original_data: &[&[u8]], new_data: &[&[u8]]) -> ProgramResult {
112        let original_leaf = Leaf::new(original_data);
113        let new_leaf = Leaf::new(new_data);
114        self.try_replace_leaf(proof, Hash::from(original_leaf), Hash::from(new_leaf))
115    }
116
117    pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Hash, new_leaf: Hash) -> ProgramResult {
118        self.check_length(proof)?;
119
120        let original_path = MerkleTree::<N>::compute_path(proof, original_leaf);
121        let new_path = MerkleTree::<N>::compute_path(proof, new_leaf);
122
123        check_condition(
124            MerkleTree::<N>::is_valid_path(&original_path, self.root),
125            BrineTreeError::InvalidProof,
126        )?;
127
128        for i in 0..N {
129            if original_path[i] == self.filled_subtrees[i] {
130                self.filled_subtrees[i] = new_path[i];
131            }
132        }
133
134        self.root = *new_path.last().unwrap();
135
136        Ok(())
137    }
138
139    pub fn contains(&self, proof: &[Hash], data: &[&[u8]]) -> bool {
140        let leaf = Leaf::new(data);
141        self.contains_leaf(proof, leaf)
142    }
143
144    pub fn contains_leaf(&self, proof: &[Hash], leaf: Leaf) -> bool {
145        if let Err(_) = self.check_length(proof) {
146            return false;
147        }
148
149        let root = self.get_root();
150        Self::is_valid_leaf(proof, root, Hash::from(leaf))
151    }
152
153    pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
154        let combined;
155        if left.to_bytes() <= right.to_bytes() {
156            combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
157        } else {
158            combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
159        }
160
161        hashv(&combined)
162    }
163
164    pub fn compute_path(proof: &[Hash], leaf: Hash) -> Vec<Hash> {
165        let mut computed_path = Vec::with_capacity(proof.len() + 1);
166        let mut computed_hash = leaf;
167
168        computed_path.push(computed_hash);
169
170        for proof_element in proof.iter() {
171            computed_hash = Self::hash_left_right(computed_hash, *proof_element);
172            computed_path.push(computed_hash);
173        }
174
175        computed_path
176    }
177
178    pub fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Hash) -> bool {
179        let computed_path = Self::compute_path(proof, leaf);
180        Self::is_valid_path(&computed_path, root)
181    }
182
183    pub fn is_valid_path(path: &[Hash], root: Hash) -> bool {
184        if path.is_empty() {
185            return false;
186        }
187
188        *path.last().unwrap() == root
189    }
190
191    #[cfg(not(feature = "solana"))]
192    fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
193        let mut res = Vec::with_capacity(pairs.len() / 2);
194
195        for i in (0..pairs.len()).step_by(2) {
196            let left = pairs[i];
197            let right = pairs[i + 1];
198
199            let hashed = Self::hash_left_right(left, right);
200            res.push(hashed);
201        }
202
203        res
204    }
205
206    #[cfg(not(feature = "solana"))]
207    pub fn get_merkle_proof(&self, values: &[Leaf], index: usize) -> Vec<Hash> {
208        let mut layers = Vec::with_capacity(N);
209        let mut current_layer: Vec<Hash> = values.iter().map(|leaf| Hash::from(*leaf)).collect();
210
211        for i in 0..N {
212            if current_layer.len() % 2 != 0 {
213                current_layer.push(self.zero_values[i]);
214            }
215
216            layers.push(current_layer.clone());
217            current_layer = Self::hash_pairs(current_layer);
218        }
219
220        let mut proof = Vec::with_capacity(N);
221        let mut current_index = index;
222        let mut layer_index = 0;
223        let mut sibling;
224
225        for _ in 0..N {
226            if current_index % 2 == 0 {
227                sibling = layers[layer_index][current_index + 1];
228            } else {
229                sibling = layers[layer_index][current_index - 1];
230            }
231
232            proof.push(sibling);
233
234            current_index /= 2;
235            layer_index += 1;
236        }
237
238        proof
239    }
240
241    fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
242        check_condition(proof.len() == N, BrineTreeError::ProofLength)
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    type TestTree = MerkleTree<3>;
251
252    #[test]
253    fn test_create_tree() {
254        let seeds: &[&[u8]] = &[b"test"];
255        let tree = TestTree::new(seeds);
256
257        assert_eq!(tree.get_depth(), 3);
258        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
259    }
260
261    #[test]
262    fn test_insert_and_remove() {
263        let seeds: &[&[u8]] = &[b"test"];
264
265        let mut tree = TestTree::new(seeds);
266        let empty = tree.zero_values.first().unwrap().clone();
267        let empty_leaf = empty.as_leaf();
268
269        // Tree structure:
270        // 
271        //              root
272        //            /     \
273        //         m           n
274        //       /   \       /   \
275        //      i     j     k     l
276        //     / \   / \   / \   / \
277        //    a  b  c  d  e  f  g  h
278
279        let a = Hash::from(Leaf::new(&[b"val_1"]));
280        let b = Hash::from(Leaf::new(&[b"val_2"]));
281        let c = Hash::from(Leaf::new(&[b"val_3"]));
282
283        let d = empty.clone();
284        let e = empty.clone();
285        let f = empty.clone();
286        let g = empty.clone();
287        let h = empty.clone();
288
289        let i = TestTree::hash_left_right(a, b);
290        let j: Hash = TestTree::hash_left_right(c, d);
291        let k: Hash = TestTree::hash_left_right(e, f);
292        let l: Hash = TestTree::hash_left_right(g, h);
293        let m: Hash = TestTree::hash_left_right(i, j);
294        let n: Hash = TestTree::hash_left_right(k, l);
295        let root = TestTree::hash_left_right(m, n);
296
297        assert!(tree.try_add(&[b"val_1"]).is_ok());
298        assert!(tree.filled_subtrees[0].eq(&a));
299
300        assert!(tree.try_add(&[b"val_2"]).is_ok());
301        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
302
303        assert!(tree.try_add(&[b"val_3"]).is_ok());
304        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
305
306        assert_eq!(tree.filled_subtrees[0], c);
307        assert_eq!(tree.filled_subtrees[1], i);
308        assert_eq!(tree.filled_subtrees[2], m);
309        assert_eq!(root, tree.get_root());
310
311        let val1_proof = vec![b.clone(), j.clone(), n.clone()];
312        let val2_proof = vec![a.clone(), j.clone(), n.clone()];
313        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
314
315        // Check filled leaves
316        assert!(tree.contains(&val1_proof, &[b"val_1"]));
317        assert!(tree.contains(&val2_proof, &[b"val_2"]));
318        assert!(tree.contains(&val3_proof, &[b"val_3"]));
319
320        // Check empty leaves
321        assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
322        assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
323        assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
324        assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
325        assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
326
327        // Remove val2 from the tree
328        assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
329
330        // Update the expected tree structure
331        let i = TestTree::hash_left_right(a, empty);
332        let m: Hash = TestTree::hash_left_right(i, j);
333        let root = TestTree::hash_left_right(m, n);
334
335        assert_eq!(root, tree.get_root());
336
337        let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
338        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
339
340        assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
341        assert!(tree.contains_leaf(&val2_proof, empty_leaf));
342        assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
343
344        // Check that val2 is no longer in the tree
345        assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
346
347        // Insert val4 into the tree
348        assert!(tree.try_add(&[b"val_4"]).is_ok());
349        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
350
351        // Update the expected tree structure
352        let d = Hash::from(Leaf::new(&[b"val_4"]));
353        let j = TestTree::hash_left_right(c, d);
354        let m = TestTree::hash_left_right(i, j);
355        let root = TestTree::hash_left_right(m, n);
356
357        assert_eq!(root, tree.get_root());
358    }
359
360    #[test]
361    fn test_proof() {
362        let seeds: &[&[u8]] = &[b"test"];
363
364        let mut tree = TestTree::new(seeds);
365
366        let leaves = [
367            Leaf::new(&[b"val_1"]),
368            Leaf::new(&[b"val_2"]),
369            Leaf::new(&[b"val_3"]),
370        ];
371
372        assert!(tree.try_add(&[b"val_1"]).is_ok());
373        assert!(tree.try_add(&[b"val_2"]).is_ok());
374        assert!(tree.try_add(&[b"val_3"]).is_ok());
375
376        let val1_proof = tree.get_merkle_proof(&leaves, 0);
377        let val2_proof = tree.get_merkle_proof(&leaves, 1);
378        let val3_proof = tree.get_merkle_proof(&leaves, 2);
379
380        assert!(tree.contains(&val1_proof, &[b"val_1"]));
381        assert!(tree.contains(&val2_proof, &[b"val_2"]));
382        assert!(tree.contains(&val3_proof, &[b"val_3"]));
383
384        // Invalid Proof Length
385        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
386        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
387
388        assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
389        assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
390
391        // Empty Proof
392        let empty_proof: Vec<Hash> = Vec::new();
393        assert!(!tree.contains(&empty_proof, &[b"val_1"]));
394    }
395
396    #[test]
397    fn test_init_and_reinit() {
398        let seeds: &[&[u8]] = &[b"test"];
399        let mut tree = TestTree::new(seeds);
400        
401        // Store initial state
402        let initial_root = tree.get_root();
403        let initial_zeros = tree.zero_values;
404        let initial_filled = tree.filled_subtrees;
405        let initial_index = tree.next_index;
406
407        // Add a leaf to modify the tree
408        assert!(tree.try_add(&[b"val_1"]).is_ok());
409        
410        // Reinitialize
411        tree.init(seeds);
412        
413        // Verify tree is reset to initial state
414        assert_eq!(tree.get_root(), initial_root);
415        assert_eq!(tree.zero_values, initial_zeros);
416        assert_eq!(tree.filled_subtrees, initial_filled);
417        assert_eq!(tree.next_index, initial_index);
418    }
419
420    #[test]
421    fn test_tree_full() {
422        let seeds: &[&[u8]] = &[b"test"];
423        let mut tree = TestTree::new(seeds);
424        
425        // Fill the tree (2^3 = 8 leaves)
426        for i in 0u8..8 {
427            assert!(tree.try_add(&[&[i]]).is_ok());
428        }
429        
430        // Try to add one more leaf
431        let result = tree.try_add(&[b"extra"]);
432        assert!(result.is_err());
433        assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
434    }
435
436    #[test]
437    fn test_replace_leaf() {
438        let seeds: &[&[u8]] = &[b"test"];
439        let mut tree = TestTree::new(seeds);
440        
441        // Add initial leaves
442        assert!(tree.try_add(&[b"val_1"]).is_ok());
443        assert!(tree.try_add(&[b"val_2"]).is_ok());
444        
445        // Get proof for val_1
446        let leaves = [
447            Leaf::new(&[b"val_1"]),
448            Leaf::new(&[b"val_2"]),
449        ];
450        let proof = tree.get_merkle_proof(&leaves, 0);
451        
452        // Replace val_1 with new_val
453        assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
454        
455        // Verify new leaf is present
456        assert!(tree.contains(&proof, &[b"new_val"]));
457        assert!(!tree.contains(&proof, &[b"val_1"]));
458        
459        // Verify val_2 is still present
460        let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
461        assert!(tree.contains(&proof_val2, &[b"val_2"]));
462    }
463
464    #[test]
465    fn test_invalid_replace() {
466        let seeds: &[&[u8]] = &[b"test"];
467        let mut tree = TestTree::new(seeds);
468        
469        assert!(tree.try_add(&[b"val_1"]).is_ok());
470        
471        let leaves = [Leaf::new(&[b"val_1"])];
472        let proof = tree.get_merkle_proof(&leaves, 0);
473        
474        // Try to replace with wrong original leaf
475        let result = tree.try_replace(&proof, &[b"wrong_val"], &[b"new_val"]);
476        assert!(result.is_err());
477        assert_eq!(result.unwrap_err(), BrineTreeError::InvalidProof);
478        
479        // Verify original leaf is still present
480        assert!(tree.contains(&proof, &[b"val_1"]));
481    }
482
483    #[test]
484    fn test_zero_values_calculation() {
485        let seeds: &[&[u8]] = &[b"test"];
486        let tree = TestTree::new(seeds);
487        
488        let zeros = tree.zero_values;
489        
490        // Verify zero values are correctly chained
491        let mut expected = hashv(seeds);
492        for i in 0..3 {
493            assert_eq!(zeros[i], expected);
494            expected = hashv(&[b"NODE".as_ref(), expected.as_ref(), expected.as_ref()]);
495        }
496    }
497
498    #[test]
499    fn test_path_computation() {
500        let seeds: &[&[u8]] = &[b"test"];
501        let tree = TestTree::new(seeds);
502        
503        let leaf = Hash::from(Leaf::new(&[b"val_1"]));
504        let proof = vec![
505            Hash::from(Leaf::new(&[b"val_2"])),
506            tree.zero_values[1],
507            tree.zero_values[2],
508        ];
509        
510        let path = TestTree::compute_path(&proof, leaf);
511        
512        // Verify path length
513        assert_eq!(path.len(), 4);
514        
515        // Verify path computation
516        assert_eq!(path[0], leaf);
517        assert_eq!(path[1], TestTree::hash_left_right(leaf, proof[0]));
518        assert_eq!(path[2], TestTree::hash_left_right(path[1], proof[1]));
519        assert_eq!(path[3], TestTree::hash_left_right(path[2], proof[2]));
520    }
521
522    #[test]
523    fn test_invalid_path() {
524        let seeds: &[&[u8]] = &[b"test"];
525        let tree = TestTree::new(seeds);
526        
527        // Empty path
528        let empty_path: Vec<Hash> = vec![];
529        assert!(!TestTree::is_valid_path(&empty_path, tree.get_root()));
530        
531        // Wrong root
532        let leaf = Hash::from(Leaf::new(&[b"val_1"]));
533        let proof = vec![tree.zero_values[0], tree.zero_values[1], tree.zero_values[2]];
534        let path = TestTree::compute_path(&proof, leaf);
535        let wrong_root = Hash::default();
536        assert!(!TestTree::is_valid_path(&path, wrong_root));
537    }
538
539    #[test]
540    fn test_hash_left_right_ordering() {
541        let left = Hash::from(Leaf::new(&[b"val_1"]));
542        let right = Hash::from(Leaf::new(&[b"val_2"]));
543        
544        let hash1 = TestTree::hash_left_right(left, right);
545        
546        // Swap order - should produce same result due to ordering in hash_left_right
547        let hash2 = TestTree::hash_left_right(right, left);
548        
549        assert_eq!(hash1, hash2);
550        
551        // Verify correct ordering was used (smaller hash first)
552        let direct = hashv(&[
553            b"NODE".as_ref(), 
554            left.to_bytes().min(right.to_bytes()).as_ref(), 
555            left.to_bytes().max(right.to_bytes()).as_ref()
556        ]);
557
558        assert_eq!(hash1, direct);
559    }
560
561    #[test]
562    fn test_partial_proof() {
563        let seeds: &[&[u8]] = &[b"test"];
564        let mut tree = TestTree::new(seeds);
565
566        // Add a leaf to the tree
567        assert!(tree.try_add(&[b"val_1"]).is_ok());
568
569        // Create a valid proof
570        let leaves = [Leaf::new(&[b"val_1"])];
571        let valid_proof = tree.get_merkle_proof(&leaves, 0);
572        assert_eq!(valid_proof.len(), 3); // Should match tree depth
573
574        // Create a partial proof (shorter than depth)
575        let partial_proof = &valid_proof[..2]; // Take only first two elements
576
577        // Verify that partial proof is rejected
578        assert!(!tree.contains(&partial_proof, &[b"val_1"]));
579
580        // Try to use partial proof for removal
581        let remove_result = tree.try_remove(&partial_proof, &[b"val_1"]);
582        assert!(remove_result.is_err());
583        assert_eq!(remove_result.unwrap_err(), BrineTreeError::ProofLength);
584
585        // Try to use partial proof for replacement
586        let replace_result = tree.try_replace(&partial_proof, &[b"val_1"], &[b"new_val"]);
587        assert!(replace_result.is_err());
588        assert_eq!(replace_result.unwrap_err(), BrineTreeError::ProofLength);
589
590        // Verify original leaf is still present
591        assert!(tree.contains(&valid_proof, &[b"val_1"]));
592    }
593
594    #[test]
595    fn test_leaf_vs_node_attack() {
596        let seeds: &[&[u8]] = &[b"test"];
597        let mut tree = TestTree::new(seeds);
598
599        // Add some leaves to the tree
600        let val_1 : &[&[u8]] = &[b"val_1"];
601        let val_2 : &[&[u8]] = &[b"val_2"];
602        assert!(tree.try_add(val_1).is_ok());
603        assert!(tree.try_add(val_2).is_ok());
604
605        // Get a valid proof for val_1
606        let leaves = [
607            Leaf::new(val_1),
608            Leaf::new(val_2),
609        ];
610        let valid_proof = tree.get_merkle_proof(&leaves, 0);
611        assert_eq!(valid_proof.len(), 3); // Matches tree depth
612
613        // Create a malicious proof by replacing a node hash with a leaf hash
614        let malicious_leaf = Leaf::new(&[b"malicious"]);
615        let malicious_hash = Hash::from(malicious_leaf);
616        let mut malicious_proof = valid_proof.clone();
617        // Replace the second proof element (a node hash) with a leaf hash
618        malicious_proof[1] = malicious_hash;
619
620        // Verify that the malicious proof is rejected
621        assert!(!tree.contains(&malicious_proof, val_1));
622
623        // Attempt to replace val_1 using the malicious proof
624        let replace_result = tree.try_replace(&malicious_proof, val_1, &[b"new_val"]);
625        assert!(replace_result.is_err());
626        assert_eq!(replace_result.unwrap_err(), BrineTreeError::InvalidProof);
627
628        // Attempt to remove val_1 using the malicious proof
629        let remove_result = tree.try_remove(&malicious_proof, val_1);
630        assert!(remove_result.is_err());
631        assert_eq!(remove_result.unwrap_err(), BrineTreeError::InvalidProof);
632
633        // Verify the tree state is unchanged
634        assert!(tree.contains(&valid_proof, val_1));
635        assert!(tree.contains(&tree.get_merkle_proof(&leaves, 1), val_2));
636    }
637
638    #[test]
639    fn test_proof_with_duplicate_hashes() {
640        let seeds: &[&[u8]] = &[b"test"];
641        let mut tree = TestTree::new(seeds);
642
643        // Add a leaf to the tree
644        assert!(tree.try_add(&[b"val_1"]).is_ok());
645
646        // Create a valid proof
647        let leaves = [Leaf::new(&[b"val_1"])];
648        let valid_proof = tree.get_merkle_proof(&leaves, 0);
649        assert_eq!(valid_proof.len(), 3);
650
651        // Create a proof with duplicate hashes (same hash repeated)
652        let duplicate_hash = valid_proof[0];
653        let duplicate_proof = vec![duplicate_hash, duplicate_hash, duplicate_hash];
654
655        // Verify that the duplicate proof is rejected
656        assert!(!tree.contains(&duplicate_proof, &[b"val_1"]));
657
658        // Attempt to remove using the duplicate proof
659        let remove_result = tree.try_remove(&duplicate_proof, &[b"val_1"]);
660        assert!(remove_result.is_err());
661        assert_eq!(remove_result.unwrap_err(), BrineTreeError::InvalidProof);
662
663        // Attempt to replace using the duplicate proof
664        let replace_result = tree.try_replace(&duplicate_proof, &[b"val_1"], &[b"new_val"]);
665        assert!(replace_result.is_err());
666        assert_eq!(replace_result.unwrap_err(), BrineTreeError::InvalidProof);
667
668        // Verify the original leaf is still present
669        assert!(tree.contains(&valid_proof, &[b"val_1"]));
670    }
671
672    #[test]
673    fn test_proof_with_zero_hashes() {
674        let seeds: &[&[u8]] = &[b"test"];
675        let mut tree = TestTree::new(seeds);
676
677        // Add a leaf to the tree
678        assert!(tree.try_add(&[b"val_1"]).is_ok());
679
680        // Create a valid proof
681        let leaves = [Leaf::new(&[b"val_1"])];
682        let valid_proof = tree.get_merkle_proof(&leaves, 0);
683        assert_eq!(valid_proof.len(), 3);
684
685        // Create a proof with zero hashes (Hash::default())
686        let zero_hash = Hash::default();
687        let zero_proof = vec![zero_hash, zero_hash, zero_hash];
688
689        // Verify that the zero proof is rejected
690        assert!(!tree.contains(&zero_proof, &[b"val_1"]));
691
692        // Attempt to remove using the zero proof
693        let remove_result = tree.try_remove(&zero_proof, &[b"val_1"]);
694        assert!(remove_result.is_err());
695        assert_eq!(remove_result.unwrap_err(), BrineTreeError::InvalidProof);
696
697        // Attempt to replace using the zero proof
698        let replace_result = tree.try_replace(&zero_proof, &[b"val_1"], &[b"new_val"]);
699        assert!(replace_result.is_err());
700        assert_eq!(replace_result.unwrap_err(), BrineTreeError::InvalidProof);
701
702        // Verify the original leaf is still present
703        assert!(tree.contains(&valid_proof, &[b"val_1"]));
704    }
705
706    #[test]
707    fn test_proof_exploit_hash_ordering() {
708        let seeds: &[&[u8]] = &[b"test"];
709        let mut tree = TestTree::new(seeds);
710
711        // Add two leaves to the tree
712        assert!(tree.try_add(&[b"val_1"]).is_ok());
713        assert!(tree.try_add(&[b"val_2"]).is_ok());
714
715        // Create a valid proof for val_1
716        let leaves = [
717            Leaf::new(&[b"val_1"]),
718            Leaf::new(&[b"val_2"]),
719        ];
720        let valid_proof = tree.get_merkle_proof(&leaves, 0);
721        assert_eq!(valid_proof.len(), 3);
722
723        // Craft a malicious proof by swapping the order of hashes
724        // This tests if the hash_left_right ordering (based on byte comparison) can be exploited
725        let mut malicious_proof = valid_proof.clone();
726        // Swap the first two proof elements to disrupt the expected ordering
727        malicious_proof.swap(0, 1);
728
729        // Verify that the malicious proof is rejected
730        assert!(!tree.contains(&malicious_proof, &[b"val_1"]));
731
732        // Attempt to remove using the malicious proof
733        let remove_result = tree.try_remove(&malicious_proof, &[b"val_1"]);
734        assert!(remove_result.is_err());
735        assert_eq!(remove_result.unwrap_err(), BrineTreeError::InvalidProof);
736
737        // Attempt to replace using the malicious proof
738        let replace_result = tree.try_replace(&malicious_proof, &[b"val_1"], &[b"new_val"]);
739        assert!(replace_result.is_err());
740        assert_eq!(replace_result.unwrap_err(), BrineTreeError::InvalidProof);
741
742        // Verify the original leaf is still present
743        assert!(tree.contains(&valid_proof, &[b"val_1"]));
744    }
745
746    // New tests for Empty or Malformed Seeds
747    #[test]
748    fn test_empty_seeds() {
749        let seeds: &[&[u8]] = &[];
750        let tree = TestTree::new(seeds);
751
752        // Verify that the tree is initialized correctly with empty seeds
753        let zeros = tree.zero_values;
754        let mut expected = hashv(seeds); // Should handle empty slice
755        for i in 0..3 {
756            assert_eq!(zeros[i], expected);
757            expected = hashv(&[b"NODE".as_ref(), expected.as_ref(), expected.as_ref()]);
758        }
759
760        // Verify that the root and filled subtrees are set correctly
761        assert_eq!(tree.get_root(), zeros[2]);
762        assert_eq!(tree.filled_subtrees, zeros);
763        assert_eq!(tree.next_index, 0);
764
765        // Test adding a leaf to ensure the tree functions
766        let mut tree = tree;
767        assert!(tree.try_add(&[b"val_1"]).is_ok());
768        assert!(tree.contains(&tree.get_merkle_proof(&[Leaf::new(&[b"val_1"])], 0), &[b"val_1"]));
769    }
770
771    #[test]
772    fn test_malformed_seeds_empty_bytes() {
773        let seeds: &[&[u8]] = &[b"", b""];
774        let tree = TestTree::new(seeds);
775
776        // Verify that the tree is initialized correctly with empty byte arrays
777        let zeros = tree.zero_values;
778        let mut expected = hashv(seeds); // Should handle empty byte arrays
779        for i in 0..3 {
780            assert_eq!(zeros[i], expected);
781            expected = hashv(&[b"NODE".as_ref(), expected.as_ref(), expected.as_ref()]);
782        }
783
784        // Verify that the root and filled subtrees are set correctly
785        assert_eq!(tree.get_root(), zeros[2]);
786        assert_eq!(tree.filled_subtrees, zeros);
787        assert_eq!(tree.next_index, 0);
788
789        // Test adding a leaf to ensure the tree functions
790        let mut tree = tree;
791        assert!(tree.try_add(&[b"val_1"]).is_ok());
792        assert!(tree.contains(&tree.get_merkle_proof(&[Leaf::new(&[b"val_1"])], 0), &[b"val_1"]));
793    }
794
795    #[test]
796    fn test_reinit_with_empty_seeds() {
797        let seeds: &[&[u8]] = &[b"test"];
798        let mut tree = TestTree::new(seeds);
799
800        // Add a leaf to modify the tree
801        assert!(tree.try_add(&[b"val_1"]).is_ok());
802
803        // Reinitialize with empty seeds
804        let empty_seeds: &[&[u8]] = &[];
805        tree.init(empty_seeds);
806
807        // Verify that the tree is reset correctly
808        let zeros = tree.zero_values;
809        let mut expected = hashv(empty_seeds);
810        for i in 0..3 {
811            assert_eq!(zeros[i], expected);
812            expected = hashv(&[b"NODE".as_ref(), expected.as_ref(), expected.as_ref()]);
813        }
814
815        assert_eq!(tree.get_root(), zeros[2]);
816        assert_eq!(tree.filled_subtrees, zeros);
817        assert_eq!(tree.next_index, 0);
818
819        // Test adding a leaf after reinitialization
820        assert!(tree.try_add(&[b"val_1"]).is_ok());
821        assert!(tree.contains(&tree.get_merkle_proof(&[Leaf::new(&[b"val_1"])], 0), &[b"val_1"]));
822    }
823}