brine_tree/
tree.rs

1#![allow(unexpected_cfgs)] 
2
3use bytemuck::{Pod, Zeroable};
4use super::hash::Hash;
5use super::{utils, utils::check_condition};
6use super::error::{ProgramError, ProgramResult};
7
8#[repr(C, align(8))]
9#[derive(Clone, Copy, PartialEq, Debug,)]
10pub struct MerkleTree<const N: usize> {
11    root: Hash,
12    filled_subtrees: [Hash; N],
13    zero_values: [Hash; N],
14    next_index: u64,
15}
16
17unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
18unsafe impl<const N: usize> Pod for MerkleTree<N> {}
19
20impl<const N: usize> MerkleTree<N> {
21    pub const fn get_depth(&self) -> u8 {
22        N as u8
23    }
24
25    pub const fn get_size() -> usize {
26        std::mem::size_of::<Self>()
27    }
28
29    pub fn get_root(&self) -> Hash {
30        self.root
31    }
32
33    pub fn get_empty_leaf(&self) -> Hash {
34        self.zero_values[0]
35    }
36
37    pub fn new(seeds: &[&[u8]]) -> Self {
38        let zeros = Self::calc_zeros(seeds);
39        Self {
40            next_index: 0,
41            root: zeros[N - 1],
42            filled_subtrees: zeros,
43            zero_values: zeros,
44        }
45    }
46
47    pub fn init(&mut self, seeds: &[&[u8]]) {
48        let zeros = Self::calc_zeros(seeds);
49        self.next_index = 0;
50        self.root = zeros[N - 1];
51        self.filled_subtrees = zeros;
52        self.zero_values = zeros;
53    }
54
55    fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
56        let mut zeros: [Hash; N] = [Hash::default(); N];
57        let mut current = utils::hashv(seeds);
58
59        for i in 0..N {
60            zeros[i] = current;
61            current = utils::hashv(&[current.as_ref(), current.as_ref()]);
62        }
63
64        zeros
65    }
66
67    pub fn try_insert(&mut self, val: Hash) -> ProgramResult {
68        check_condition(
69            self.next_index < (1u64 << N),
70            "merkle tree is full",
71        )?;
72
73        let mut current_index = self.next_index;
74        let mut current_hash = MerkleTree::<N>::as_leaf(val);
75        let mut left;
76        let mut right;
77
78        for i in 0..N {
79            if current_index % 2 == 0 {
80                left = current_hash;
81                right = self.zero_values[i];
82                self.filled_subtrees[i] = current_hash;
83            } else {
84                left = self.filled_subtrees[i];
85                right = current_hash;
86            }
87
88            current_hash = Self::hash_left_right(left, right);
89            current_index /= 2;
90        }
91
92        self.root = current_hash;
93        self.next_index += 1;
94
95        Ok(())
96    }
97
98    pub fn try_remove(&mut self, proof: &[Hash], val: Hash) -> ProgramResult {
99        self.check_length(proof)?;
100
101        self.try_replace_leaf(proof, Self::as_leaf(val), self.get_empty_leaf())
102    }
103
104    pub fn try_replace(&mut self, proof: &[Hash], original_val: Hash, new_val: Hash) -> ProgramResult {
105        self.check_length(proof)?;
106
107        let original_leaf = Self::as_leaf(original_val);
108        let new_leaf = Self::as_leaf(new_val);
109
110        self.try_replace_leaf(proof, original_leaf, new_leaf)
111    }
112
113    pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Hash, new_leaf: Hash) -> ProgramResult {
114        self.check_length(proof)?;
115
116        let original_path = MerkleTree::<N>::compute_path(proof, original_leaf);
117        let new_path = MerkleTree::<N>::compute_path(proof, new_leaf);
118
119        check_condition(
120            MerkleTree::<N>::is_valid_path(&original_path, self.root),
121            "invalid proof for original leaf",
122        )?;
123
124        for i in 0..N {
125            if original_path[i] == self.filled_subtrees[i] {
126                self.filled_subtrees[i] = new_path[i];
127            }
128        }
129
130        self.root = *new_path.last().unwrap();
131
132        Ok(())
133    }
134
135    pub fn contains(&self, proof: &[Hash], val: Hash) -> bool {
136        if let Err(_) = self.check_length(proof) {
137            return false;
138        }
139
140        let leaf = Self::as_leaf(val);
141        self.contains_leaf(proof, leaf)
142    }
143
144    pub fn contains_leaf(&self, proof: &[Hash], leaf: Hash) -> bool {
145        if let Err(_) = self.check_length(proof) {
146            return false;
147        }
148
149        let root = self.get_root();
150        Self::is_valid_leaf(proof, root, leaf)
151    }
152
153    pub fn as_leaf(val: Hash) -> Hash {
154        utils::hash(val.as_ref())
155    }
156
157    pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
158        let combined;
159        if left.to_bytes() <= right.to_bytes() {
160            combined = [left.as_ref(), right.as_ref()];
161        } else {
162            combined = [right.as_ref(), left.as_ref()];
163        }
164
165        utils::hashv(&combined)
166    }
167
168    pub fn compute_path(proof: &[Hash], leaf: Hash) -> Vec<Hash> {
169        let mut computed_path = Vec::with_capacity(proof.len() + 1);
170        let mut computed_hash = leaf;
171
172        computed_path.push(computed_hash);
173
174        for proof_element in proof.iter() {
175            computed_hash = Self::hash_left_right(computed_hash, *proof_element);
176            computed_path.push(computed_hash);
177        }
178
179        computed_path
180    }
181
182    pub fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Hash) -> bool {
183        let computed_path = Self::compute_path(proof, leaf);
184        Self::is_valid_path(&computed_path, root)
185    }
186
187    pub fn is_valid_path(path: &[Hash], root: Hash) -> bool {
188        if path.is_empty() {
189            return false;
190        }
191
192        *path.last().unwrap() == root
193    }
194
195    #[cfg(not(target_os = "solana"))]
196    fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
197        // A helper function that hashes all pairs of hashes into a new array of
198        // hashes.
199        let mut res = Vec::with_capacity(pairs.len() / 2);
200
201        for i in (0..pairs.len()).step_by(2) {
202            let left = pairs[i];
203            let right = pairs[i + 1];
204
205            let hashed = Self::hash_left_right(left, right);
206            res.push(hashed);
207        }
208
209        res
210    }
211
212    #[cfg(not(target_os = "solana"))]
213    pub fn get_merkle_proof(&self, values: &[Hash], index: usize) -> Vec<Hash> {
214        let mut layers = Vec::with_capacity(N);
215        let mut current_layer = values.to_vec();
216        for i in 0..N {
217            if current_layer.len() % 2 != 0 {
218                current_layer.push(self.zero_values[i]);
219            }
220
221            layers.push(current_layer.clone());
222            current_layer = Self::hash_pairs(current_layer);
223        }
224
225        // At this point we have all the layers of the merkle tree in an array
226        // of arrays. The next step is to find the siblings of the provided
227        // for_leaf all the way up the tree.
228
229        let mut proof = Vec::with_capacity(N);
230        let mut current_index = index;
231        let mut layer_index = 0;
232        let mut sibling;
233
234        for _ in 0..N {
235            if current_index % 2 == 0 {
236                sibling = layers[layer_index][current_index + 1];
237            } else {
238                sibling = layers[layer_index][current_index - 1];
239            }
240
241            proof.push(sibling);
242
243            current_index /= 2;
244            layer_index += 1;
245        }
246
247        proof
248    }
249
250    fn check_length(&self, proof: &[Hash]) -> Result<(), ProgramError> {
251        check_condition(
252            proof.len() == N,
253            "merkle proof length does not match tree depth",
254        )
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    type TestTree = MerkleTree<3>;
263
264    #[test]
265    fn test_create_tree() {
266        let seeds : &[&[u8]] = &[b"test"];
267        let tree = TestTree::new(seeds);
268
269        assert_eq!(tree.get_depth(), 3);
270        assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
271    }
272
273    #[test]
274    fn test_insert_and_remove() {
275        let seeds : &[&[u8]] = &[b"test"];
276
277        let mut tree = TestTree::new(seeds);
278        let empty = tree.zero_values.first().unwrap().clone();
279
280        let val1 = utils::hash(b"val_1");
281        let val2 = utils::hash(b"val_2");
282        let val3 = utils::hash(b"val_3");
283        let val4 = utils::hash(b"val_4");
284
285        // Tree structure:
286        // 
287        //              root
288        //            /     \
289        //         m           n
290        //       /   \       /   \
291        //      i     j     k     l
292        //     / \   / \   / \   / \
293        //    a  b  c  d  e  f  g  h
294
295        let a = TestTree::as_leaf(val1);
296        let b = TestTree::as_leaf(val2);
297        let c = TestTree::as_leaf(val3);
298
299        let d = empty.clone();
300        let e = empty.clone();
301        let f = empty.clone();
302        let g = empty.clone();
303        let h = empty.clone();
304
305        let i = TestTree::hash_left_right(a, b);
306        let j: Hash = TestTree::hash_left_right(c, d);
307        let k: Hash = TestTree::hash_left_right(e, f);
308        let l: Hash = TestTree::hash_left_right(g, h);
309        let m: Hash = TestTree::hash_left_right(i, j);
310        let n: Hash = TestTree::hash_left_right(k, l);
311        let root = TestTree::hash_left_right(m, n);
312
313        assert!(tree.try_insert(val1.clone()).is_ok());
314        assert!(tree.filled_subtrees[0].eq(&a));
315
316        assert!(tree.try_insert(val2.clone()).is_ok());
317        assert!(tree.filled_subtrees[0].eq(&a)); // Not a typo
318
319        assert!(tree.try_insert(val3.clone()).is_ok());
320        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
321
322        assert_eq!(tree.filled_subtrees[0], c);
323        assert_eq!(tree.filled_subtrees[1], i);
324        assert_eq!(tree.filled_subtrees[2], m);
325        assert_eq!(root, tree.get_root());
326
327        let val1_proof = vec![b.clone(), j.clone(), n.clone()];
328        let val2_proof = vec![a.clone(), j.clone(), n.clone()];
329        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
330
331        // Check filled leaves
332        assert!(tree.contains(&val1_proof, val1));
333        assert!(tree.contains(&val2_proof, val2));
334        assert!(tree.contains(&val3_proof, val3));
335
336        // Check empty leaves
337        assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty));
338        assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty));
339        assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty));
340        assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty));
341        assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty));
342
343        // Remove val2 from the tree
344        assert!(tree.try_remove(&val2_proof, val2).is_ok());
345
346        // Update the expected tree structure
347        let i = TestTree::hash_left_right(a, empty);
348        let m: Hash = TestTree::hash_left_right(i, j);
349        let root = TestTree::hash_left_right(m, n);
350
351        assert_eq!(root, tree.get_root());
352
353        let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
354        let val3_proof = vec![d.clone(), i.clone(), n.clone()];
355
356        assert!(tree.contains_leaf(&val1_proof, a));
357        assert!(tree.contains_leaf(&val2_proof, empty));
358        assert!(tree.contains_leaf(&val3_proof, c));
359
360        // Check that val2 is no longer in the tree
361        assert!(!tree.contains_leaf(&val2_proof, b));
362
363        // Insert val4 into the tree
364        assert!(tree.try_insert(val4.clone()).is_ok());
365        assert!(tree.filled_subtrees[0].eq(&c)); // Not a typo
366
367        // Update the expected tree structure
368        let d = TestTree::as_leaf(val4.clone());
369        let j = TestTree::hash_left_right(c, d);
370        let m = TestTree::hash_left_right(i, j);
371        let root = TestTree::hash_left_right(m, n);
372
373        assert_eq!(root, tree.get_root());
374
375    }
376
377    #[test]
378    fn test_proof() {
379        let seeds : &[&[u8]] = &[b"test"];
380
381        let mut tree = TestTree::new(seeds);
382
383        let val1 = utils::hash(b"val_1");
384        let val2 = utils::hash(b"val_2");
385        let val3 = utils::hash(b"val_3");
386
387        let leaves = [
388            TestTree::as_leaf(val1), 
389            TestTree::as_leaf(val2), 
390            TestTree::as_leaf(val3), 
391        ];
392
393        assert!(tree.try_insert(val1.clone()).is_ok());
394        assert!(tree.try_insert(val2.clone()).is_ok());
395        assert!(tree.try_insert(val3.clone()).is_ok());
396
397        let val1_proof = tree.get_merkle_proof(&leaves, 0);
398        let val2_proof = tree.get_merkle_proof(&leaves, 1);
399        let val3_proof = tree.get_merkle_proof(&leaves, 2);
400
401        assert!(tree.contains(&val1_proof, val1));
402        assert!(tree.contains(&val2_proof, val2));
403        assert!(tree.contains(&val3_proof, val3));
404
405        // Invalid Proof Length
406        let invalid_proof_short = &val1_proof[..2]; // Shorter than depth
407        let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); // Longer than depth
408
409        assert!(!tree.contains(&invalid_proof_short, val1));
410        assert!(!tree.contains(&invalid_proof_long, val1));
411
412        // Empty Proof
413        let empty_proof: Vec<Hash> = Vec::new();
414        assert!(!tree.contains(&empty_proof, val1));
415    }
416}