1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9 root: Hash,
10 filled_subtrees: [Hash; N],
11 zero_values: [Hash; N],
12 next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19 pub const fn get_depth(&self) -> u8 {
20 N as u8
21 }
22
23 pub const fn get_size() -> usize {
24 core::mem::size_of::<Self>()
25 }
26
27 pub fn get_root(&self) -> Hash {
28 self.root
29 }
30
31 pub fn get_empty_leaf(&self) -> Leaf {
32 self.zero_values[0].as_leaf()
33 }
34
35 pub fn new(seeds: &[&[u8]]) -> Self {
36 let zeros = Self::calc_zeros(seeds);
37 Self {
38 next_index: 0,
39 root: zeros[N - 1],
40 filled_subtrees: zeros,
41 zero_values: zeros,
42 }
43 }
44
45 pub fn init(&mut self, seeds: &[&[u8]]) {
46 let zeros = Self::calc_zeros(seeds);
47 self.next_index = 0;
48 self.root = zeros[N - 1];
49 self.filled_subtrees = zeros;
50 self.zero_values = zeros;
51 }
52
53 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
54 let mut zeros: [Hash; N] = [Hash::default(); N];
55 let mut current = hashv(seeds);
56
57 for i in 0..N {
58 zeros[i] = current;
59 current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
60 }
61
62 zeros
63 }
64
65 pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
66 let leaf = Leaf::new(data);
67 self.try_add_leaf(leaf)
68 }
69
70 pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
71 check_condition(
72 self.next_index < (1u64 << N),
73 BrineTreeError::TreeFull,
74 )?;
75
76 let mut current_index = self.next_index;
77 let mut current_hash = Hash::from(leaf);
78 let mut left;
79 let mut right;
80
81 for i in 0..N {
82 if current_index % 2 == 0 {
83 left = current_hash;
84 right = self.zero_values[i];
85 self.filled_subtrees[i] = current_hash;
86 } else {
87 left = self.filled_subtrees[i];
88 right = current_hash;
89 }
90
91 current_hash = hash_left_right(left, right);
92 current_index /= 2;
93 }
94
95 self.root = current_hash;
96 self.next_index += 1;
97
98 Ok(())
99 }
100
101 pub fn try_remove(&mut self, proof: &[Hash], data: &[&[u8]]) -> ProgramResult {
102 let leaf = Leaf::new(data);
103 self.try_remove_leaf(proof, leaf)
104 }
105
106 pub fn try_remove_leaf(&mut self, proof: &[Hash], leaf: Leaf) -> ProgramResult {
107 self.check_length(proof)?;
108 self.try_replace_leaf(proof, leaf, self.get_empty_leaf())
109 }
110
111 pub fn try_replace(&mut self, proof: &[Hash], original_data: &[&[u8]], new_data: &[&[u8]]) -> ProgramResult {
112 let original_leaf = Leaf::new(original_data);
113 let new_leaf = Leaf::new(new_data);
114 self.try_replace_leaf(proof, original_leaf, new_leaf)
115 }
116
117 pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Leaf, new_leaf: Leaf) -> ProgramResult {
118 self.check_length(proof)?;
119
120 let original_path = compute_path(proof, original_leaf);
121 let new_path = compute_path(proof, new_leaf);
122
123 check_condition(
124 is_valid_path(&original_path, self.root),
125 BrineTreeError::InvalidProof,
126 )?;
127
128 for i in 0..N {
129 if original_path[i] == self.filled_subtrees[i] {
130 self.filled_subtrees[i] = new_path[i];
131 }
132 }
133
134 self.root = *new_path.last().unwrap();
135
136 Ok(())
137 }
138
139 pub fn contains(&self, proof: &[Hash], data: &[&[u8]]) -> bool {
140 let leaf = Leaf::new(data);
141 self.contains_leaf(proof, leaf)
142 }
143
144 pub fn contains_leaf(&self, proof: &[Hash], leaf: Leaf) -> bool {
145 if let Err(_) = self.check_length(proof) {
146 return false;
147 }
148
149 let root = self.get_root();
150 is_valid_leaf(proof, root, leaf)
151 }
152
153 fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
154 check_condition(proof.len() == N, BrineTreeError::ProofLength)
155 }
156
157 #[cfg(not(feature = "solana"))]
158 fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
159 let mut res = Vec::with_capacity(pairs.len() / 2);
160
161 for i in (0..pairs.len()).step_by(2) {
162 let left = pairs[i];
163 let right = pairs[i + 1];
164
165 let hashed = hash_left_right(left, right);
166 res.push(hashed);
167 }
168
169 res
170 }
171
172 #[cfg(not(feature = "solana"))]
173 pub fn get_merkle_proof(&self, values: &[Leaf], index: usize) -> Vec<Hash> {
174 let mut layers = Vec::with_capacity(N);
175 let mut current_layer: Vec<Hash> = values.iter().map(|leaf| Hash::from(*leaf)).collect();
176
177 for i in 0..N {
178 if current_layer.len() % 2 != 0 {
179 current_layer.push(self.zero_values[i]);
180 }
181
182 layers.push(current_layer.clone());
183 current_layer = Self::hash_pairs(current_layer);
184 }
185
186 let mut proof = Vec::with_capacity(N);
187 let mut current_index = index;
188 let mut layer_index = 0;
189 let mut sibling;
190
191 for _ in 0..N {
192 if current_index % 2 == 0 {
193 sibling = layers[layer_index][current_index + 1];
194 } else {
195 sibling = layers[layer_index][current_index - 1];
196 }
197
198 proof.push(sibling);
199
200 current_index /= 2;
201 layer_index += 1;
202 }
203
204 proof
205 }
206}
207
208fn hash_left_right(left: Hash, right: Hash) -> Hash {
209 let combined;
210 if left.to_bytes() <= right.to_bytes() {
211 combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
212 } else {
213 combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
214 }
215
216 hashv(&combined)
217}
218
219fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
220 let mut computed_path = Vec::with_capacity(proof.len() + 1);
221 let mut computed_hash = Hash::from(leaf);
222
223 computed_path.push(computed_hash);
224
225 for proof_element in proof.iter() {
226 computed_hash = hash_left_right(computed_hash, *proof_element);
227 computed_path.push(computed_hash);
228 }
229
230 computed_path
231}
232
233fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
234 let computed_path = compute_path(proof, leaf);
235 is_valid_path(&computed_path, root)
236}
237
238fn is_valid_path(path: &[Hash], root: Hash) -> bool {
239 if path.is_empty() {
240 return false;
241 }
242
243 *path.last().unwrap() == root
244}
245
246pub fn verify<const N: usize>(root: Hash, proof: &[Hash], leaf: Leaf) -> bool {
248 if proof.len() != N {
249 return false;
250 }
251
252 let computed_path = compute_path(proof, leaf);
253 is_valid_path(&computed_path, root)
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 type TestTree = MerkleTree<3>;
261
262 #[test]
263 fn test_create_tree() {
264 let seeds: &[&[u8]] = &[b"test"];
265 let tree = TestTree::new(seeds);
266
267 assert_eq!(tree.get_depth(), 3);
268 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
269 }
270
271 #[test]
272 fn test_insert_and_remove() {
273 let seeds: &[&[u8]] = &[b"test"];
274
275 let mut tree = TestTree::new(seeds);
276 let empty = tree.zero_values.first().unwrap().clone();
277 let empty_leaf = empty.as_leaf();
278
279 let a = Hash::from(Leaf::new(&[b"val_1"]));
290 let b = Hash::from(Leaf::new(&[b"val_2"]));
291 let c = Hash::from(Leaf::new(&[b"val_3"]));
292
293 let d = empty.clone();
294 let e = empty.clone();
295 let f = empty.clone();
296 let g = empty.clone();
297 let h = empty.clone();
298
299 let i = hash_left_right(a, b);
300 let j: Hash = hash_left_right(c, d);
301 let k: Hash = hash_left_right(e, f);
302 let l: Hash = hash_left_right(g, h);
303 let m: Hash = hash_left_right(i, j);
304 let n: Hash = hash_left_right(k, l);
305 let root = hash_left_right(m, n);
306
307 assert!(tree.try_add(&[b"val_1"]).is_ok());
308 assert!(tree.filled_subtrees[0].eq(&a));
309
310 assert!(tree.try_add(&[b"val_2"]).is_ok());
311 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_add(&[b"val_3"]).is_ok());
314 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
317 assert_eq!(tree.filled_subtrees[1], i);
318 assert_eq!(tree.filled_subtrees[2], m);
319 assert_eq!(root, tree.get_root());
320
321 let val1_proof = vec![b.clone(), j.clone(), n.clone()];
322 let val2_proof = vec![a.clone(), j.clone(), n.clone()];
323 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
324
325 assert!(tree.contains(&val1_proof, &[b"val_1"]));
327 assert!(tree.contains(&val2_proof, &[b"val_2"]));
328 assert!(tree.contains(&val3_proof, &[b"val_3"]));
329
330 assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
332 assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
333 assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
334 assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
335 assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
336
337 assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
339
340 let i = hash_left_right(a, empty);
342 let m: Hash = hash_left_right(i, j);
343 let root = hash_left_right(m, n);
344
345 assert_eq!(root, tree.get_root());
346
347 let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
348 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
349
350 assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
351 assert!(tree.contains_leaf(&val2_proof, empty_leaf));
352 assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
353
354 assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
356
357 assert!(tree.try_add(&[b"val_4"]).is_ok());
359 assert!(tree.filled_subtrees[0].eq(&c)); let d = Hash::from(Leaf::new(&[b"val_4"]));
363 let j = hash_left_right(c, d);
364 let m = hash_left_right(i, j);
365 let root = hash_left_right(m, n);
366
367 assert_eq!(root, tree.get_root());
368 }
369
370 #[test]
371 fn test_proof() {
372 let seeds: &[&[u8]] = &[b"test"];
373
374 let mut tree = TestTree::new(seeds);
375
376 let leaves = [
377 Leaf::new(&[b"val_1"]),
378 Leaf::new(&[b"val_2"]),
379 Leaf::new(&[b"val_3"]),
380 ];
381
382 assert!(tree.try_add(&[b"val_1"]).is_ok());
383 assert!(tree.try_add(&[b"val_2"]).is_ok());
384 assert!(tree.try_add(&[b"val_3"]).is_ok());
385
386 let val1_proof = tree.get_merkle_proof(&leaves, 0);
387 let val2_proof = tree.get_merkle_proof(&leaves, 1);
388 let val3_proof = tree.get_merkle_proof(&leaves, 2);
389
390 assert!(tree.contains(&val1_proof, &[b"val_1"]));
391 assert!(tree.contains(&val2_proof, &[b"val_2"]));
392 assert!(tree.contains(&val3_proof, &[b"val_3"]));
393
394 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
399 assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
400
401 let empty_proof: Vec<Hash> = Vec::new();
403 assert!(!tree.contains(&empty_proof, &[b"val_1"]));
404 }
405
406 #[test]
407 fn test_init_and_reinit() {
408 let seeds: &[&[u8]] = &[b"test"];
409 let mut tree = TestTree::new(seeds);
410
411 let initial_root = tree.get_root();
413 let initial_zeros = tree.zero_values;
414 let initial_filled = tree.filled_subtrees;
415 let initial_index = tree.next_index;
416
417 assert!(tree.try_add(&[b"val_1"]).is_ok());
419
420 tree.init(seeds);
422
423 assert_eq!(tree.get_root(), initial_root);
425 assert_eq!(tree.zero_values, initial_zeros);
426 assert_eq!(tree.filled_subtrees, initial_filled);
427 assert_eq!(tree.next_index, initial_index);
428 }
429
430 #[test]
431 fn test_tree_full() {
432 let seeds: &[&[u8]] = &[b"test"];
433 let mut tree = TestTree::new(seeds);
434
435 for i in 0u8..8 {
437 assert!(tree.try_add(&[&[i]]).is_ok());
438 }
439
440 let result = tree.try_add(&[b"extra"]);
442 assert!(result.is_err());
443 assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
444 }
445
446 #[test]
447 fn test_replace_leaf() {
448 let seeds: &[&[u8]] = &[b"test"];
449 let mut tree = TestTree::new(seeds);
450
451 assert!(tree.try_add(&[b"val_1"]).is_ok());
453 assert!(tree.try_add(&[b"val_2"]).is_ok());
454
455 let leaves = [
457 Leaf::new(&[b"val_1"]),
458 Leaf::new(&[b"val_2"]),
459 ];
460 let proof = tree.get_merkle_proof(&leaves, 0);
461
462 assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
464
465 assert!(tree.contains(&proof, &[b"new_val"]));
467 assert!(!tree.contains(&proof, &[b"val_1"]));
468
469 let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
471 assert!(tree.contains(&proof_val2, &[b"val_2"]));
472 }
473
474 #[test]
475 fn test_verify() {
476 let seeds: &[&[u8]] = &[b"test"];
477 let mut tree = TestTree::new(seeds);
478
479 assert!(tree.try_add(&[b"val_1"]).is_ok());
481 assert!(tree.try_add(&[b"val_2"]).is_ok());
482
483 let leaves = [
485 Leaf::new(&[b"val_1"]),
486 Leaf::new(&[b"val_2"]),
487 ];
488 let proof = tree.get_merkle_proof(&leaves, 0);
489
490 assert!(verify::<3>(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
492 }
493}