1use super::{
2 error::{BrineTreeError, ProgramResult},
3 hash::{hashv, Hash, Leaf},
4 utils::check_condition,
5};
6use bytemuck::{Pod, Zeroable};
7
8#[repr(C)]
9#[derive(Clone, Copy, PartialEq, Debug)]
10pub struct MerkleTree<const N: usize> {
11 pub root: Hash,
12 pub filled_subtrees: [Hash; N],
13 pub zero_values: [Hash; N],
14 pub next_index: u64,
15}
16
17unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
18unsafe impl<const N: usize> Pod for MerkleTree<N> {}
19
20impl<const N: usize> MerkleTree<N> {
21
22 pub fn new(seeds: &[&[u8]]) -> Self {
23 let zeros = Self::calc_zeros(seeds);
24 Self {
25 next_index: 0,
26 root: zeros[N - 1],
27 filled_subtrees: zeros,
28 zero_values: zeros,
29 }
30 }
31
32 pub const fn get_depth(&self) -> u8 {
33 N as u8
34 }
35
36 pub const fn get_size() -> usize {
37 core::mem::size_of::<Self>()
38 }
39
40 pub fn get_root(&self) -> Hash {
41 self.root
42 }
43
44 pub fn get_empty_leaf(&self) -> Leaf {
45 self.zero_values[0].as_leaf()
46 }
47
48 pub fn init(&mut self, seeds: &[&[u8]]) {
49 let zeros = Self::calc_zeros(seeds);
50 self.next_index = 0;
51 self.root = zeros[N - 1];
52 self.filled_subtrees = zeros;
53 self.zero_values = zeros;
54 }
55
56 pub fn get_leaf_count(&self) -> u64 {
58 self.next_index
59 }
60
61 pub fn get_capacity(&self) -> u64 {
63 1u64 << N
64 }
65
66 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
68 let mut zeros: [Hash; N] = [Hash::default(); N];
69 let mut current = hashv(seeds);
70
71 for i in 0..N {
72 zeros[i] = current;
73 current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
74 }
75
76 zeros
77 }
78
79 pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
81 let leaf = Leaf::new(data);
82 self.try_add_leaf(leaf)
83 }
84
85 pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
87 check_condition(self.next_index < (1u64 << N), BrineTreeError::TreeFull)?;
88
89 let mut current_index = self.next_index;
90 let mut current_hash = Hash::from(leaf);
91 let mut left;
92 let mut right;
93
94 for i in 0..N {
95 if current_index % 2 == 0 {
96 left = current_hash;
97 right = self.zero_values[i];
98 self.filled_subtrees[i] = current_hash;
99 } else {
100 left = self.filled_subtrees[i];
101 right = current_hash;
102 }
103
104 current_hash = hash_left_right(left, right);
105 current_index /= 2;
106 }
107
108 self.root = current_hash;
109 self.next_index += 1;
110
111 Ok(())
112 }
113
114 pub fn try_remove<P>(&mut self, proof: &[P], data: &[&[u8]]) -> ProgramResult
116 where
117 P: Into<Hash> + Copy,
118 {
119 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
120 let original_leaf = Leaf::new(data);
121 self.try_remove_leaf(&proof_hashes, original_leaf)
122 }
123
124 pub fn try_remove_leaf<P>(&mut self, proof: &[P], leaf: Leaf) -> ProgramResult
126 where
127 P: Into<Hash> + Copy,
128 {
129 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
130 self.check_length(&proof_hashes)?;
131 self.try_replace_leaf(&proof_hashes, leaf, self.get_empty_leaf())
132 }
133
134 pub fn try_replace<P>(
136 &mut self,
137 proof: &[P],
138 original_data: &[&[u8]],
139 new_data: &[&[u8]],
140 ) -> ProgramResult
141 where
142 P: Into<Hash> + Copy,
143 {
144 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
145 let original_leaf = Leaf::new(original_data);
146 let new_leaf = Leaf::new(new_data);
147 self.try_replace_leaf(&proof_hashes, original_leaf, new_leaf)
148 }
149
150 pub fn try_replace_leaf<P>(
152 &mut self,
153 proof: &[P],
154 original_leaf: Leaf,
155 new_leaf: Leaf,
156 ) -> ProgramResult
157 where
158 P: Into<Hash> + Copy,
159 {
160 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
161 self.check_length(&proof_hashes)?;
162 let original_path = compute_path(&proof_hashes, original_leaf);
163 let new_path = compute_path(&proof_hashes, new_leaf);
164 check_condition(
165 is_valid_path(&original_path, self.root),
166 BrineTreeError::InvalidProof,
167 )?;
168 for i in 0..N {
169 if original_path[i] == self.filled_subtrees[i] {
170 self.filled_subtrees[i] = new_path[i];
171 }
172 }
173 self.root = *new_path.last().unwrap();
174 Ok(())
175 }
176
177 pub fn contains<P>(&self, proof: &[P], data: &[&[u8]]) -> bool
179 where
180 P: Into<Hash> + Copy,
181 {
182 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
183 let leaf = Leaf::new(data);
184 self.contains_leaf(&proof_hashes, leaf)
185 }
186
187 pub fn contains_leaf<P>(&self, proof: &[P], leaf: Leaf) -> bool
189 where
190 P: Into<Hash> + Copy,
191 {
192 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
193 if self.check_length(&proof_hashes).is_err() {
194 return false;
195 }
196 is_valid_leaf(&proof_hashes, self.root, leaf)
197 }
198
199 fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
201 check_condition(proof.len() == N, BrineTreeError::ProofLength)
202 }
203
204 pub fn get_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
206 get_merkle_proof(leaves, &self.zero_values, leaf_index, N)
207 }
208
209 pub fn get_layer_nodes(&self, leaves: &[Leaf], layer_number: usize) -> Vec<Hash> {
212 if layer_number > N {
213 return vec![];
214 }
215
216 let valid_leaves = leaves
217 .iter()
218 .take(self.next_index as usize)
219 .copied()
220 .collect::<Vec<Leaf>>();
221
222 let mut current_layer: Vec<Hash> =
223 valid_leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
224
225 if current_layer.is_empty() || layer_number == 0 {
226 return current_layer;
227 }
228
229 let mut current_level: usize = 0;
230 loop {
231 if current_layer.is_empty() {
232 break;
233 }
234 let mut next_layer = Vec::with_capacity(current_layer.len().div_ceil(2));
235 let mut i = 0;
236 while i < current_layer.len() {
237 if i + 1 < current_layer.len() {
238 let val = hash_left_right(current_layer[i], current_layer[i + 1]);
239 next_layer.push(val);
240 i += 2;
241 } else {
242 let val = hash_left_right(current_layer[i], self.zero_values[current_level]);
243 next_layer.push(val);
244 i += 1;
245 }
246 }
247 current_level += 1;
248 if current_level == layer_number {
249 return next_layer;
250 }
251 current_layer = next_layer;
252 }
253 vec![]
254 }
255}
256
257pub fn get_merkle_proof(
259 leaves: &[Leaf],
260 zero_values: &[Hash],
261 leaf_index: usize,
262 height: usize,
263) -> Vec<Hash> {
264 let mut layers = Vec::with_capacity(height);
265 let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
266
267 for i in 0..height {
268 if current_layer.len() % 2 != 0 {
269 current_layer.push(zero_values[i]);
270 }
271
272 layers.push(current_layer.clone());
273 current_layer = hash_pairs(current_layer);
274 }
275
276 let mut proof = Vec::with_capacity(height);
277 let mut current_index = leaf_index;
278 let mut layer_index = 0;
279
280 for _ in 0..height {
281 let sibling = if current_index % 2 == 0 {
282 layers[layer_index][current_index + 1]
283 } else {
284 layers[layer_index][current_index - 1]
285 };
286
287 proof.push(sibling);
288
289 current_index /= 2;
290 layer_index += 1;
291 }
292
293 proof
294}
295
296pub fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
298 let mut res = Vec::with_capacity(pairs.len() / 2);
299
300 for i in (0..pairs.len()).step_by(2) {
301 let left = pairs[i];
302 let right = pairs[i + 1];
303
304 let hashed = hash_left_right(left, right);
305 res.push(hashed);
306 }
307
308 res
309}
310
311pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
313 let combined;
314 if left.to_bytes() <= right.to_bytes() {
315 combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
316 } else {
317 combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
318 }
319
320 hashv(&combined)
321}
322
323pub fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
325 let mut computed_path = Vec::with_capacity(proof.len() + 1);
326 let mut computed_hash = Hash::from(leaf);
327
328 computed_path.push(computed_hash);
329
330 for proof_element in proof.iter() {
331 computed_hash = hash_left_right(computed_hash, *proof_element);
332 computed_path.push(computed_hash);
333 }
334
335 computed_path
336}
337
338fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
339 let computed_path = compute_path(proof, leaf);
340 is_valid_path(&computed_path, root)
341}
342
343fn is_valid_path(path: &[Hash], root: Hash) -> bool {
344 if path.is_empty() {
345 return false;
346 }
347
348 *path.last().unwrap() == root
349}
350
351pub fn verify<Root, Item, L>(root: Root, proof: &[Item], leaf: L) -> bool
353where
354 Root: Into<Hash>,
355 Item: Into<Hash> + Copy,
356 L: Into<Leaf>,
357{
358 let root_h: Hash = root.into();
359 let proof_hashes: Vec<Hash> = proof.iter().map(|&x| x.into()).collect();
360
361 let leaf_h: Leaf = leaf.into();
362 let path = compute_path(&proof_hashes, leaf_h);
363 is_valid_path(&path, root_h)
364}
365
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 type TestTree = MerkleTree<3>;
372
373 #[test]
374 fn test_create_tree() {
375 let seeds: &[&[u8]] = &[b"test"];
376 let tree = TestTree::new(seeds);
377
378 assert_eq!(tree.get_depth(), 3);
379 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
380 }
381
382 #[test]
383 fn test_insert_and_remove() {
384 let seeds: &[&[u8]] = &[b"test"];
385
386 let mut tree = TestTree::new(seeds);
387 let empty = *tree.zero_values.first().unwrap();
388 let empty_leaf = empty.as_leaf();
389
390 let a = Hash::from(Leaf::new(&[b"val_1"]));
401 let b = Hash::from(Leaf::new(&[b"val_2"]));
402 let c = Hash::from(Leaf::new(&[b"val_3"]));
403
404 let d = empty;
405 let e = empty;
406 let f = empty;
407 let g = empty;
408 let h = empty;
409
410 let i = hash_left_right(a, b);
411 let j: Hash = hash_left_right(c, d);
412 let k: Hash = hash_left_right(e, f);
413 let l: Hash = hash_left_right(g, h);
414 let m: Hash = hash_left_right(i, j);
415 let n: Hash = hash_left_right(k, l);
416 let root = hash_left_right(m, n);
417
418 assert!(tree.try_add(&[b"val_1"]).is_ok());
419 assert!(tree.filled_subtrees[0].eq(&a));
420
421 assert!(tree.try_add(&[b"val_2"]).is_ok());
422 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_add(&[b"val_3"]).is_ok());
425 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
428 assert_eq!(tree.filled_subtrees[1], i);
429 assert_eq!(tree.filled_subtrees[2], m);
430 assert_eq!(root, tree.get_root());
431
432 let val1_proof = vec![b, j, n];
433 let val2_proof = vec![a, j, n];
434 let val3_proof = vec![d, i, n];
435
436 assert!(tree.contains(&val1_proof, &[b"val_1"]));
438 assert!(tree.contains(&val2_proof, &[b"val_2"]));
439 assert!(tree.contains(&val3_proof, &[b"val_3"]));
440
441 assert!(tree.contains_leaf(&[c, i, n], empty_leaf));
443 assert!(tree.contains_leaf(&[f, l, m], empty_leaf));
444 assert!(tree.contains_leaf(&[e, l, m], empty_leaf));
445 assert!(tree.contains_leaf(&[h, k, m], empty_leaf));
446 assert!(tree.contains_leaf(&[g, k, m], empty_leaf));
447
448 assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
450
451 let i = hash_left_right(a, empty);
453 let m: Hash = hash_left_right(i, j);
454 let root = hash_left_right(m, n);
455
456 assert_eq!(root, tree.get_root());
457
458 let val1_proof = vec![empty, j, n];
459 let val3_proof = vec![d, i, n];
460
461 assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
462 assert!(tree.contains_leaf(&val2_proof, empty_leaf));
463 assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
464
465 assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
467
468 assert!(tree.try_add(&[b"val_4"]).is_ok());
470 assert!(tree.filled_subtrees[0].eq(&c)); let d = Hash::from(Leaf::new(&[b"val_4"]));
474 let j = hash_left_right(c, d);
475 let m = hash_left_right(i, j);
476 let root = hash_left_right(m, n);
477
478 assert_eq!(root, tree.get_root());
479 }
480
481 #[test]
482 fn test_proof() {
483 let seeds: &[&[u8]] = &[b"test"];
484
485 let mut tree = TestTree::new(seeds);
486
487 let leaves = [
488 Leaf::new(&[b"val_1"]),
489 Leaf::new(&[b"val_2"]),
490 Leaf::new(&[b"val_3"]),
491 ];
492
493 assert!(tree.try_add(&[b"val_1"]).is_ok());
494 assert!(tree.try_add(&[b"val_2"]).is_ok());
495 assert!(tree.try_add(&[b"val_3"]).is_ok());
496
497 let val1_proof = tree.get_proof(&leaves, 0);
498 let val2_proof = tree.get_proof(&leaves, 1);
499 let val3_proof = tree.get_proof(&leaves, 2);
500
501 assert!(tree.contains(&val1_proof, &[b"val_1"]));
502 assert!(tree.contains(&val2_proof, &[b"val_2"]));
503 assert!(tree.contains(&val3_proof, &[b"val_3"]));
504
505 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(invalid_proof_short, &[b"val_1"]));
510 assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
511
512 let empty_proof: Vec<Hash> = Vec::new();
514 assert!(!tree.contains(&empty_proof, &[b"val_1"]));
515 }
516
517 #[test]
518 fn test_init_and_reinit() {
519 let seeds: &[&[u8]] = &[b"test"];
520 let mut tree = TestTree::new(seeds);
521
522 let initial_root = tree.get_root();
524 let initial_zeros = tree.zero_values;
525 let initial_filled = tree.filled_subtrees;
526 let initial_index = tree.next_index;
527
528 assert!(tree.try_add(&[b"val_1"]).is_ok());
530
531 tree.init(seeds);
533
534 assert_eq!(tree.get_root(), initial_root);
536 assert_eq!(tree.zero_values, initial_zeros);
537 assert_eq!(tree.filled_subtrees, initial_filled);
538 assert_eq!(tree.next_index, initial_index);
539 }
540
541 #[test]
542 fn test_tree_full() {
543 let seeds: &[&[u8]] = &[b"test"];
544 let mut tree = TestTree::new(seeds);
545
546 for i in 0u8..8 {
548 assert!(tree.try_add(&[&[i]]).is_ok());
549 }
550
551 let result = tree.try_add(&[b"extra"]);
553 assert!(result.is_err());
554 assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
555 }
556
557 #[test]
558 fn test_replace_leaf() {
559 let seeds: &[&[u8]] = &[b"test"];
560 let mut tree = TestTree::new(seeds);
561
562 assert!(tree.try_add(&[b"val_1"]).is_ok());
564 assert!(tree.try_add(&[b"val_2"]).is_ok());
565
566 let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
568 let proof = tree.get_proof(&leaves, 0);
569
570 assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
572
573 assert!(tree.contains(&proof, &[b"new_val"]));
575 assert!(!tree.contains(&proof, &[b"val_1"]));
576
577 let proof_val2 = tree.get_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
579 assert!(tree.contains(&proof_val2, &[b"val_2"]));
580 }
581
582 #[test]
583 fn test_verify() {
584 let seeds: &[&[u8]] = &[b"test"];
585 let mut tree = TestTree::new(seeds);
586
587 assert!(tree.try_add(&[b"val_1"]).is_ok());
589 assert!(tree.try_add(&[b"val_2"]).is_ok());
590
591 let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
593 let proof = tree.get_proof(&leaves, 0);
594
595 assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
597
598 let a: [u8; 32] = tree.get_root().to_bytes();
599 let b: [[u8; 32]; 3] = [
600 proof[0].to_bytes(),
601 proof[1].to_bytes(),
602 proof[2].to_bytes(),
603 ];
604 let c: [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
605
606 assert!(verify(a, &b, c));
608 }
609
610 #[test]
611 fn test_get_layer_nodes() {
612 let seeds: &[&[u8]] = &[b"test"];
613 let mut tree = TestTree::new(seeds);
614 let empty = tree.zero_values[0];
615
616 let leaves = [
618 Leaf::new(&[b"val_1"]),
619 Leaf::new(&[b"val_2"]),
620 Leaf::new(&[b"val_3"]),
621 Leaf::new(&[b"val_4"]),
622 ];
623
624 assert_eq!(tree.get_layer_nodes(&leaves, 0), vec![]);
626 assert_eq!(tree.get_layer_nodes(&leaves, 1), vec![]);
627
628 assert!(tree.try_add(&[b"val_1"]).is_ok());
630 assert!(tree.try_add(&[b"val_2"]).is_ok());
631 assert!(tree.try_add(&[b"val_3"]).is_ok());
632
633 let a = Hash::from(leaves[0]);
643 let b = Hash::from(leaves[1]);
644 let c = Hash::from(leaves[2]);
645 let d = empty;
646 let i = hash_left_right(a, b);
647 let j = hash_left_right(c, d);
648
649 let layer_0 = tree.get_layer_nodes(&leaves, 0);
651 assert_eq!(layer_0, vec![a, b, c]);
652
653 let layer_1 = tree.get_layer_nodes(&leaves, 1);
655 assert_eq!(layer_1, vec![i, j]);
656
657 let layer_2 = tree.get_layer_nodes(&leaves, 2);
659 let m = hash_left_right(i, j);
660 assert_eq!(layer_2, vec![m]);
661
662 let layer_3 = tree.get_layer_nodes(&leaves, 3);
664 assert_eq!(layer_3, vec![tree.get_root()]);
665
666 let layer_4 = tree.get_layer_nodes(&leaves, 4);
668 assert_eq!(layer_4, vec![]);
669
670 assert!(tree.try_add(&[b"val_4"]).is_ok());
672 let d = Hash::from(leaves[3]);
673 let j = hash_left_right(c, d);
674
675 let layer_0 = tree.get_layer_nodes(&leaves, 0);
677 assert_eq!(layer_0, vec![a, b, c, d]);
678
679 let layer_1 = tree.get_layer_nodes(&leaves, 1);
681 assert_eq!(layer_1, vec![i, j]);
682 }
683}