brine_tree/
tree.rs

1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9    root: Hash,
10    filled_subtrees: [Hash; N],
11    zero_values: [Hash; N],
12    next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19    pub const fn get_depth(&self) -> u8 {
20        N as u8
21    }
22
23    pub const fn get_size() -> usize {
24        core::mem::size_of::<Self>()
25    }
26
27    pub fn get_root(&self) -> Hash {
28        self.root
29    }
30
31    pub fn get_empty_leaf(&self) -> Leaf {
32        self.zero_values[0].as_leaf()
33    }
34
35    pub fn new(seeds: &[&[u8]]) -> Self {
36        let zeros = Self::calc_zeros(seeds);
37        Self {
38            next_index: 0,
39            root: zeros[N - 1],
40            filled_subtrees: zeros,
41            zero_values: zeros,
42        }
43    }
44
45    pub fn init(&mut self, seeds: &[&[u8]]) {
46        let zeros = Self::calc_zeros(seeds);
47        self.next_index = 0;
48        self.root = zeros[N - 1];
49        self.filled_subtrees = zeros;
50        self.zero_values = zeros;
51    }
52
53    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
54        let mut zeros: [Hash; N] = [Hash::default(); N];
55        let mut current = hashv(seeds);
56
57        for i in 0..N {
58            zeros[i] = current;
59            current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
60        }
61
62        zeros
63    }
64
65    pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
66        let leaf = Leaf::new(data);
67        self.try_add_leaf(leaf)
68    }
69
70    pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
71        check_condition(
72            self.next_index < (1u64 << N),
73            BrineTreeError::TreeFull,
74        )?;
75
76        let mut current_index = self.next_index;
77        let mut current_hash = Hash::from(leaf);
78        let mut left;
79        let mut right;
80
81        for i in 0..N {
82            if current_index % 2 == 0 {
83                left = current_hash;
84                right = self.zero_values[i];
85                self.filled_subtrees[i] = current_hash;
86            } else {
87                left = self.filled_subtrees[i];
88                right = current_hash;
89            }
90
91            current_hash = hash_left_right(left, right);
92            current_index /= 2;
93        }
94
95        self.root = current_hash;
96        self.next_index += 1;
97
98        Ok(())
99    }
100
101    pub fn try_remove(&mut self, proof: &[Hash], data: &[&[u8]]) -> ProgramResult {
102        let leaf = Leaf::new(data);
103        self.try_remove_leaf(proof, leaf)
104    }
105
106    pub fn try_remove_leaf(&mut self, proof: &[Hash], leaf: Leaf) -> ProgramResult {
107        self.check_length(proof)?;
108        self.try_replace_leaf(proof, leaf, self.get_empty_leaf())
109    }
110
111    pub fn try_replace(&mut self, proof: &[Hash], original_data: &[&[u8]], new_data: &[&[u8]]) -> ProgramResult {
112        let original_leaf = Leaf::new(original_data);
113        let new_leaf = Leaf::new(new_data);
114        self.try_replace_leaf(proof, original_leaf, new_leaf)
115    }
116
117    pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Leaf, new_leaf: Leaf) -> ProgramResult {
118        self.check_length(proof)?;
119
120        let original_path = compute_path(proof, original_leaf);
121        let new_path = compute_path(proof, new_leaf);
122
123        check_condition(
124            is_valid_path(&original_path, self.root),
125            BrineTreeError::InvalidProof,
126        )?;
127
128        for i in 0..N {
129            if original_path[i] == self.filled_subtrees[i] {
130                self.filled_subtrees[i] = new_path[i];
131            }
132        }
133
134        self.root = *new_path.last().unwrap();
135
136        Ok(())
137    }
138
139    pub fn contains(&self, proof: &[Hash], data: &[&[u8]]) -> bool {
140        let leaf = Leaf::new(data);
141        self.contains_leaf(proof, leaf)
142    }
143
144    pub fn contains_leaf(&self, proof: &[Hash], leaf: Leaf) -> bool {
145        if let Err(_) = self.check_length(proof) {
146            return false;
147        }
148
149        let root = self.get_root();
150        is_valid_leaf(proof, root, leaf)
151    }
152
153    fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
154        check_condition(proof.len() == N, BrineTreeError::ProofLength)
155    }
156
157    fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
158        let mut res = Vec::with_capacity(pairs.len() / 2);
159
160        for i in (0..pairs.len()).step_by(2) {
161            let left = pairs[i];
162            let right = pairs[i + 1];
163
164            let hashed = hash_left_right(left, right);
165            res.push(hashed);
166        }
167
168        res
169    }
170
171    pub fn get_merkle_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
172        let mut layers = Vec::with_capacity(N);
173        let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
174
175        for i in 0..N {
176            if current_layer.len() % 2 != 0 {
177                current_layer.push(self.zero_values[i]);
178            }
179
180            layers.push(current_layer.clone());
181            current_layer = Self::hash_pairs(current_layer);
182        }
183
184        let mut proof = Vec::with_capacity(N);
185        let mut current_index = leaf_index;
186        let mut layer_index = 0;
187        let mut sibling;
188
189        for _ in 0..N {
190            if current_index % 2 == 0 {
191                sibling = layers[layer_index][current_index + 1];
192            } else {
193                sibling = layers[layer_index][current_index - 1];
194            }
195
196            proof.push(sibling);
197
198            current_index /= 2;
199            layer_index += 1;
200        }
201
202        proof
203    }
204}
205
206fn hash_left_right(left: Hash, right: Hash) -> Hash {
207    let combined;
208    if left.to_bytes() <= right.to_bytes() {
209        combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
210    } else {
211        combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
212    }
213
214    hashv(&combined)
215}
216
217fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
218    let mut computed_path = Vec::with_capacity(proof.len() + 1);
219    let mut computed_hash = Hash::from(leaf);
220
221    computed_path.push(computed_hash);
222
223    for proof_element in proof.iter() {
224        computed_hash = hash_left_right(computed_hash, *proof_element);
225        computed_path.push(computed_hash);
226    }
227
228    computed_path
229}
230
231fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
232    let computed_path = compute_path(proof, leaf);
233    is_valid_path(&computed_path, root)
234}
235
236fn is_valid_path(path: &[Hash], root: Hash) -> bool {
237    if path.is_empty() {
238        return false;
239    }
240
241    *path.last().unwrap() == root
242}
243
244/// Verifies that a given merkle root contains the leaf using the provided proof.
245pub fn verify<Root, Item, L>(
246    root: Root,
247    proof: &[Item],
248    leaf: L,
249) -> bool
250where
251    Root: Into<Hash>,
252    Item: Into<Hash> + Copy,
253    L: Into<Leaf>,
254{
255    let root_h: Hash = root.into();
256    let proof_hashes: Vec<Hash> = 
257        proof.iter()
258          .map(|&x| x.into())
259          .collect();
260
261    let leaf_h: Leaf = leaf.into();
262    let path = compute_path(&proof_hashes, leaf_h);
263    is_valid_path(&path, root_h)
264}
265
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    type TestTree = MerkleTree<3>;
272
273    #[test]
274    fn test_create_tree() {
275        let seeds: &[&[u8]] = &[b"test"];
276        let tree = TestTree::new(seeds);
277
278        assert_eq!(tree.get_depth(), 3);
279        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
280    }
281
282    #[test]
283    fn test_insert_and_remove() {
284        let seeds: &[&[u8]] = &[b"test"];
285
286        let mut tree = TestTree::new(seeds);
287        let empty = tree.zero_values.first().unwrap().clone();
288        let empty_leaf = empty.as_leaf();
289
290        // Tree structure:
291        // 
292        //              root
293        //            /     \
294        //         m           n
295        //       /   \       /   \
296        //      i     j     k     l
297        //     / \   / \   / \   / \
298        //    a  b  c  d  e  f  g  h
299
300        let a = Hash::from(Leaf::new(&[b"val_1"]));
301        let b = Hash::from(Leaf::new(&[b"val_2"]));
302        let c = Hash::from(Leaf::new(&[b"val_3"]));
303
304        let d = empty.clone();
305        let e = empty.clone();
306        let f = empty.clone();
307        let g = empty.clone();
308        let h = empty.clone();
309
310        let i = hash_left_right(a, b);
311        let j: Hash = hash_left_right(c, d);
312        let k: Hash = hash_left_right(e, f);
313        let l: Hash = hash_left_right(g, h);
314        let m: Hash = hash_left_right(i, j);
315        let n: Hash = hash_left_right(k, l);
316        let root = hash_left_right(m, n);
317
318        assert!(tree.try_add(&[b"val_1"]).is_ok());
319        assert!(tree.filled_subtrees[0].eq(&a));
320
321        assert!(tree.try_add(&[b"val_2"]).is_ok());
322        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
323
324        assert!(tree.try_add(&[b"val_3"]).is_ok());
325        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
326
327        assert_eq!(tree.filled_subtrees[0], c);
328        assert_eq!(tree.filled_subtrees[1], i);
329        assert_eq!(tree.filled_subtrees[2], m);
330        assert_eq!(root, tree.get_root());
331
332        let val1_proof = vec![b.clone(), j.clone(), n.clone()];
333        let val2_proof = vec![a.clone(), j.clone(), n.clone()];
334        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
335
336        // Check filled leaves
337        assert!(tree.contains(&val1_proof, &[b"val_1"]));
338        assert!(tree.contains(&val2_proof, &[b"val_2"]));
339        assert!(tree.contains(&val3_proof, &[b"val_3"]));
340
341        // Check empty leaves
342        assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
343        assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
344        assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
345        assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
346        assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
347
348        // Remove val2 from the tree
349        assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
350
351        // Update the expected tree structure
352        let i = hash_left_right(a, empty);
353        let m: Hash = hash_left_right(i, j);
354        let root = hash_left_right(m, n);
355
356        assert_eq!(root, tree.get_root());
357
358        let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
359        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
360
361        assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
362        assert!(tree.contains_leaf(&val2_proof, empty_leaf));
363        assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
364
365        // Check that val2 is no longer in the tree
366        assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
367
368        // Insert val4 into the tree
369        assert!(tree.try_add(&[b"val_4"]).is_ok());
370        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
371
372        // Update the expected tree structure
373        let d = Hash::from(Leaf::new(&[b"val_4"]));
374        let j = hash_left_right(c, d);
375        let m = hash_left_right(i, j);
376        let root = hash_left_right(m, n);
377
378        assert_eq!(root, tree.get_root());
379    }
380
381    #[test]
382    fn test_proof() {
383        let seeds: &[&[u8]] = &[b"test"];
384
385        let mut tree = TestTree::new(seeds);
386
387        let leaves = [
388            Leaf::new(&[b"val_1"]),
389            Leaf::new(&[b"val_2"]),
390            Leaf::new(&[b"val_3"]),
391        ];
392
393        assert!(tree.try_add(&[b"val_1"]).is_ok());
394        assert!(tree.try_add(&[b"val_2"]).is_ok());
395        assert!(tree.try_add(&[b"val_3"]).is_ok());
396
397        let val1_proof = tree.get_merkle_proof(&leaves, 0);
398        let val2_proof = tree.get_merkle_proof(&leaves, 1);
399        let val3_proof = tree.get_merkle_proof(&leaves, 2);
400
401        assert!(tree.contains(&val1_proof, &[b"val_1"]));
402        assert!(tree.contains(&val2_proof, &[b"val_2"]));
403        assert!(tree.contains(&val3_proof, &[b"val_3"]));
404
405        // Invalid Proof Length
406        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
407        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
408
409        assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
410        assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
411
412        // Empty Proof
413        let empty_proof: Vec<Hash> = Vec::new();
414        assert!(!tree.contains(&empty_proof, &[b"val_1"]));
415    }
416
417    #[test]
418    fn test_init_and_reinit() {
419        let seeds: &[&[u8]] = &[b"test"];
420        let mut tree = TestTree::new(seeds);
421        
422        // Store initial state
423        let initial_root = tree.get_root();
424        let initial_zeros = tree.zero_values;
425        let initial_filled = tree.filled_subtrees;
426        let initial_index = tree.next_index;
427
428        // Add a leaf to modify the tree
429        assert!(tree.try_add(&[b"val_1"]).is_ok());
430        
431        // Reinitialize
432        tree.init(seeds);
433        
434        // Verify tree is reset to initial state
435        assert_eq!(tree.get_root(), initial_root);
436        assert_eq!(tree.zero_values, initial_zeros);
437        assert_eq!(tree.filled_subtrees, initial_filled);
438        assert_eq!(tree.next_index, initial_index);
439    }
440
441    #[test]
442    fn test_tree_full() {
443        let seeds: &[&[u8]] = &[b"test"];
444        let mut tree = TestTree::new(seeds);
445        
446        // Fill the tree (2^3 = 8 leaves)
447        for i in 0u8..8 {
448            assert!(tree.try_add(&[&[i]]).is_ok());
449        }
450        
451        // Try to add one more leaf
452        let result = tree.try_add(&[b"extra"]);
453        assert!(result.is_err());
454        assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
455    }
456
457    #[test]
458    fn test_replace_leaf() {
459        let seeds: &[&[u8]] = &[b"test"];
460        let mut tree = TestTree::new(seeds);
461        
462        // Add initial leaves
463        assert!(tree.try_add(&[b"val_1"]).is_ok());
464        assert!(tree.try_add(&[b"val_2"]).is_ok());
465        
466        // Get proof for val_1
467        let leaves = [
468            Leaf::new(&[b"val_1"]),
469            Leaf::new(&[b"val_2"]),
470        ];
471        let proof = tree.get_merkle_proof(&leaves, 0);
472        
473        // Replace val_1 with new_val
474        assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
475        
476        // Verify new leaf is present
477        assert!(tree.contains(&proof, &[b"new_val"]));
478        assert!(!tree.contains(&proof, &[b"val_1"]));
479        
480        // Verify val_2 is still present
481        let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
482        assert!(tree.contains(&proof_val2, &[b"val_2"]));
483    }
484
485    #[test]
486    fn test_verify() {
487        let seeds: &[&[u8]] = &[b"test"];
488        let mut tree = TestTree::new(seeds);
489        
490        // Add initial leaves
491        assert!(tree.try_add(&[b"val_1"]).is_ok());
492        assert!(tree.try_add(&[b"val_2"]).is_ok());
493        
494        // Get proof for val_1
495        let leaves = [
496            Leaf::new(&[b"val_1"]),
497            Leaf::new(&[b"val_2"]),
498        ];
499        let proof = tree.get_merkle_proof(&leaves, 0);
500        
501        // Verify proof (typed)
502        assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
503
504        let a : [u8; 32] = tree.get_root().to_bytes();
505        let b : [[u8; 32]; 3] = [
506            proof[0].to_bytes(),
507            proof[1].to_bytes(),
508            proof[2].to_bytes(),
509        ];
510        let c : [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
511
512        // Verify proof (generic)
513        assert!(verify(a, &b, c));
514    }
515}