1use super::{
2 error::{BrineTreeError, ProgramResult},
3 hash::{hashv, Hash, Leaf},
4 utils::check_condition,
5};
6use bytemuck::{Pod, Zeroable};
7
8#[repr(C)]
9#[derive(Clone, Copy, PartialEq, Debug)]
10pub struct MerkleTree<const N: usize> {
11 pub root: Hash,
12 pub filled_subtrees: [Hash; N],
13 pub zero_values: [Hash; N],
14 pub next_index: u64,
15}
16
17unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
18unsafe impl<const N: usize> Pod for MerkleTree<N> {}
19
20impl<const N: usize> MerkleTree<N> {
21
22 pub fn new(seeds: &[&[u8]]) -> Self {
23 let zeros = Self::calc_zeros(seeds);
24 Self {
25 next_index: 0,
26 root: zeros[N - 1],
27 filled_subtrees: zeros,
28 zero_values: zeros,
29 }
30 }
31
32 pub fn from_zeros(zeros: [Hash; N]) -> Self {
33 Self {
34 next_index: 0,
35 root: zeros[N - 1],
36 filled_subtrees: zeros,
37 zero_values: zeros,
38 }
39 }
40
41 pub const fn get_depth(&self) -> u8 {
42 N as u8
43 }
44
45 pub const fn get_size() -> usize {
46 core::mem::size_of::<Self>()
47 }
48
49 pub fn get_root(&self) -> Hash {
50 self.root
51 }
52
53 pub fn get_empty_leaf(&self) -> Leaf {
54 self.zero_values[0].as_leaf()
55 }
56
57 pub fn init(&mut self, seeds: &[&[u8]]) {
58 let zeros = Self::calc_zeros(seeds);
59 self.next_index = 0;
60 self.root = zeros[N - 1];
61 self.filled_subtrees = zeros;
62 self.zero_values = zeros;
63 }
64
65 pub fn get_leaf_count(&self) -> u64 {
67 self.next_index
68 }
69
70 pub fn get_capacity(&self) -> u64 {
72 1u64 << N
73 }
74
75 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
77 let mut zeros: [Hash; N] = [Hash::default(); N];
78 let mut current = hashv(seeds);
79
80 for i in 0..N {
81 zeros[i] = current;
82 current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
83 }
84
85 zeros
86 }
87
88 pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
90 let leaf = Leaf::new(data);
91 self.try_add_leaf(leaf)
92 }
93
94 pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
96 check_condition(self.next_index < (1u64 << N), BrineTreeError::TreeFull)?;
97
98 let mut current_index = self.next_index;
99 let mut current_hash = Hash::from(leaf);
100 let mut left;
101 let mut right;
102
103 for i in 0..N {
104 if current_index % 2 == 0 {
105 left = current_hash;
106 right = self.zero_values[i];
107 self.filled_subtrees[i] = current_hash;
108 } else {
109 left = self.filled_subtrees[i];
110 right = current_hash;
111 }
112
113 current_hash = hash_left_right(left, right);
114 current_index /= 2;
115 }
116
117 self.root = current_hash;
118 self.next_index += 1;
119
120 Ok(())
121 }
122
123 pub fn try_remove<P>(&mut self, proof: &[P], data: &[&[u8]]) -> ProgramResult
125 where
126 P: Into<Hash> + Copy,
127 {
128 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
129 let original_leaf = Leaf::new(data);
130 self.try_remove_leaf(&proof_hashes, original_leaf)
131 }
132
133 pub fn try_remove_leaf<P>(&mut self, proof: &[P], leaf: Leaf) -> ProgramResult
135 where
136 P: Into<Hash> + Copy,
137 {
138 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
139 self.check_length(&proof_hashes)?;
140 self.try_replace_leaf(&proof_hashes, leaf, self.get_empty_leaf())
141 }
142
143 pub fn try_replace<P>(
145 &mut self,
146 proof: &[P],
147 original_data: &[&[u8]],
148 new_data: &[&[u8]],
149 ) -> ProgramResult
150 where
151 P: Into<Hash> + Copy,
152 {
153 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
154 let original_leaf = Leaf::new(original_data);
155 let new_leaf = Leaf::new(new_data);
156 self.try_replace_leaf(&proof_hashes, original_leaf, new_leaf)
157 }
158
159 pub fn try_replace_leaf<P>(
161 &mut self,
162 proof: &[P],
163 original_leaf: Leaf,
164 new_leaf: Leaf,
165 ) -> ProgramResult
166 where
167 P: Into<Hash> + Copy,
168 {
169 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
170 self.check_length(&proof_hashes)?;
171 let original_path = compute_path(&proof_hashes, original_leaf);
172 let new_path = compute_path(&proof_hashes, new_leaf);
173 check_condition(
174 is_valid_path(&original_path, self.root),
175 BrineTreeError::InvalidProof,
176 )?;
177 for i in 0..N {
178 if original_path[i] == self.filled_subtrees[i] {
179 self.filled_subtrees[i] = new_path[i];
180 }
181 }
182 self.root = *new_path.last().unwrap();
183 Ok(())
184 }
185
186 pub fn contains<P>(&self, proof: &[P], data: &[&[u8]]) -> bool
188 where
189 P: Into<Hash> + Copy,
190 {
191 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
192 let leaf = Leaf::new(data);
193 self.contains_leaf(&proof_hashes, leaf)
194 }
195
196 pub fn contains_leaf<P>(&self, proof: &[P], leaf: Leaf) -> bool
198 where
199 P: Into<Hash> + Copy,
200 {
201 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
202 if self.check_length(&proof_hashes).is_err() {
203 return false;
204 }
205 is_valid_leaf(&proof_hashes, self.root, leaf)
206 }
207
208 fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
210 check_condition(proof.len() == N, BrineTreeError::ProofLength)
211 }
212
213 pub fn get_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
215 get_merkle_proof(leaves, &self.zero_values, leaf_index, N)
216 }
217
218 pub fn get_layer_nodes(&self, leaves: &[Leaf], layer_number: usize) -> Vec<Hash> {
221 if layer_number > N {
222 return vec![];
223 }
224
225 let valid_leaves = leaves
226 .iter()
227 .take(self.next_index as usize)
228 .copied()
229 .collect::<Vec<Leaf>>();
230
231 let mut current_layer: Vec<Hash> =
232 valid_leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
233
234 if current_layer.is_empty() || layer_number == 0 {
235 return current_layer;
236 }
237
238 let mut current_level: usize = 0;
239 loop {
240 if current_layer.is_empty() {
241 break;
242 }
243 let mut next_layer = Vec::with_capacity(current_layer.len().div_ceil(2));
244 let mut i = 0;
245 while i < current_layer.len() {
246 if i + 1 < current_layer.len() {
247 let val = hash_left_right(current_layer[i], current_layer[i + 1]);
248 next_layer.push(val);
249 i += 2;
250 } else {
251 let val = hash_left_right(current_layer[i], self.zero_values[current_level]);
252 next_layer.push(val);
253 i += 1;
254 }
255 }
256 current_level += 1;
257 if current_level == layer_number {
258 return next_layer;
259 }
260 current_layer = next_layer;
261 }
262 vec![]
263 }
264}
265
266pub fn get_merkle_proof(
268 leaves: &[Leaf],
269 zero_values: &[Hash],
270 leaf_index: usize,
271 height: usize,
272) -> Vec<Hash> {
273 let mut layers = Vec::with_capacity(height);
274 let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
275
276 for i in 0..height {
277 if current_layer.len() % 2 != 0 {
278 current_layer.push(zero_values[i]);
279 }
280
281 layers.push(current_layer.clone());
282 current_layer = hash_pairs(current_layer);
283 }
284
285 let mut proof = Vec::with_capacity(height);
286 let mut current_index = leaf_index;
287 let mut layer_index = 0;
288
289 for _ in 0..height {
290 let sibling = if current_index % 2 == 0 {
291 layers[layer_index][current_index + 1]
292 } else {
293 layers[layer_index][current_index - 1]
294 };
295
296 proof.push(sibling);
297
298 current_index /= 2;
299 layer_index += 1;
300 }
301
302 proof
303}
304
305pub fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
307 let mut res = Vec::with_capacity(pairs.len() / 2);
308
309 for i in (0..pairs.len()).step_by(2) {
310 let left = pairs[i];
311 let right = pairs[i + 1];
312
313 let hashed = hash_left_right(left, right);
314 res.push(hashed);
315 }
316
317 res
318}
319
320pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
322 let combined;
323 if left.to_bytes() <= right.to_bytes() {
324 combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
325 } else {
326 combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
327 }
328
329 hashv(&combined)
330}
331
332pub fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
334 let mut computed_path = Vec::with_capacity(proof.len() + 1);
335 let mut computed_hash = Hash::from(leaf);
336
337 computed_path.push(computed_hash);
338
339 for proof_element in proof.iter() {
340 computed_hash = hash_left_right(computed_hash, *proof_element);
341 computed_path.push(computed_hash);
342 }
343
344 computed_path
345}
346
347fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
348 let computed_path = compute_path(proof, leaf);
349 is_valid_path(&computed_path, root)
350}
351
352fn is_valid_path(path: &[Hash], root: Hash) -> bool {
353 if path.is_empty() {
354 return false;
355 }
356
357 *path.last().unwrap() == root
358}
359
360pub fn verify<Root, Item, L>(root: Root, proof: &[Item], leaf: L) -> bool
362where
363 Root: Into<Hash>,
364 Item: Into<Hash> + Copy,
365 L: Into<Leaf>,
366{
367 let root_h: Hash = root.into();
368 let proof_hashes: Vec<Hash> = proof.iter().map(|&x| x.into()).collect();
369
370 let leaf_h: Leaf = leaf.into();
371 let path = compute_path(&proof_hashes, leaf_h);
372 is_valid_path(&path, root_h)
373}
374
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 type TestTree = MerkleTree<3>;
381
382 #[test]
383 fn test_create_tree() {
384 let seeds: &[&[u8]] = &[b"test"];
385 let tree = TestTree::new(seeds);
386
387 assert_eq!(tree.get_depth(), 3);
388 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
389 }
390
391 #[test]
392 fn test_insert_and_remove() {
393 let seeds: &[&[u8]] = &[b"test"];
394
395 let mut tree = TestTree::new(seeds);
396 let empty = *tree.zero_values.first().unwrap();
397 let empty_leaf = empty.as_leaf();
398
399 let a = Hash::from(Leaf::new(&[b"val_1"]));
410 let b = Hash::from(Leaf::new(&[b"val_2"]));
411 let c = Hash::from(Leaf::new(&[b"val_3"]));
412
413 let d = empty;
414 let e = empty;
415 let f = empty;
416 let g = empty;
417 let h = empty;
418
419 let i = hash_left_right(a, b);
420 let j: Hash = hash_left_right(c, d);
421 let k: Hash = hash_left_right(e, f);
422 let l: Hash = hash_left_right(g, h);
423 let m: Hash = hash_left_right(i, j);
424 let n: Hash = hash_left_right(k, l);
425 let root = hash_left_right(m, n);
426
427 assert!(tree.try_add(&[b"val_1"]).is_ok());
428 assert!(tree.filled_subtrees[0].eq(&a));
429
430 assert!(tree.try_add(&[b"val_2"]).is_ok());
431 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_add(&[b"val_3"]).is_ok());
434 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
437 assert_eq!(tree.filled_subtrees[1], i);
438 assert_eq!(tree.filled_subtrees[2], m);
439 assert_eq!(root, tree.get_root());
440
441 let val1_proof = vec![b, j, n];
442 let val2_proof = vec![a, j, n];
443 let val3_proof = vec![d, i, n];
444
445 assert!(tree.contains(&val1_proof, &[b"val_1"]));
447 assert!(tree.contains(&val2_proof, &[b"val_2"]));
448 assert!(tree.contains(&val3_proof, &[b"val_3"]));
449
450 assert!(tree.contains_leaf(&[c, i, n], empty_leaf));
452 assert!(tree.contains_leaf(&[f, l, m], empty_leaf));
453 assert!(tree.contains_leaf(&[e, l, m], empty_leaf));
454 assert!(tree.contains_leaf(&[h, k, m], empty_leaf));
455 assert!(tree.contains_leaf(&[g, k, m], empty_leaf));
456
457 assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
459
460 let i = hash_left_right(a, empty);
462 let m: Hash = hash_left_right(i, j);
463 let root = hash_left_right(m, n);
464
465 assert_eq!(root, tree.get_root());
466
467 let val1_proof = vec![empty, j, n];
468 let val3_proof = vec![d, i, n];
469
470 assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
471 assert!(tree.contains_leaf(&val2_proof, empty_leaf));
472 assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
473
474 assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
476
477 assert!(tree.try_add(&[b"val_4"]).is_ok());
479 assert!(tree.filled_subtrees[0].eq(&c)); let d = Hash::from(Leaf::new(&[b"val_4"]));
483 let j = hash_left_right(c, d);
484 let m = hash_left_right(i, j);
485 let root = hash_left_right(m, n);
486
487 assert_eq!(root, tree.get_root());
488 }
489
490 #[test]
491 fn test_proof() {
492 let seeds: &[&[u8]] = &[b"test"];
493
494 let mut tree = TestTree::new(seeds);
495
496 let leaves = [
497 Leaf::new(&[b"val_1"]),
498 Leaf::new(&[b"val_2"]),
499 Leaf::new(&[b"val_3"]),
500 ];
501
502 assert!(tree.try_add(&[b"val_1"]).is_ok());
503 assert!(tree.try_add(&[b"val_2"]).is_ok());
504 assert!(tree.try_add(&[b"val_3"]).is_ok());
505
506 let val1_proof = tree.get_proof(&leaves, 0);
507 let val2_proof = tree.get_proof(&leaves, 1);
508 let val3_proof = tree.get_proof(&leaves, 2);
509
510 assert!(tree.contains(&val1_proof, &[b"val_1"]));
511 assert!(tree.contains(&val2_proof, &[b"val_2"]));
512 assert!(tree.contains(&val3_proof, &[b"val_3"]));
513
514 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(invalid_proof_short, &[b"val_1"]));
519 assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
520
521 let empty_proof: Vec<Hash> = Vec::new();
523 assert!(!tree.contains(&empty_proof, &[b"val_1"]));
524 }
525
526 #[test]
527 fn test_init_and_reinit() {
528 let seeds: &[&[u8]] = &[b"test"];
529 let mut tree = TestTree::new(seeds);
530
531 let initial_root = tree.get_root();
533 let initial_zeros = tree.zero_values;
534 let initial_filled = tree.filled_subtrees;
535 let initial_index = tree.next_index;
536
537 assert!(tree.try_add(&[b"val_1"]).is_ok());
539
540 tree.init(seeds);
542
543 assert_eq!(tree.get_root(), initial_root);
545 assert_eq!(tree.zero_values, initial_zeros);
546 assert_eq!(tree.filled_subtrees, initial_filled);
547 assert_eq!(tree.next_index, initial_index);
548 }
549
550 #[test]
551 fn test_tree_full() {
552 let seeds: &[&[u8]] = &[b"test"];
553 let mut tree = TestTree::new(seeds);
554
555 for i in 0u8..8 {
557 assert!(tree.try_add(&[&[i]]).is_ok());
558 }
559
560 let result = tree.try_add(&[b"extra"]);
562 assert!(result.is_err());
563 assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
564 }
565
566 #[test]
567 fn test_replace_leaf() {
568 let seeds: &[&[u8]] = &[b"test"];
569 let mut tree = TestTree::new(seeds);
570
571 assert!(tree.try_add(&[b"val_1"]).is_ok());
573 assert!(tree.try_add(&[b"val_2"]).is_ok());
574
575 let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
577 let proof = tree.get_proof(&leaves, 0);
578
579 assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
581
582 assert!(tree.contains(&proof, &[b"new_val"]));
584 assert!(!tree.contains(&proof, &[b"val_1"]));
585
586 let proof_val2 = tree.get_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
588 assert!(tree.contains(&proof_val2, &[b"val_2"]));
589 }
590
591 #[test]
592 fn test_verify() {
593 let seeds: &[&[u8]] = &[b"test"];
594 let mut tree = TestTree::new(seeds);
595
596 assert!(tree.try_add(&[b"val_1"]).is_ok());
598 assert!(tree.try_add(&[b"val_2"]).is_ok());
599
600 let leaves = [Leaf::new(&[b"val_1"]), Leaf::new(&[b"val_2"])];
602 let proof = tree.get_proof(&leaves, 0);
603
604 assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
606
607 let a: [u8; 32] = tree.get_root().to_bytes();
608 let b: [[u8; 32]; 3] = [
609 proof[0].to_bytes(),
610 proof[1].to_bytes(),
611 proof[2].to_bytes(),
612 ];
613 let c: [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
614
615 assert!(verify(a, &b, c));
617 }
618
619 #[test]
620 fn test_get_layer_nodes() {
621 let seeds: &[&[u8]] = &[b"test"];
622 let mut tree = TestTree::new(seeds);
623 let empty = tree.zero_values[0];
624
625 let leaves = [
627 Leaf::new(&[b"val_1"]),
628 Leaf::new(&[b"val_2"]),
629 Leaf::new(&[b"val_3"]),
630 Leaf::new(&[b"val_4"]),
631 ];
632
633 assert_eq!(tree.get_layer_nodes(&leaves, 0), vec![]);
635 assert_eq!(tree.get_layer_nodes(&leaves, 1), vec![]);
636
637 assert!(tree.try_add(&[b"val_1"]).is_ok());
639 assert!(tree.try_add(&[b"val_2"]).is_ok());
640 assert!(tree.try_add(&[b"val_3"]).is_ok());
641
642 let a = Hash::from(leaves[0]);
652 let b = Hash::from(leaves[1]);
653 let c = Hash::from(leaves[2]);
654 let d = empty;
655 let i = hash_left_right(a, b);
656 let j = hash_left_right(c, d);
657
658 let layer_0 = tree.get_layer_nodes(&leaves, 0);
660 assert_eq!(layer_0, vec![a, b, c]);
661
662 let layer_1 = tree.get_layer_nodes(&leaves, 1);
664 assert_eq!(layer_1, vec![i, j]);
665
666 let layer_2 = tree.get_layer_nodes(&leaves, 2);
668 let m = hash_left_right(i, j);
669 assert_eq!(layer_2, vec![m]);
670
671 let layer_3 = tree.get_layer_nodes(&leaves, 3);
673 assert_eq!(layer_3, vec![tree.get_root()]);
674
675 let layer_4 = tree.get_layer_nodes(&leaves, 4);
677 assert_eq!(layer_4, vec![]);
678
679 assert!(tree.try_add(&[b"val_4"]).is_ok());
681 let d = Hash::from(leaves[3]);
682 let j = hash_left_right(c, d);
683
684 let layer_0 = tree.get_layer_nodes(&leaves, 0);
686 assert_eq!(layer_0, vec![a, b, c, d]);
687
688 let layer_1 = tree.get_layer_nodes(&leaves, 1);
690 assert_eq!(layer_1, vec![i, j]);
691 }
692}