brine_tree/
tree.rs

1use super::{
2    error::{BrineTreeError, ProgramResult},
3    hash::{hashv, Hash, Leaf},
4    utils::check_condition,
5};
6use bytemuck::{Pod, Zeroable};
7
8#[repr(C)]
9#[derive(Clone, Copy, PartialEq, Debug)]
10pub struct MerkleTree<const N: usize> {
11    pub root: Hash,
12    pub filled_subtrees: [Hash; N],
13    pub zero_values: [Hash; N],
14    pub next_index: u64,
15}
16
17unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
18unsafe impl<const N: usize> Pod for MerkleTree<N> {}
19
20impl<const N: usize> MerkleTree<N> {
21
22    pub fn new(seeds: &[&[u8]]) -> Self {
23        let zeros = Self::calc_zeros(seeds);
24        Self {
25            next_index: 0,
26            root: zeros[N - 1],
27            filled_subtrees: zeros,
28            zero_values: zeros,
29        }
30    }
31
32    pub const fn get_depth(&self) -> u8 {
33        N as u8
34    }
35
36    pub const fn get_size() -> usize {
37        core::mem::size_of::<Self>()
38    }
39
40    pub fn get_root(&self) -> Hash {
41        self.root
42    }
43
44    pub fn get_empty_leaf(&self) -> Leaf {
45        self.zero_values[0].as_leaf()
46    }
47
48    pub fn init(&mut self, seeds: &[&[u8]]) {
49        let zeros = Self::calc_zeros(seeds);
50        self.next_index = 0;
51        self.root = zeros[N - 1];
52        self.filled_subtrees = zeros;
53        self.zero_values = zeros;
54    }
55
56    /// Returns the number of leaves currently in the Merkle tree.
57    pub fn get_leaf_count(&self) -> u64 {
58        self.next_index
59    }
60
61    /// Returns the maximum capacity of the Merkle tree.
62    pub fn get_capacity(&self) -> u64 {
63        1u64 << N
64    }
65
66    /// Calculates the zero values for the Merkle tree based on the provided seeds.
67    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
68        let mut zeros: [Hash; N] = [Hash::default(); N];
69        let mut current = hashv(seeds);
70
71        for i in 0..N {
72            zeros[i] = current;
73            current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
74        }
75
76        zeros
77    }
78
79    /// Adds a data to the tree, creating a new leaf.
80    pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
81        let leaf = Leaf::new(data);
82        self.try_add_leaf(leaf)
83    }
84
85    /// Adds a leaf to the tree.
86    pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
87        check_condition(self.next_index < (1u64 << N), BrineTreeError::TreeFull)?;
88
89        let mut current_index = self.next_index;
90        let mut current_hash = Hash::from(leaf);
91        let mut left;
92        let mut right;
93
94        for i in 0..N {
95            if current_index % 2 == 0 {
96                left = current_hash;
97                right = self.zero_values[i];
98                self.filled_subtrees[i] = current_hash;
99            } else {
100                left = self.filled_subtrees[i];
101                right = current_hash;
102            }
103
104            current_hash = hash_left_right(left, right);
105            current_index /= 2;
106        }
107
108        self.root = current_hash;
109        self.next_index += 1;
110
111        Ok(())
112    }
113
114    /// Removes a leaf from the tree using the provided proof.
115    pub fn try_remove<P>(&mut self, proof: &[P], data: &[&[u8]]) -> ProgramResult
116    where
117        P: Into<Hash> + Copy,
118    {
119        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
120        let original_leaf = Leaf::new(data);
121        self.try_remove_leaf(&proof_hashes, original_leaf)
122    }
123
124    /// Removes a leaf from the tree using the provided proof.
125    pub fn try_remove_leaf<P>(&mut self, proof: &[P], leaf: Leaf) -> ProgramResult
126    where
127        P: Into<Hash> + Copy,
128    {
129        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
130        self.check_length(&proof_hashes)?;
131        self.try_replace_leaf(&proof_hashes, leaf, self.get_empty_leaf())
132    }
133
134    /// Replaces a leaf in the tree with new data using the provided proof.
135    pub fn try_replace<P>(
136        &mut self,
137        proof: &[P],
138        original_data: &[&[u8]],
139        new_data: &[&[u8]],
140    ) -> ProgramResult
141    where
142        P: Into<Hash> + Copy,
143    {
144        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
145        let original_leaf = Leaf::new(original_data);
146        let new_leaf = Leaf::new(new_data);
147        self.try_replace_leaf(&proof_hashes, original_leaf, new_leaf)
148    }
149
150    /// Replaces a leaf in the tree with a new leaf using the provided proof.
151    pub fn try_replace_leaf<P>(
152        &mut self,
153        proof: &[P],
154        original_leaf: Leaf,
155        new_leaf: Leaf,
156    ) -> ProgramResult
157    where
158        P: Into<Hash> + Copy,
159    {
160        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
161        self.check_length(&proof_hashes)?;
162        let original_path = compute_path(&proof_hashes, original_leaf);
163        let new_path = compute_path(&proof_hashes, new_leaf);
164        check_condition(
165            is_valid_path(&original_path, self.root),
166            BrineTreeError::InvalidProof,
167        )?;
168        for i in 0..N {
169            if original_path[i] == self.filled_subtrees[i] {
170                self.filled_subtrees[i] = new_path[i];
171            }
172        }
173        self.root = *new_path.last().unwrap();
174        Ok(())
175    }
176
177    /// Checks if the proof contains the specified data.
178    pub fn contains<P>(&self, proof: &[P], data: &[&[u8]]) -> bool
179    where
180        P: Into<Hash> + Copy,
181    {
182        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
183        let leaf = Leaf::new(data);
184        self.contains_leaf(&proof_hashes, leaf)
185    }
186
187    /// Checks if the proof contains the specified leaf.
188    pub fn contains_leaf<P>(&self, proof: &[P], leaf: Leaf) -> bool
189    where
190        P: Into<Hash> + Copy,
191    {
192        let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
193        if self.check_length(&proof_hashes).is_err() {
194            return false;
195        }
196        is_valid_leaf(&proof_hashes, self.root, leaf)
197    }
198
199    /// Checks if the proof length matches the expected depth of the tree.
200    fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
201        check_condition(proof.len() == N, BrineTreeError::ProofLength)
202    }
203
204    /// Returns a Merkle proof for a specific leaf in the tree.
205    pub fn get_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
206        get_merkle_proof(leaves, &self.zero_values, leaf_index, N)
207    }
208
209    /// Hashes up to `layer_number` and returns only the non-empty nodes
210    /// on that layer.
211    pub fn get_layer_nodes(&self, leaves: &[Leaf], layer_number: usize) -> Vec<Hash> {
212        if layer_number > N {
213            return vec![];
214        }
215
216        let valid_leaves = leaves
217            .iter()
218            .take(self.next_index as usize)
219            .copied()
220            .collect::<Vec<Leaf>>();
221
222        let mut current_layer: Vec<Hash> =
223            valid_leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
224
225        if current_layer.is_empty() || layer_number == 0 {
226            return current_layer;
227        }
228
229        let mut current_level: usize = 0;
230        loop {
231            if current_layer.is_empty() {
232                break;
233            }
234            let mut next_layer = Vec::with_capacity(current_layer.len().div_ceil(2));
235            let mut i = 0;
236            while i < current_layer.len() {
237                if i + 1 < current_layer.len() {
238                    let val = hash_left_right(current_layer[i], current_layer[i + 1]);
239                    next_layer.push(val);
240                    i += 2;
241                } else {
242                    let val = hash_left_right(current_layer[i], self.zero_values[current_level]);
243                    next_layer.push(val);
244                    i += 1;
245                }
246            }
247            current_level += 1;
248            if current_level == layer_number {
249                return next_layer;
250            }
251            current_layer = next_layer;
252        }
253        vec![]
254    }
255}
256
257/// Returns a Merkle proof for a specific leaf in the tree.
258pub fn get_merkle_proof(
259    leaves: &[Leaf],
260    zero_values: &[Hash],
261    leaf_index: usize,
262    height: usize,
263) -> Vec<Hash> {
264    let mut layers = Vec::with_capacity(height);
265    let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
266
267    for i in 0..height {
268        if current_layer.len() % 2 != 0 {
269            current_layer.push(zero_values[i]);
270        }
271
272        layers.push(current_layer.clone());
273        current_layer = hash_pairs(current_layer);
274    }
275
276    let mut proof = Vec::with_capacity(height);
277    let mut current_index = leaf_index;
278    let mut layer_index = 0;
279
280    for _ in 0..height {
281        let sibling = if current_index % 2 == 0 {
282            layers[layer_index][current_index + 1]
283        } else {
284            layers[layer_index][current_index - 1]
285        };
286
287        proof.push(sibling);
288
289        current_index /= 2;
290        layer_index += 1;
291    }
292
293    proof
294}
295
296/// Hashes pairs of hashes together, returning a new vector of hashes.
297pub fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
298    let mut res = Vec::with_capacity(pairs.len() / 2);
299
300    for i in (0..pairs.len()).step_by(2) {
301        let left = pairs[i];
302        let right = pairs[i + 1];
303
304        let hashed = hash_left_right(left, right);
305        res.push(hashed);
306    }
307
308    res
309}
310
311/// Hashes two hashes together, ensuring a consistent order.
312pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
313    let combined;
314    if left.to_bytes() <= right.to_bytes() {
315        combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
316    } else {
317        combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
318    }
319
320    hashv(&combined)
321}
322
323/// Computes the path from the leaf to the root using the provided proof.
324pub fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
325    let mut computed_path = Vec::with_capacity(proof.len() + 1);
326    let mut computed_hash = Hash::from(leaf);
327
328    computed_path.push(computed_hash);
329
330    for proof_element in proof.iter() {
331        computed_hash = hash_left_right(computed_hash, *proof_element);
332        computed_path.push(computed_hash);
333    }
334
335    computed_path
336}
337
338fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
339    let computed_path = compute_path(proof, leaf);
340    is_valid_path(&computed_path, root)
341}
342
343fn is_valid_path(path: &[Hash], root: Hash) -> bool {
344    if path.is_empty() {
345        return false;
346    }
347
348    *path.last().unwrap() == root
349}
350
351/// Verifies that a given merkle root contains the leaf using the provided proof.
352pub fn verify<Root, Item, L>(root: Root, proof: &[Item], leaf: L) -> bool
353where
354    Root: Into<Hash>,
355    Item: Into<Hash> + Copy,
356    L: Into<Leaf>,
357{
358    let root_h: Hash = root.into();
359    let proof_hashes: Vec<Hash> = proof.iter().map(|&x| x.into()).collect();
360
361    let leaf_h: Leaf = leaf.into();
362    let path = compute_path(&proof_hashes, leaf_h);
363    is_valid_path(&path, root_h)
364}
365
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    type TestTree = MerkleTree<3>;
372
373    #[test]
374    fn test_create_tree() {
375        let seeds: &[&[u8]] = &[b"test"];
376        let tree = TestTree::new(seeds);
377
378        assert_eq!(tree.get_depth(), 3);
379        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
380    }
381
382    #[test]
383    fn test_insert_and_remove() {
384        let seeds: &[&[u8]] = &[b"test"];
385
386        let mut tree = TestTree::new(seeds);
387        let empty = *tree.zero_values.first().unwrap();
388        let empty_leaf = empty.as_leaf();
389
390        // Tree structure:
391        //
392        //              root
393        //            /     \
394        //         m           n
395        //       /   \       /   \
396        //      i     j     k     l
397        //     / \   / \   / \   / \
398        //    a  b  c  d  e  f  g  h
399
400        let a = Hash::from(Leaf::new(&[b"val_1"]));
401        let b = Hash::from(Leaf::new(&[b"val_2"]));
402        let c = Hash::from(Leaf::new(&[b"val_3"]));
403
404        let d = empty;
405        let e = empty;
406        let f = empty;
407        let g = empty;
408        let h = empty;
409
410        let i = hash_left_right(a, b);
411        let j: Hash = hash_left_right(c, d);
412        let k: Hash = hash_left_right(e, f);
413        let l: Hash = hash_left_right(g, h);
414        let m: Hash = hash_left_right(i, j);
415        let n: Hash = hash_left_right(k, l);
416        let root = hash_left_right(m, n);
417
418        assert!(tree.try_add(&[b"val_1"]).is_ok());
419        assert!(tree.filled_subtrees[0].eq(&a));
420
421        assert!(tree.try_add(&[b"val_2"]).is_ok());
422        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
423
424        assert!(tree.try_add(&[b"val_3"]).is_ok());
425        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
426
427        assert_eq!(tree.filled_subtrees[0], c);
428        assert_eq!(tree.filled_subtrees[1], i);
429        assert_eq!(tree.filled_subtrees[2], m);
430        assert_eq!(root, tree.get_root());
431
432        let val1_proof = vec![b, j, n];
433        let val2_proof = vec![a, j, n];
434        let val3_proof = vec![d, i, n];
435
436        // Check filled leaves
437        assert!(tree.contains(&val1_proof, &[b"val_1"]));
438        assert!(tree.contains(&val2_proof, &[b"val_2"]));
439        assert!(tree.contains(&val3_proof, &[b"val_3"]));
440
441        // Check empty leaves
442        assert!(tree.contains_leaf(&[c, i, n], empty_leaf));
443        assert!(tree.contains_leaf(&[f, l, m], empty_leaf));
444        assert!(tree.contains_leaf(&[e, l, m], empty_leaf));
445        assert!(tree.contains_leaf(&[h, k, m], empty_leaf));
446        assert!(tree.contains_leaf(&[g, k, m], empty_leaf));
447
448        // Remove val2 from the tree
449        assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
450
451        // Update the expected tree structure
452        let i = hash_left_right(a, empty);
453        let m: Hash = hash_left_right(i, j);
454        let root = hash_left_right(m, n);
455
456        assert_eq!(root, tree.get_root());
457
458        let val1_proof = vec![empty, j, n];
459        let val3_proof = vec![d, i, n];
460
461        assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
462        assert!(tree.contains_leaf(&val2_proof, empty_leaf));
463        assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
464
465        // Check that val2 is no longer in the tree
466        assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
467
468        // Insert val4 into the tree
469        assert!(tree.try_add(&[b"val_4"]).is_ok());
470        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
471
472        // Update the expected tree structure
473        let d = Hash::from(Leaf::new(&[b"val_4"]));
474        let j = hash_left_right(c, d);
475        let m = hash_left_right(i, j);
476        let root = hash_left_right(m, n);
477
478        assert_eq!(root, tree.get_root());
479    }
480
481    #[test]
482    fn test_proof() {
483        let seeds: &[&[u8]] = &[b"test"];
484
485        let mut tree = TestTree::new(seeds);
486
487        let leaves = [
488            Leaf::new(&[b"val_1"]),
489            Leaf::new(&[b"val_2"]),
490            Leaf::new(&[b"val_3"]),
491        ];
492
493        assert!(tree.try_add(&[b"val_1"]).is_ok());
494        assert!(tree.try_add(&[b"val_2"]).is_ok());
495        assert!(tree.try_add(&[b"val_3"]).is_ok());
496
497        let val1_proof = tree.get_proof(&leaves, 0);
498        let val2_proof = tree.get_proof(&leaves, 1);
499        let val3_proof = tree.get_proof(&leaves, 2);
500
501        assert!(tree.contains(&val1_proof, &[b"val_1"]));
502        assert!(tree.contains(&val2_proof, &[b"val_2"]));
503        assert!(tree.contains(&val3_proof, &[b"val_3"]));
504
505        // Invalid Proof Length
506        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
507        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
508
509        assert!(!tree.contains(invalid_proof_short, &[b"val_1"]));
510        assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
511
512        // Empty Proof
513        let empty_proof: Vec<Hash> = Vec::new();
514        assert!(!tree.contains(&empty_proof, &[b"val_1"]));
515    }
516
517    #[test]
518    fn test_init_and_reinit() {
519        let seeds: &[&[u8]] = &[b"test"];
520        let mut tree = TestTree::new(seeds);
521
522        // Store initial state
523        let initial_root = tree.get_root();
524        let initial_zeros = tree.zero_values;
525        let initial_filled = tree.filled_subtrees;
526        let initial_index = tree.next_index;
527
528        // Add a leaf to modify the tree
529        assert!(tree.try_add(&[b"val_1"]).is_ok());
530
531        // Reinitialize
532        tree.init(seeds);
533
534        // Verify tree is reset to initial state
535        assert_eq!(tree.get_root(), initial_root);
536        assert_eq!(tree.zero_values, initial_zeros);
537        assert_eq!(tree.filled_subtrees, initial_filled);
538        assert_eq!(tree.next_index, initial_index);
539    }
540
541    #[test]
542    fn test_tree_full() {
543        let seeds: &[&[u8]] = &[b"test"];
544        let mut tree = TestTree::new(seeds);
545
546        // Fill the tree (2^3 = 8 leaves)
547        for i in 0u8..8 {
548            assert!(tree.try_add(&[&[i]]).is_ok());
549        }
550
551        // Try to add one more leaf
552        let result = tree.try_add(&[b"extra"]);
553        assert!(result.is_err());
554        assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
555    }
556
557    #[test]
558    fn test_replace_leaf() {
559        let seeds: &[&[u8]] = &[b"test"];
560        let mut tree = TestTree::new(seeds);
561
562        // Add initial leaves
563        assert!(tree.try_add(&[b"val_1"]).is_ok());
564        assert!(tree.try_add(&[b"val_2"]).is_ok());
565
566        // Get proof for val_1
567        let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
568        let proof = tree.get_proof(&leaves, 0);
569
570        // Replace val_1 with new_val
571        assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
572
573        // Verify new leaf is present
574        assert!(tree.contains(&proof, &[b"new_val"]));
575        assert!(!tree.contains(&proof, &[b"val_1"]));
576
577        // Verify val_2 is still present
578        let proof_val2 = tree.get_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
579        assert!(tree.contains(&proof_val2, &[b"val_2"]));
580    }
581
582    #[test]
583    fn test_verify() {
584        let seeds: &[&[u8]] = &[b"test"];
585        let mut tree = TestTree::new(seeds);
586
587        // Add initial leaves
588        assert!(tree.try_add(&[b"val_1"]).is_ok());
589        assert!(tree.try_add(&[b"val_2"]).is_ok());
590
591        // Get proof for val_1
592        let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
593        let proof = tree.get_proof(&leaves, 0);
594
595        // Verify proof (typed)
596        assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
597
598        let a: [u8; 32] = tree.get_root().to_bytes();
599        let b: [[u8; 32]; 3] = [
600            proof[0].to_bytes(),
601            proof[1].to_bytes(),
602            proof[2].to_bytes(),
603        ];
604        let c: [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
605
606        // Verify proof (generic)
607        assert!(verify(a, &b, c));
608    }
609
610    #[test]
611    fn test_get_layer_nodes() {
612        let seeds: &[&[u8]] = &[b"test"];
613        let mut tree = TestTree::new(seeds);
614        let empty = tree.zero_values[0];
615
616        // Define leaves
617        let leaves = [
618            Leaf::new(&[b"val_1"]),
619            Leaf::new(&[b"val_2"]),
620            Leaf::new(&[b"val_3"]),
621            Leaf::new(&[b"val_4"]),
622        ];
623
624        // Test empty tree
625        assert_eq!(tree.get_layer_nodes(&leaves, 0), vec![]);
626        assert_eq!(tree.get_layer_nodes(&leaves, 1), vec![]);
627
628        // Add 3 leaves
629        assert!(tree.try_add(&[b"val_1"]).is_ok());
630        assert!(tree.try_add(&[b"val_2"]).is_ok());
631        assert!(tree.try_add(&[b"val_3"]).is_ok());
632
633        // Expected tree structure:
634        //       root
635        //      /    \
636        //     m      0
637        //    / \    / \
638        //   i   j  0   0
639        //  / \ / \ / \/ \
640        // a  b c d 0 0 0 0
641
642        let a = Hash::from(leaves[0]);
643        let b = Hash::from(leaves[1]);
644        let c = Hash::from(leaves[2]);
645        let d = empty;
646        let i = hash_left_right(a, b);
647        let j = hash_left_right(c, d);
648
649        // Test layer 0 (leaf layer)
650        let layer_0 = tree.get_layer_nodes(&leaves, 0);
651        assert_eq!(layer_0, vec![a, b, c]);
652
653        // Test layer 1
654        let layer_1 = tree.get_layer_nodes(&leaves, 1);
655        assert_eq!(layer_1, vec![i, j]);
656
657        // Test layer 2
658        let layer_2 = tree.get_layer_nodes(&leaves, 2);
659        let m = hash_left_right(i, j);
660        assert_eq!(layer_2, vec![m]);
661
662        // Test layer 3 (root)
663        let layer_3 = tree.get_layer_nodes(&leaves, 3);
664        assert_eq!(layer_3, vec![tree.get_root()]);
665
666        // Test invalid layer
667        let layer_4 = tree.get_layer_nodes(&leaves, 4);
668        assert_eq!(layer_4, vec![]);
669
670        // Add one more leaf to fill a node pair
671        assert!(tree.try_add(&[b"val_4"]).is_ok());
672        let d = Hash::from(leaves[3]);
673        let j = hash_left_right(c, d);
674
675        // Test layer 0 with 4 leaves
676        let layer_0 = tree.get_layer_nodes(&leaves, 0);
677        assert_eq!(layer_0, vec![a, b, c, d]);
678
679        // Test layer 1 with updated j
680        let layer_1 = tree.get_layer_nodes(&leaves, 1);
681        assert_eq!(layer_1, vec![i, j]);
682    }
683}