brine_tree/
tree.rs

1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9    root: Hash,
10    filled_subtrees: [Hash; N],
11    zero_values: [Hash; N],
12    next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19    pub const fn get_depth(&self) -> u8 {
20        N as u8
21    }
22
23    pub const fn get_size() -> usize {
24        core::mem::size_of::<Self>()
25    }
26
27    pub fn get_root(&self) -> Hash {
28        self.root
29    }
30
31    pub fn get_empty_leaf(&self) -> Leaf {
32        self.zero_values[0].as_leaf()
33    }
34
35    pub fn new(seeds: &[&[u8]]) -> Self {
36        let zeros = Self::calc_zeros(seeds);
37        Self {
38            next_index: 0,
39            root: zeros[N - 1],
40            filled_subtrees: zeros,
41            zero_values: zeros,
42        }
43    }
44
45    pub fn init(&mut self, seeds: &[&[u8]]) {
46        let zeros = Self::calc_zeros(seeds);
47        self.next_index = 0;
48        self.root = zeros[N - 1];
49        self.filled_subtrees = zeros;
50        self.zero_values = zeros;
51    }
52
53    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
54        let mut zeros: [Hash; N] = [Hash::default(); N];
55        let mut current = hashv(seeds);
56
57        for i in 0..N {
58            zeros[i] = current;
59            current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
60        }
61
62        zeros
63    }
64
65    pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
66        let leaf = Leaf::new(data);
67        self.try_add_leaf(leaf)
68    }
69
70    pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
71        check_condition(
72            self.next_index < (1u64 << N),
73            BrineTreeError::TreeFull,
74        )?;
75
76        let mut current_index = self.next_index;
77        let mut current_hash = Hash::from(leaf);
78        let mut left;
79        let mut right;
80
81        for i in 0..N {
82            if current_index % 2 == 0 {
83                left = current_hash;
84                right = self.zero_values[i];
85                self.filled_subtrees[i] = current_hash;
86            } else {
87                left = self.filled_subtrees[i];
88                right = current_hash;
89            }
90
91            current_hash = hash_left_right(left, right);
92            current_index /= 2;
93        }
94
95        self.root = current_hash;
96        self.next_index += 1;
97
98        Ok(())
99    }
100
101    pub fn try_remove(&mut self, proof: &[Hash], data: &[&[u8]]) -> ProgramResult {
102        let leaf = Leaf::new(data);
103        self.try_remove_leaf(proof, leaf)
104    }
105
106    pub fn try_remove_leaf(&mut self, proof: &[Hash], leaf: Leaf) -> ProgramResult {
107        self.check_length(proof)?;
108        self.try_replace_leaf(proof, leaf, self.get_empty_leaf())
109    }
110
111    pub fn try_replace(&mut self, proof: &[Hash], original_data: &[&[u8]], new_data: &[&[u8]]) -> ProgramResult {
112        let original_leaf = Leaf::new(original_data);
113        let new_leaf = Leaf::new(new_data);
114        self.try_replace_leaf(proof, original_leaf, new_leaf)
115    }
116
117    pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Leaf, new_leaf: Leaf) -> ProgramResult {
118        self.check_length(proof)?;
119
120        let original_path = compute_path(proof, original_leaf);
121        let new_path = compute_path(proof, new_leaf);
122
123        check_condition(
124            is_valid_path(&original_path, self.root),
125            BrineTreeError::InvalidProof,
126        )?;
127
128        for i in 0..N {
129            if original_path[i] == self.filled_subtrees[i] {
130                self.filled_subtrees[i] = new_path[i];
131            }
132        }
133
134        self.root = *new_path.last().unwrap();
135
136        Ok(())
137    }
138
139    pub fn contains(&self, proof: &[Hash], data: &[&[u8]]) -> bool {
140        let leaf = Leaf::new(data);
141        self.contains_leaf(proof, leaf)
142    }
143
144    pub fn contains_leaf(&self, proof: &[Hash], leaf: Leaf) -> bool {
145        if let Err(_) = self.check_length(proof) {
146            return false;
147        }
148
149        let root = self.get_root();
150        is_valid_leaf(proof, root, leaf)
151    }
152
153    fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
154        check_condition(proof.len() == N, BrineTreeError::ProofLength)
155    }
156
157    #[cfg(not(feature = "solana"))]
158    fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
159        let mut res = Vec::with_capacity(pairs.len() / 2);
160
161        for i in (0..pairs.len()).step_by(2) {
162            let left = pairs[i];
163            let right = pairs[i + 1];
164
165            let hashed = hash_left_right(left, right);
166            res.push(hashed);
167        }
168
169        res
170    }
171
172    #[cfg(not(feature = "solana"))]
173    pub fn get_merkle_proof(&self, values: &[Leaf], index: usize) -> Vec<Hash> {
174        let mut layers = Vec::with_capacity(N);
175        let mut current_layer: Vec<Hash> = values.iter().map(|leaf| Hash::from(*leaf)).collect();
176
177        for i in 0..N {
178            if current_layer.len() % 2 != 0 {
179                current_layer.push(self.zero_values[i]);
180            }
181
182            layers.push(current_layer.clone());
183            current_layer = Self::hash_pairs(current_layer);
184        }
185
186        let mut proof = Vec::with_capacity(N);
187        let mut current_index = index;
188        let mut layer_index = 0;
189        let mut sibling;
190
191        for _ in 0..N {
192            if current_index % 2 == 0 {
193                sibling = layers[layer_index][current_index + 1];
194            } else {
195                sibling = layers[layer_index][current_index - 1];
196            }
197
198            proof.push(sibling);
199
200            current_index /= 2;
201            layer_index += 1;
202        }
203
204        proof
205    }
206}
207
208fn hash_left_right(left: Hash, right: Hash) -> Hash {
209    let combined;
210    if left.to_bytes() <= right.to_bytes() {
211        combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
212    } else {
213        combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
214    }
215
216    hashv(&combined)
217}
218
219fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
220    let mut computed_path = Vec::with_capacity(proof.len() + 1);
221    let mut computed_hash = Hash::from(leaf);
222
223    computed_path.push(computed_hash);
224
225    for proof_element in proof.iter() {
226        computed_hash = hash_left_right(computed_hash, *proof_element);
227        computed_path.push(computed_hash);
228    }
229
230    computed_path
231}
232
233fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
234    let computed_path = compute_path(proof, leaf);
235    is_valid_path(&computed_path, root)
236}
237
238fn is_valid_path(path: &[Hash], root: Hash) -> bool {
239    if path.is_empty() {
240        return false;
241    }
242
243    *path.last().unwrap() == root
244}
245
246/// Verifies that a given merkle root contains the leaf using the provided proof.
247pub fn verify<const N: usize>(root: Hash, proof: &[Hash], leaf: Leaf) -> bool {
248    if proof.len() != N {
249        return false;
250    }
251
252    let computed_path = compute_path(proof, leaf);
253    is_valid_path(&computed_path, root)
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    type TestTree = MerkleTree<3>;
261
262    #[test]
263    fn test_create_tree() {
264        let seeds: &[&[u8]] = &[b"test"];
265        let tree = TestTree::new(seeds);
266
267        assert_eq!(tree.get_depth(), 3);
268        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
269    }
270
271    #[test]
272    fn test_insert_and_remove() {
273        let seeds: &[&[u8]] = &[b"test"];
274
275        let mut tree = TestTree::new(seeds);
276        let empty = tree.zero_values.first().unwrap().clone();
277        let empty_leaf = empty.as_leaf();
278
279        // Tree structure:
280        // 
281        //              root
282        //            /     \
283        //         m           n
284        //       /   \       /   \
285        //      i     j     k     l
286        //     / \   / \   / \   / \
287        //    a  b  c  d  e  f  g  h
288
289        let a = Hash::from(Leaf::new(&[b"val_1"]));
290        let b = Hash::from(Leaf::new(&[b"val_2"]));
291        let c = Hash::from(Leaf::new(&[b"val_3"]));
292
293        let d = empty.clone();
294        let e = empty.clone();
295        let f = empty.clone();
296        let g = empty.clone();
297        let h = empty.clone();
298
299        let i = hash_left_right(a, b);
300        let j: Hash = hash_left_right(c, d);
301        let k: Hash = hash_left_right(e, f);
302        let l: Hash = hash_left_right(g, h);
303        let m: Hash = hash_left_right(i, j);
304        let n: Hash = hash_left_right(k, l);
305        let root = hash_left_right(m, n);
306
307        assert!(tree.try_add(&[b"val_1"]).is_ok());
308        assert!(tree.filled_subtrees[0].eq(&a));
309
310        assert!(tree.try_add(&[b"val_2"]).is_ok());
311        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
312
313        assert!(tree.try_add(&[b"val_3"]).is_ok());
314        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
315
316        assert_eq!(tree.filled_subtrees[0], c);
317        assert_eq!(tree.filled_subtrees[1], i);
318        assert_eq!(tree.filled_subtrees[2], m);
319        assert_eq!(root, tree.get_root());
320
321        let val1_proof = vec![b.clone(), j.clone(), n.clone()];
322        let val2_proof = vec![a.clone(), j.clone(), n.clone()];
323        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
324
325        // Check filled leaves
326        assert!(tree.contains(&val1_proof, &[b"val_1"]));
327        assert!(tree.contains(&val2_proof, &[b"val_2"]));
328        assert!(tree.contains(&val3_proof, &[b"val_3"]));
329
330        // Check empty leaves
331        assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
332        assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
333        assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
334        assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
335        assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
336
337        // Remove val2 from the tree
338        assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
339
340        // Update the expected tree structure
341        let i = hash_left_right(a, empty);
342        let m: Hash = hash_left_right(i, j);
343        let root = hash_left_right(m, n);
344
345        assert_eq!(root, tree.get_root());
346
347        let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
348        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
349
350        assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
351        assert!(tree.contains_leaf(&val2_proof, empty_leaf));
352        assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
353
354        // Check that val2 is no longer in the tree
355        assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
356
357        // Insert val4 into the tree
358        assert!(tree.try_add(&[b"val_4"]).is_ok());
359        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
360
361        // Update the expected tree structure
362        let d = Hash::from(Leaf::new(&[b"val_4"]));
363        let j = hash_left_right(c, d);
364        let m = hash_left_right(i, j);
365        let root = hash_left_right(m, n);
366
367        assert_eq!(root, tree.get_root());
368    }
369
370    #[test]
371    fn test_proof() {
372        let seeds: &[&[u8]] = &[b"test"];
373
374        let mut tree = TestTree::new(seeds);
375
376        let leaves = [
377            Leaf::new(&[b"val_1"]),
378            Leaf::new(&[b"val_2"]),
379            Leaf::new(&[b"val_3"]),
380        ];
381
382        assert!(tree.try_add(&[b"val_1"]).is_ok());
383        assert!(tree.try_add(&[b"val_2"]).is_ok());
384        assert!(tree.try_add(&[b"val_3"]).is_ok());
385
386        let val1_proof = tree.get_merkle_proof(&leaves, 0);
387        let val2_proof = tree.get_merkle_proof(&leaves, 1);
388        let val3_proof = tree.get_merkle_proof(&leaves, 2);
389
390        assert!(tree.contains(&val1_proof, &[b"val_1"]));
391        assert!(tree.contains(&val2_proof, &[b"val_2"]));
392        assert!(tree.contains(&val3_proof, &[b"val_3"]));
393
394        // Invalid Proof Length
395        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
396        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
397
398        assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
399        assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
400
401        // Empty Proof
402        let empty_proof: Vec<Hash> = Vec::new();
403        assert!(!tree.contains(&empty_proof, &[b"val_1"]));
404    }
405
406    #[test]
407    fn test_init_and_reinit() {
408        let seeds: &[&[u8]] = &[b"test"];
409        let mut tree = TestTree::new(seeds);
410        
411        // Store initial state
412        let initial_root = tree.get_root();
413        let initial_zeros = tree.zero_values;
414        let initial_filled = tree.filled_subtrees;
415        let initial_index = tree.next_index;
416
417        // Add a leaf to modify the tree
418        assert!(tree.try_add(&[b"val_1"]).is_ok());
419        
420        // Reinitialize
421        tree.init(seeds);
422        
423        // Verify tree is reset to initial state
424        assert_eq!(tree.get_root(), initial_root);
425        assert_eq!(tree.zero_values, initial_zeros);
426        assert_eq!(tree.filled_subtrees, initial_filled);
427        assert_eq!(tree.next_index, initial_index);
428    }
429
430    #[test]
431    fn test_tree_full() {
432        let seeds: &[&[u8]] = &[b"test"];
433        let mut tree = TestTree::new(seeds);
434        
435        // Fill the tree (2^3 = 8 leaves)
436        for i in 0u8..8 {
437            assert!(tree.try_add(&[&[i]]).is_ok());
438        }
439        
440        // Try to add one more leaf
441        let result = tree.try_add(&[b"extra"]);
442        assert!(result.is_err());
443        assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
444    }
445
446    #[test]
447    fn test_replace_leaf() {
448        let seeds: &[&[u8]] = &[b"test"];
449        let mut tree = TestTree::new(seeds);
450        
451        // Add initial leaves
452        assert!(tree.try_add(&[b"val_1"]).is_ok());
453        assert!(tree.try_add(&[b"val_2"]).is_ok());
454        
455        // Get proof for val_1
456        let leaves = [
457            Leaf::new(&[b"val_1"]),
458            Leaf::new(&[b"val_2"]),
459        ];
460        let proof = tree.get_merkle_proof(&leaves, 0);
461        
462        // Replace val_1 with new_val
463        assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
464        
465        // Verify new leaf is present
466        assert!(tree.contains(&proof, &[b"new_val"]));
467        assert!(!tree.contains(&proof, &[b"val_1"]));
468        
469        // Verify val_2 is still present
470        let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
471        assert!(tree.contains(&proof_val2, &[b"val_2"]));
472    }
473
474    #[test]
475    fn test_verify() {
476        let seeds: &[&[u8]] = &[b"test"];
477        let mut tree = TestTree::new(seeds);
478        
479        // Add initial leaves
480        assert!(tree.try_add(&[b"val_1"]).is_ok());
481        assert!(tree.try_add(&[b"val_2"]).is_ok());
482        
483        // Get proof for val_1
484        let leaves = [
485            Leaf::new(&[b"val_1"]),
486            Leaf::new(&[b"val_2"]),
487        ];
488        let proof = tree.get_merkle_proof(&leaves, 0);
489        
490        // Verify the proof
491        assert!(verify::<3>(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
492    }
493}