brine_tree/
tree.rs

1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9    root: Hash,
10    filled_subtrees: [Hash; N],
11    zero_values: [Hash; N],
12    next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19    pub const fn get_depth(&self) -> u8 {
20        N as u8
21    }
22
23    pub const fn get_size() -> usize {
24        core::mem::size_of::<Self>()
25    }
26
27    pub fn get_root(&self) -> Hash {
28        self.root
29    }
30
31    pub fn get_empty_leaf(&self) -> Leaf {
32        self.zero_values[0].as_leaf()
33    }
34
35    pub fn new(seeds: &[&[u8]]) -> Self {
36        let zeros = Self::calc_zeros(seeds);
37        Self {
38            next_index: 0,
39            root: zeros[N - 1],
40            filled_subtrees: zeros,
41            zero_values: zeros,
42        }
43    }
44
45    pub fn init(&mut self, seeds: &[&[u8]]) {
46        let zeros = Self::calc_zeros(seeds);
47        self.next_index = 0;
48        self.root = zeros[N - 1];
49        self.filled_subtrees = zeros;
50        self.zero_values = zeros;
51    }
52
53    pub fn get_leaf_count(&self) -> u64 {
54        self.next_index
55    }
56
57    pub fn get_capacity(&self) -> u64 {
58        1u64 << N
59    }
60
61    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
62        let mut zeros: [Hash; N] = [Hash::default(); N];
63        let mut current = hashv(seeds);
64
65        for i in 0..N {
66            zeros[i] = current;
67            current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
68        }
69
70        zeros
71    }
72
73    pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
74        let leaf = Leaf::new(data);
75        self.try_add_leaf(leaf)
76    }
77
78    pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
79        check_condition(
80            self.next_index < (1u64 << N),
81            BrineTreeError::TreeFull,
82        )?;
83
84        let mut current_index = self.next_index;
85        let mut current_hash = Hash::from(leaf);
86        let mut left;
87        let mut right;
88
89        for i in 0..N {
90            if current_index % 2 == 0 {
91                left = current_hash;
92                right = self.zero_values[i];
93                self.filled_subtrees[i] = current_hash;
94            } else {
95                left = self.filled_subtrees[i];
96                right = current_hash;
97            }
98
99            current_hash = hash_left_right(left, right);
100            current_index /= 2;
101        }
102
103        self.root = current_hash;
104        self.next_index += 1;
105
106        Ok(())
107    }
108
109    pub fn try_remove<P>(
110        &mut self,
111        proof: &[P],
112        data: &[&[u8]],
113    ) -> ProgramResult
114    where
115        P: Into<Hash> + Copy,
116    {
117        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
118        let original_leaf = Leaf::new(data);
119        self.try_remove_leaf(&proof_hashes, original_leaf)
120    }
121
122    pub fn try_remove_leaf<P>(
123        &mut self,
124        proof: &[P],
125        leaf: Leaf,
126    ) -> ProgramResult
127    where
128        P: Into<Hash> + Copy,
129    {
130        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
131        self.check_length(&proof_hashes)?;
132        self.try_replace_leaf(&proof_hashes, leaf, self.get_empty_leaf())
133    }
134
135    pub fn try_replace<P>(
136        &mut self,
137        proof: &[P],
138        original_data: &[&[u8]],
139        new_data: &[&[u8]],
140    ) -> ProgramResult
141    where
142        P: Into<Hash> + Copy,
143    {
144        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
145        let original_leaf = Leaf::new(original_data);
146        let new_leaf = Leaf::new(new_data);
147        self.try_replace_leaf(&proof_hashes, original_leaf, new_leaf)
148    }
149
150    pub fn try_replace_leaf<P>(
151        &mut self,
152        proof: &[P],
153        original_leaf: Leaf,
154        new_leaf: Leaf,
155    ) -> ProgramResult
156    where
157        P: Into<Hash> + Copy,
158    {
159        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
160        self.check_length(&proof_hashes)?;
161        let original_path = compute_path(&proof_hashes, original_leaf);
162        let new_path = compute_path(&proof_hashes, new_leaf);
163        check_condition(
164            is_valid_path(&original_path, self.root),
165            BrineTreeError::InvalidProof,
166        )?;
167        for i in 0..N {
168            if original_path[i] == self.filled_subtrees[i] {
169                self.filled_subtrees[i] = new_path[i];
170            }
171        }
172        self.root = *new_path.last().unwrap();
173        Ok(())
174    }
175
176    pub fn contains<P>(
177        &self,
178        proof: &[P],
179        data: &[&[u8]],
180    ) -> bool
181    where
182        P: Into<Hash> + Copy,
183    {
184        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
185        let leaf = Leaf::new(data);
186        self.contains_leaf(&proof_hashes, leaf)
187    }
188
189    pub fn contains_leaf<P>(
190        &self,
191        proof: &[P],
192        leaf: Leaf,
193    ) -> bool
194    where
195        P: Into<Hash> + Copy,
196    {
197        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
198        if self.check_length(&proof_hashes).is_err() {
199            return false;
200        }
201        is_valid_leaf(&proof_hashes, self.root, leaf)
202    }
203
204    fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
205        check_condition(proof.len() == N, BrineTreeError::ProofLength)
206    }
207
208    fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
209        let mut res = Vec::with_capacity(pairs.len() / 2);
210
211        for i in (0..pairs.len()).step_by(2) {
212            let left = pairs[i];
213            let right = pairs[i + 1];
214
215            let hashed = hash_left_right(left, right);
216            res.push(hashed);
217        }
218
219        res
220    }
221
222    pub fn get_merkle_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
223        let mut layers = Vec::with_capacity(N);
224        let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
225
226        for i in 0..N {
227            if current_layer.len() % 2 != 0 {
228                current_layer.push(self.zero_values[i]);
229            }
230
231            layers.push(current_layer.clone());
232            current_layer = Self::hash_pairs(current_layer);
233        }
234
235        let mut proof = Vec::with_capacity(N);
236        let mut current_index = leaf_index;
237        let mut layer_index = 0;
238        let mut sibling;
239
240        for _ in 0..N {
241            if current_index % 2 == 0 {
242                sibling = layers[layer_index][current_index + 1];
243            } else {
244                sibling = layers[layer_index][current_index - 1];
245            }
246
247            proof.push(sibling);
248
249            current_index /= 2;
250            layer_index += 1;
251        }
252
253        proof
254    }
255}
256
257fn hash_left_right(left: Hash, right: Hash) -> Hash {
258    let combined;
259    if left.to_bytes() <= right.to_bytes() {
260        combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
261    } else {
262        combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
263    }
264
265    hashv(&combined)
266}
267
268fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
269    let mut computed_path = Vec::with_capacity(proof.len() + 1);
270    let mut computed_hash = Hash::from(leaf);
271
272    computed_path.push(computed_hash);
273
274    for proof_element in proof.iter() {
275        computed_hash = hash_left_right(computed_hash, *proof_element);
276        computed_path.push(computed_hash);
277    }
278
279    computed_path
280}
281
282fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
283    let computed_path = compute_path(proof, leaf);
284    is_valid_path(&computed_path, root)
285}
286
287fn is_valid_path(path: &[Hash], root: Hash) -> bool {
288    if path.is_empty() {
289        return false;
290    }
291
292    *path.last().unwrap() == root
293}
294
295/// Verifies that a given merkle root contains the leaf using the provided proof.
296pub fn verify<Root, Item, L>(
297    root: Root,
298    proof: &[Item],
299    leaf: L,
300) -> bool
301where
302    Root: Into<Hash>,
303    Item: Into<Hash> + Copy,
304    L: Into<Leaf>,
305{
306    let root_h: Hash = root.into();
307    let proof_hashes: Vec<Hash> = 
308        proof.iter()
309          .map(|&x| x.into())
310          .collect();
311
312    let leaf_h: Leaf = leaf.into();
313    let path = compute_path(&proof_hashes, leaf_h);
314    is_valid_path(&path, root_h)
315}
316
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    type TestTree = MerkleTree<3>;
323
324    #[test]
325    fn test_create_tree() {
326        let seeds: &[&[u8]] = &[b"test"];
327        let tree = TestTree::new(seeds);
328
329        assert_eq!(tree.get_depth(), 3);
330        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
331    }
332
333    #[test]
334    fn test_insert_and_remove() {
335        let seeds: &[&[u8]] = &[b"test"];
336
337        let mut tree = TestTree::new(seeds);
338        let empty = tree.zero_values.first().unwrap().clone();
339        let empty_leaf = empty.as_leaf();
340
341        // Tree structure:
342        // 
343        //              root
344        //            /     \
345        //         m           n
346        //       /   \       /   \
347        //      i     j     k     l
348        //     / \   / \   / \   / \
349        //    a  b  c  d  e  f  g  h
350
351        let a = Hash::from(Leaf::new(&[b"val_1"]));
352        let b = Hash::from(Leaf::new(&[b"val_2"]));
353        let c = Hash::from(Leaf::new(&[b"val_3"]));
354
355        let d = empty.clone();
356        let e = empty.clone();
357        let f = empty.clone();
358        let g = empty.clone();
359        let h = empty.clone();
360
361        let i = hash_left_right(a, b);
362        let j: Hash = hash_left_right(c, d);
363        let k: Hash = hash_left_right(e, f);
364        let l: Hash = hash_left_right(g, h);
365        let m: Hash = hash_left_right(i, j);
366        let n: Hash = hash_left_right(k, l);
367        let root = hash_left_right(m, n);
368
369        assert!(tree.try_add(&[b"val_1"]).is_ok());
370        assert!(tree.filled_subtrees[0].eq(&a));
371
372        assert!(tree.try_add(&[b"val_2"]).is_ok());
373        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
374
375        assert!(tree.try_add(&[b"val_3"]).is_ok());
376        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
377
378        assert_eq!(tree.filled_subtrees[0], c);
379        assert_eq!(tree.filled_subtrees[1], i);
380        assert_eq!(tree.filled_subtrees[2], m);
381        assert_eq!(root, tree.get_root());
382
383        let val1_proof = vec![b.clone(), j.clone(), n.clone()];
384        let val2_proof = vec![a.clone(), j.clone(), n.clone()];
385        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
386
387        // Check filled leaves
388        assert!(tree.contains(&val1_proof, &[b"val_1"]));
389        assert!(tree.contains(&val2_proof, &[b"val_2"]));
390        assert!(tree.contains(&val3_proof, &[b"val_3"]));
391
392        // Check empty leaves
393        assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
394        assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
395        assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
396        assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
397        assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
398
399        // Remove val2 from the tree
400        assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
401
402        // Update the expected tree structure
403        let i = hash_left_right(a, empty);
404        let m: Hash = hash_left_right(i, j);
405        let root = hash_left_right(m, n);
406
407        assert_eq!(root, tree.get_root());
408
409        let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
410        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
411
412        assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
413        assert!(tree.contains_leaf(&val2_proof, empty_leaf));
414        assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
415
416        // Check that val2 is no longer in the tree
417        assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
418
419        // Insert val4 into the tree
420        assert!(tree.try_add(&[b"val_4"]).is_ok());
421        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
422
423        // Update the expected tree structure
424        let d = Hash::from(Leaf::new(&[b"val_4"]));
425        let j = hash_left_right(c, d);
426        let m = hash_left_right(i, j);
427        let root = hash_left_right(m, n);
428
429        assert_eq!(root, tree.get_root());
430    }
431
432    #[test]
433    fn test_proof() {
434        let seeds: &[&[u8]] = &[b"test"];
435
436        let mut tree = TestTree::new(seeds);
437
438        let leaves = [
439            Leaf::new(&[b"val_1"]),
440            Leaf::new(&[b"val_2"]),
441            Leaf::new(&[b"val_3"]),
442        ];
443
444        assert!(tree.try_add(&[b"val_1"]).is_ok());
445        assert!(tree.try_add(&[b"val_2"]).is_ok());
446        assert!(tree.try_add(&[b"val_3"]).is_ok());
447
448        let val1_proof = tree.get_merkle_proof(&leaves, 0);
449        let val2_proof = tree.get_merkle_proof(&leaves, 1);
450        let val3_proof = tree.get_merkle_proof(&leaves, 2);
451
452        assert!(tree.contains(&val1_proof, &[b"val_1"]));
453        assert!(tree.contains(&val2_proof, &[b"val_2"]));
454        assert!(tree.contains(&val3_proof, &[b"val_3"]));
455
456        // Invalid Proof Length
457        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
458        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
459
460        assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
461        assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
462
463        // Empty Proof
464        let empty_proof: Vec<Hash> = Vec::new();
465        assert!(!tree.contains(&empty_proof, &[b"val_1"]));
466    }
467
468    #[test]
469    fn test_init_and_reinit() {
470        let seeds: &[&[u8]] = &[b"test"];
471        let mut tree = TestTree::new(seeds);
472        
473        // Store initial state
474        let initial_root = tree.get_root();
475        let initial_zeros = tree.zero_values;
476        let initial_filled = tree.filled_subtrees;
477        let initial_index = tree.next_index;
478
479        // Add a leaf to modify the tree
480        assert!(tree.try_add(&[b"val_1"]).is_ok());
481        
482        // Reinitialize
483        tree.init(seeds);
484        
485        // Verify tree is reset to initial state
486        assert_eq!(tree.get_root(), initial_root);
487        assert_eq!(tree.zero_values, initial_zeros);
488        assert_eq!(tree.filled_subtrees, initial_filled);
489        assert_eq!(tree.next_index, initial_index);
490    }
491
492    #[test]
493    fn test_tree_full() {
494        let seeds: &[&[u8]] = &[b"test"];
495        let mut tree = TestTree::new(seeds);
496        
497        // Fill the tree (2^3 = 8 leaves)
498        for i in 0u8..8 {
499            assert!(tree.try_add(&[&[i]]).is_ok());
500        }
501        
502        // Try to add one more leaf
503        let result = tree.try_add(&[b"extra"]);
504        assert!(result.is_err());
505        assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
506    }
507
508    #[test]
509    fn test_replace_leaf() {
510        let seeds: &[&[u8]] = &[b"test"];
511        let mut tree = TestTree::new(seeds);
512        
513        // Add initial leaves
514        assert!(tree.try_add(&[b"val_1"]).is_ok());
515        assert!(tree.try_add(&[b"val_2"]).is_ok());
516        
517        // Get proof for val_1
518        let leaves = [
519            Leaf::new(&[b"val_1"]),
520            Leaf::new(&[b"val_2"]),
521        ];
522        let proof = tree.get_merkle_proof(&leaves, 0);
523        
524        // Replace val_1 with new_val
525        assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
526        
527        // Verify new leaf is present
528        assert!(tree.contains(&proof, &[b"new_val"]));
529        assert!(!tree.contains(&proof, &[b"val_1"]));
530        
531        // Verify val_2 is still present
532        let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
533        assert!(tree.contains(&proof_val2, &[b"val_2"]));
534    }
535
536    #[test]
537    fn test_verify() {
538        let seeds: &[&[u8]] = &[b"test"];
539        let mut tree = TestTree::new(seeds);
540        
541        // Add initial leaves
542        assert!(tree.try_add(&[b"val_1"]).is_ok());
543        assert!(tree.try_add(&[b"val_2"]).is_ok());
544        
545        // Get proof for val_1
546        let leaves = [
547            Leaf::new(&[b"val_1"]),
548            Leaf::new(&[b"val_2"]),
549        ];
550        let proof = tree.get_merkle_proof(&leaves, 0);
551        
552        // Verify proof (typed)
553        assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
554
555        let a : [u8; 32] = tree.get_root().to_bytes();
556        let b : [[u8; 32]; 3] = [
557            proof[0].to_bytes(),
558            proof[1].to_bytes(),
559            proof[2].to_bytes(),
560        ];
561        let c : [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
562
563        // Verify proof (generic)
564        assert!(verify(a, &b, c));
565    }
566}