1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9 root: Hash,
10 filled_subtrees: [Hash; N],
11 zero_values: [Hash; N],
12 next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19 pub const fn get_depth(&self) -> u8 {
20 N as u8
21 }
22
23 pub const fn get_size() -> usize {
24 core::mem::size_of::<Self>()
25 }
26
27 pub fn get_root(&self) -> Hash {
28 self.root
29 }
30
31 pub fn get_empty_leaf(&self) -> Leaf {
32 self.zero_values[0].as_leaf()
33 }
34
35 pub fn new(seeds: &[&[u8]]) -> Self {
36 let zeros = Self::calc_zeros(seeds);
37 Self {
38 next_index: 0,
39 root: zeros[N - 1],
40 filled_subtrees: zeros,
41 zero_values: zeros,
42 }
43 }
44
45 pub fn init(&mut self, seeds: &[&[u8]]) {
46 let zeros = Self::calc_zeros(seeds);
47 self.next_index = 0;
48 self.root = zeros[N - 1];
49 self.filled_subtrees = zeros;
50 self.zero_values = zeros;
51 }
52
53 pub fn get_leaf_count(&self) -> u64 {
54 self.next_index
55 }
56
57 pub fn get_capacity(&self) -> u64 {
58 1u64 << N
59 }
60
61 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
62 let mut zeros: [Hash; N] = [Hash::default(); N];
63 let mut current = hashv(seeds);
64
65 for i in 0..N {
66 zeros[i] = current;
67 current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
68 }
69
70 zeros
71 }
72
73 pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
74 let leaf = Leaf::new(data);
75 self.try_add_leaf(leaf)
76 }
77
78 pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
79 check_condition(
80 self.next_index < (1u64 << N),
81 BrineTreeError::TreeFull,
82 )?;
83
84 let mut current_index = self.next_index;
85 let mut current_hash = Hash::from(leaf);
86 let mut left;
87 let mut right;
88
89 for i in 0..N {
90 if current_index % 2 == 0 {
91 left = current_hash;
92 right = self.zero_values[i];
93 self.filled_subtrees[i] = current_hash;
94 } else {
95 left = self.filled_subtrees[i];
96 right = current_hash;
97 }
98
99 current_hash = hash_left_right(left, right);
100 current_index /= 2;
101 }
102
103 self.root = current_hash;
104 self.next_index += 1;
105
106 Ok(())
107 }
108
109 pub fn try_remove<P>(
110 &mut self,
111 proof: &[P],
112 data: &[&[u8]],
113 ) -> ProgramResult
114 where
115 P: Into<Hash> + Copy,
116 {
117 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
118 let original_leaf = Leaf::new(data);
119 self.try_remove_leaf(&proof_hashes, original_leaf)
120 }
121
122 pub fn try_remove_leaf<P>(
123 &mut self,
124 proof: &[P],
125 leaf: Leaf,
126 ) -> ProgramResult
127 where
128 P: Into<Hash> + Copy,
129 {
130 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
131 self.check_length(&proof_hashes)?;
132 self.try_replace_leaf(&proof_hashes, leaf, self.get_empty_leaf())
133 }
134
135 pub fn try_replace<P>(
136 &mut self,
137 proof: &[P],
138 original_data: &[&[u8]],
139 new_data: &[&[u8]],
140 ) -> ProgramResult
141 where
142 P: Into<Hash> + Copy,
143 {
144 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
145 let original_leaf = Leaf::new(original_data);
146 let new_leaf = Leaf::new(new_data);
147 self.try_replace_leaf(&proof_hashes, original_leaf, new_leaf)
148 }
149
150 pub fn try_replace_leaf<P>(
151 &mut self,
152 proof: &[P],
153 original_leaf: Leaf,
154 new_leaf: Leaf,
155 ) -> ProgramResult
156 where
157 P: Into<Hash> + Copy,
158 {
159 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
160 self.check_length(&proof_hashes)?;
161 let original_path = compute_path(&proof_hashes, original_leaf);
162 let new_path = compute_path(&proof_hashes, new_leaf);
163 check_condition(
164 is_valid_path(&original_path, self.root),
165 BrineTreeError::InvalidProof,
166 )?;
167 for i in 0..N {
168 if original_path[i] == self.filled_subtrees[i] {
169 self.filled_subtrees[i] = new_path[i];
170 }
171 }
172 self.root = *new_path.last().unwrap();
173 Ok(())
174 }
175
176 pub fn contains<P>(
177 &self,
178 proof: &[P],
179 data: &[&[u8]],
180 ) -> bool
181 where
182 P: Into<Hash> + Copy,
183 {
184 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
185 let leaf = Leaf::new(data);
186 self.contains_leaf(&proof_hashes, leaf)
187 }
188
189 pub fn contains_leaf<P>(
190 &self,
191 proof: &[P],
192 leaf: Leaf,
193 ) -> bool
194 where
195 P: Into<Hash> + Copy,
196 {
197 let proof_hashes: Vec<Hash> = proof.iter().map(|p| (*p).into()).collect();
198 if self.check_length(&proof_hashes).is_err() {
199 return false;
200 }
201 is_valid_leaf(&proof_hashes, self.root, leaf)
202 }
203
204 fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
205 check_condition(proof.len() == N, BrineTreeError::ProofLength)
206 }
207
208 fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
209 let mut res = Vec::with_capacity(pairs.len() / 2);
210
211 for i in (0..pairs.len()).step_by(2) {
212 let left = pairs[i];
213 let right = pairs[i + 1];
214
215 let hashed = hash_left_right(left, right);
216 res.push(hashed);
217 }
218
219 res
220 }
221
222 pub fn get_merkle_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
223 let mut layers = Vec::with_capacity(N);
224 let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
225
226 for i in 0..N {
227 if current_layer.len() % 2 != 0 {
228 current_layer.push(self.zero_values[i]);
229 }
230
231 layers.push(current_layer.clone());
232 current_layer = Self::hash_pairs(current_layer);
233 }
234
235 let mut proof = Vec::with_capacity(N);
236 let mut current_index = leaf_index;
237 let mut layer_index = 0;
238 let mut sibling;
239
240 for _ in 0..N {
241 if current_index % 2 == 0 {
242 sibling = layers[layer_index][current_index + 1];
243 } else {
244 sibling = layers[layer_index][current_index - 1];
245 }
246
247 proof.push(sibling);
248
249 current_index /= 2;
250 layer_index += 1;
251 }
252
253 proof
254 }
255}
256
257fn hash_left_right(left: Hash, right: Hash) -> Hash {
258 let combined;
259 if left.to_bytes() <= right.to_bytes() {
260 combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
261 } else {
262 combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
263 }
264
265 hashv(&combined)
266}
267
268fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
269 let mut computed_path = Vec::with_capacity(proof.len() + 1);
270 let mut computed_hash = Hash::from(leaf);
271
272 computed_path.push(computed_hash);
273
274 for proof_element in proof.iter() {
275 computed_hash = hash_left_right(computed_hash, *proof_element);
276 computed_path.push(computed_hash);
277 }
278
279 computed_path
280}
281
282fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
283 let computed_path = compute_path(proof, leaf);
284 is_valid_path(&computed_path, root)
285}
286
287fn is_valid_path(path: &[Hash], root: Hash) -> bool {
288 if path.is_empty() {
289 return false;
290 }
291
292 *path.last().unwrap() == root
293}
294
295pub fn verify<Root, Item, L>(
297 root: Root,
298 proof: &[Item],
299 leaf: L,
300) -> bool
301where
302 Root: Into<Hash>,
303 Item: Into<Hash> + Copy,
304 L: Into<Leaf>,
305{
306 let root_h: Hash = root.into();
307 let proof_hashes: Vec<Hash> =
308 proof.iter()
309 .map(|&x| x.into())
310 .collect();
311
312 let leaf_h: Leaf = leaf.into();
313 let path = compute_path(&proof_hashes, leaf_h);
314 is_valid_path(&path, root_h)
315}
316
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 type TestTree = MerkleTree<3>;
323
324 #[test]
325 fn test_create_tree() {
326 let seeds: &[&[u8]] = &[b"test"];
327 let tree = TestTree::new(seeds);
328
329 assert_eq!(tree.get_depth(), 3);
330 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
331 }
332
333 #[test]
334 fn test_insert_and_remove() {
335 let seeds: &[&[u8]] = &[b"test"];
336
337 let mut tree = TestTree::new(seeds);
338 let empty = tree.zero_values.first().unwrap().clone();
339 let empty_leaf = empty.as_leaf();
340
341 let a = Hash::from(Leaf::new(&[b"val_1"]));
352 let b = Hash::from(Leaf::new(&[b"val_2"]));
353 let c = Hash::from(Leaf::new(&[b"val_3"]));
354
355 let d = empty.clone();
356 let e = empty.clone();
357 let f = empty.clone();
358 let g = empty.clone();
359 let h = empty.clone();
360
361 let i = hash_left_right(a, b);
362 let j: Hash = hash_left_right(c, d);
363 let k: Hash = hash_left_right(e, f);
364 let l: Hash = hash_left_right(g, h);
365 let m: Hash = hash_left_right(i, j);
366 let n: Hash = hash_left_right(k, l);
367 let root = hash_left_right(m, n);
368
369 assert!(tree.try_add(&[b"val_1"]).is_ok());
370 assert!(tree.filled_subtrees[0].eq(&a));
371
372 assert!(tree.try_add(&[b"val_2"]).is_ok());
373 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_add(&[b"val_3"]).is_ok());
376 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
379 assert_eq!(tree.filled_subtrees[1], i);
380 assert_eq!(tree.filled_subtrees[2], m);
381 assert_eq!(root, tree.get_root());
382
383 let val1_proof = vec![b.clone(), j.clone(), n.clone()];
384 let val2_proof = vec![a.clone(), j.clone(), n.clone()];
385 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
386
387 assert!(tree.contains(&val1_proof, &[b"val_1"]));
389 assert!(tree.contains(&val2_proof, &[b"val_2"]));
390 assert!(tree.contains(&val3_proof, &[b"val_3"]));
391
392 assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
394 assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
395 assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
396 assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
397 assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
398
399 assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
401
402 let i = hash_left_right(a, empty);
404 let m: Hash = hash_left_right(i, j);
405 let root = hash_left_right(m, n);
406
407 assert_eq!(root, tree.get_root());
408
409 let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
410 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
411
412 assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
413 assert!(tree.contains_leaf(&val2_proof, empty_leaf));
414 assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
415
416 assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
418
419 assert!(tree.try_add(&[b"val_4"]).is_ok());
421 assert!(tree.filled_subtrees[0].eq(&c)); let d = Hash::from(Leaf::new(&[b"val_4"]));
425 let j = hash_left_right(c, d);
426 let m = hash_left_right(i, j);
427 let root = hash_left_right(m, n);
428
429 assert_eq!(root, tree.get_root());
430 }
431
432 #[test]
433 fn test_proof() {
434 let seeds: &[&[u8]] = &[b"test"];
435
436 let mut tree = TestTree::new(seeds);
437
438 let leaves = [
439 Leaf::new(&[b"val_1"]),
440 Leaf::new(&[b"val_2"]),
441 Leaf::new(&[b"val_3"]),
442 ];
443
444 assert!(tree.try_add(&[b"val_1"]).is_ok());
445 assert!(tree.try_add(&[b"val_2"]).is_ok());
446 assert!(tree.try_add(&[b"val_3"]).is_ok());
447
448 let val1_proof = tree.get_merkle_proof(&leaves, 0);
449 let val2_proof = tree.get_merkle_proof(&leaves, 1);
450 let val3_proof = tree.get_merkle_proof(&leaves, 2);
451
452 assert!(tree.contains(&val1_proof, &[b"val_1"]));
453 assert!(tree.contains(&val2_proof, &[b"val_2"]));
454 assert!(tree.contains(&val3_proof, &[b"val_3"]));
455
456 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
461 assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
462
463 let empty_proof: Vec<Hash> = Vec::new();
465 assert!(!tree.contains(&empty_proof, &[b"val_1"]));
466 }
467
468 #[test]
469 fn test_init_and_reinit() {
470 let seeds: &[&[u8]] = &[b"test"];
471 let mut tree = TestTree::new(seeds);
472
473 let initial_root = tree.get_root();
475 let initial_zeros = tree.zero_values;
476 let initial_filled = tree.filled_subtrees;
477 let initial_index = tree.next_index;
478
479 assert!(tree.try_add(&[b"val_1"]).is_ok());
481
482 tree.init(seeds);
484
485 assert_eq!(tree.get_root(), initial_root);
487 assert_eq!(tree.zero_values, initial_zeros);
488 assert_eq!(tree.filled_subtrees, initial_filled);
489 assert_eq!(tree.next_index, initial_index);
490 }
491
492 #[test]
493 fn test_tree_full() {
494 let seeds: &[&[u8]] = &[b"test"];
495 let mut tree = TestTree::new(seeds);
496
497 for i in 0u8..8 {
499 assert!(tree.try_add(&[&[i]]).is_ok());
500 }
501
502 let result = tree.try_add(&[b"extra"]);
504 assert!(result.is_err());
505 assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
506 }
507
508 #[test]
509 fn test_replace_leaf() {
510 let seeds: &[&[u8]] = &[b"test"];
511 let mut tree = TestTree::new(seeds);
512
513 assert!(tree.try_add(&[b"val_1"]).is_ok());
515 assert!(tree.try_add(&[b"val_2"]).is_ok());
516
517 let leaves = [
519 Leaf::new(&[b"val_1"]),
520 Leaf::new(&[b"val_2"]),
521 ];
522 let proof = tree.get_merkle_proof(&leaves, 0);
523
524 assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
526
527 assert!(tree.contains(&proof, &[b"new_val"]));
529 assert!(!tree.contains(&proof, &[b"val_1"]));
530
531 let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
533 assert!(tree.contains(&proof_val2, &[b"val_2"]));
534 }
535
536 #[test]
537 fn test_verify() {
538 let seeds: &[&[u8]] = &[b"test"];
539 let mut tree = TestTree::new(seeds);
540
541 assert!(tree.try_add(&[b"val_1"]).is_ok());
543 assert!(tree.try_add(&[b"val_2"]).is_ok());
544
545 let leaves = [
547 Leaf::new(&[b"val_1"]),
548 Leaf::new(&[b"val_2"]),
549 ];
550 let proof = tree.get_merkle_proof(&leaves, 0);
551
552 assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
554
555 let a : [u8; 32] = tree.get_root().to_bytes();
556 let b : [[u8; 32]; 3] = [
557 proof[0].to_bytes(),
558 proof[1].to_bytes(),
559 proof[2].to_bytes(),
560 ];
561 let c : [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
562
563 assert!(verify(a, &b, c));
565 }
566}