1use bytemuck::{ Pod, Zeroable };
2use super::hash::{ Hash, Leaf, hashv };
3use super::error::{ BrineTreeError, ProgramResult };
4use super::utils::check_condition;
5
6#[repr(C, align(8))]
7#[derive(Clone, Copy, PartialEq, Debug)]
8pub struct MerkleTree<const N: usize> {
9 root: Hash,
10 filled_subtrees: [Hash; N],
11 zero_values: [Hash; N],
12 next_index: u64,
13}
14
15unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
16unsafe impl<const N: usize> Pod for MerkleTree<N> {}
17
18impl<const N: usize> MerkleTree<N> {
19 pub const fn get_depth(&self) -> u8 {
20 N as u8
21 }
22
23 pub const fn get_size() -> usize {
24 core::mem::size_of::<Self>()
25 }
26
27 pub fn get_root(&self) -> Hash {
28 self.root
29 }
30
31 pub fn get_empty_leaf(&self) -> Leaf {
32 self.zero_values[0].as_leaf()
33 }
34
35 pub fn new(seeds: &[&[u8]]) -> Self {
36 let zeros = Self::calc_zeros(seeds);
37 Self {
38 next_index: 0,
39 root: zeros[N - 1],
40 filled_subtrees: zeros,
41 zero_values: zeros,
42 }
43 }
44
45 pub fn init(&mut self, seeds: &[&[u8]]) {
46 let zeros = Self::calc_zeros(seeds);
47 self.next_index = 0;
48 self.root = zeros[N - 1];
49 self.filled_subtrees = zeros;
50 self.zero_values = zeros;
51 }
52
53 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
54 let mut zeros: [Hash; N] = [Hash::default(); N];
55 let mut current = hashv(seeds);
56
57 for i in 0..N {
58 zeros[i] = current;
59 current = hashv(&[b"NODE".as_ref(), current.as_ref(), current.as_ref()]);
60 }
61
62 zeros
63 }
64
65 pub fn try_add(&mut self, data: &[&[u8]]) -> ProgramResult {
66 let leaf = Leaf::new(data);
67 self.try_add_leaf(leaf)
68 }
69
70 pub fn try_add_leaf(&mut self, leaf: Leaf) -> ProgramResult {
71 check_condition(
72 self.next_index < (1u64 << N),
73 BrineTreeError::TreeFull,
74 )?;
75
76 let mut current_index = self.next_index;
77 let mut current_hash = Hash::from(leaf);
78 let mut left;
79 let mut right;
80
81 for i in 0..N {
82 if current_index % 2 == 0 {
83 left = current_hash;
84 right = self.zero_values[i];
85 self.filled_subtrees[i] = current_hash;
86 } else {
87 left = self.filled_subtrees[i];
88 right = current_hash;
89 }
90
91 current_hash = hash_left_right(left, right);
92 current_index /= 2;
93 }
94
95 self.root = current_hash;
96 self.next_index += 1;
97
98 Ok(())
99 }
100
101 pub fn try_remove(&mut self, proof: &[Hash], data: &[&[u8]]) -> ProgramResult {
102 let leaf = Leaf::new(data);
103 self.try_remove_leaf(proof, leaf)
104 }
105
106 pub fn try_remove_leaf(&mut self, proof: &[Hash], leaf: Leaf) -> ProgramResult {
107 self.check_length(proof)?;
108 self.try_replace_leaf(proof, leaf, self.get_empty_leaf())
109 }
110
111 pub fn try_replace(&mut self, proof: &[Hash], original_data: &[&[u8]], new_data: &[&[u8]]) -> ProgramResult {
112 let original_leaf = Leaf::new(original_data);
113 let new_leaf = Leaf::new(new_data);
114 self.try_replace_leaf(proof, original_leaf, new_leaf)
115 }
116
117 pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Leaf, new_leaf: Leaf) -> ProgramResult {
118 self.check_length(proof)?;
119
120 let original_path = compute_path(proof, original_leaf);
121 let new_path = compute_path(proof, new_leaf);
122
123 check_condition(
124 is_valid_path(&original_path, self.root),
125 BrineTreeError::InvalidProof,
126 )?;
127
128 for i in 0..N {
129 if original_path[i] == self.filled_subtrees[i] {
130 self.filled_subtrees[i] = new_path[i];
131 }
132 }
133
134 self.root = *new_path.last().unwrap();
135
136 Ok(())
137 }
138
139 pub fn contains(&self, proof: &[Hash], data: &[&[u8]]) -> bool {
140 let leaf = Leaf::new(data);
141 self.contains_leaf(proof, leaf)
142 }
143
144 pub fn contains_leaf(&self, proof: &[Hash], leaf: Leaf) -> bool {
145 if let Err(_) = self.check_length(proof) {
146 return false;
147 }
148
149 let root = self.get_root();
150 is_valid_leaf(proof, root, leaf)
151 }
152
153 fn check_length(&self, proof: &[Hash]) -> Result<(), BrineTreeError> {
154 check_condition(proof.len() == N, BrineTreeError::ProofLength)
155 }
156
157 fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
158 let mut res = Vec::with_capacity(pairs.len() / 2);
159
160 for i in (0..pairs.len()).step_by(2) {
161 let left = pairs[i];
162 let right = pairs[i + 1];
163
164 let hashed = hash_left_right(left, right);
165 res.push(hashed);
166 }
167
168 res
169 }
170
171 pub fn get_merkle_proof(&self, leaves: &[Leaf], leaf_index: usize) -> Vec<Hash> {
172 let mut layers = Vec::with_capacity(N);
173 let mut current_layer: Vec<Hash> = leaves.iter().map(|leaf| Hash::from(*leaf)).collect();
174
175 for i in 0..N {
176 if current_layer.len() % 2 != 0 {
177 current_layer.push(self.zero_values[i]);
178 }
179
180 layers.push(current_layer.clone());
181 current_layer = Self::hash_pairs(current_layer);
182 }
183
184 let mut proof = Vec::with_capacity(N);
185 let mut current_index = leaf_index;
186 let mut layer_index = 0;
187 let mut sibling;
188
189 for _ in 0..N {
190 if current_index % 2 == 0 {
191 sibling = layers[layer_index][current_index + 1];
192 } else {
193 sibling = layers[layer_index][current_index - 1];
194 }
195
196 proof.push(sibling);
197
198 current_index /= 2;
199 layer_index += 1;
200 }
201
202 proof
203 }
204}
205
206fn hash_left_right(left: Hash, right: Hash) -> Hash {
207 let combined;
208 if left.to_bytes() <= right.to_bytes() {
209 combined = [b"NODE".as_ref(), left.as_ref(), right.as_ref()];
210 } else {
211 combined = [b"NODE".as_ref(), right.as_ref(), left.as_ref()];
212 }
213
214 hashv(&combined)
215}
216
217fn compute_path(proof: &[Hash], leaf: Leaf) -> Vec<Hash> {
218 let mut computed_path = Vec::with_capacity(proof.len() + 1);
219 let mut computed_hash = Hash::from(leaf);
220
221 computed_path.push(computed_hash);
222
223 for proof_element in proof.iter() {
224 computed_hash = hash_left_right(computed_hash, *proof_element);
225 computed_path.push(computed_hash);
226 }
227
228 computed_path
229}
230
231fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Leaf) -> bool {
232 let computed_path = compute_path(proof, leaf);
233 is_valid_path(&computed_path, root)
234}
235
236fn is_valid_path(path: &[Hash], root: Hash) -> bool {
237 if path.is_empty() {
238 return false;
239 }
240
241 *path.last().unwrap() == root
242}
243
244pub fn verify<Root, Item, L>(
246 root: Root,
247 proof: &[Item],
248 leaf: L,
249) -> bool
250where
251 Root: Into<Hash>,
252 Item: Into<Hash> + Copy,
253 L: Into<Leaf>,
254{
255 let root_h: Hash = root.into();
256 let proof_hashes: Vec<Hash> =
257 proof.iter()
258 .map(|&x| x.into())
259 .collect();
260
261 let leaf_h: Leaf = leaf.into();
262 let path = compute_path(&proof_hashes, leaf_h);
263 is_valid_path(&path, root_h)
264}
265
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 type TestTree = MerkleTree<3>;
272
273 #[test]
274 fn test_create_tree() {
275 let seeds: &[&[u8]] = &[b"test"];
276 let tree = TestTree::new(seeds);
277
278 assert_eq!(tree.get_depth(), 3);
279 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
280 }
281
282 #[test]
283 fn test_insert_and_remove() {
284 let seeds: &[&[u8]] = &[b"test"];
285
286 let mut tree = TestTree::new(seeds);
287 let empty = tree.zero_values.first().unwrap().clone();
288 let empty_leaf = empty.as_leaf();
289
290 let a = Hash::from(Leaf::new(&[b"val_1"]));
301 let b = Hash::from(Leaf::new(&[b"val_2"]));
302 let c = Hash::from(Leaf::new(&[b"val_3"]));
303
304 let d = empty.clone();
305 let e = empty.clone();
306 let f = empty.clone();
307 let g = empty.clone();
308 let h = empty.clone();
309
310 let i = hash_left_right(a, b);
311 let j: Hash = hash_left_right(c, d);
312 let k: Hash = hash_left_right(e, f);
313 let l: Hash = hash_left_right(g, h);
314 let m: Hash = hash_left_right(i, j);
315 let n: Hash = hash_left_right(k, l);
316 let root = hash_left_right(m, n);
317
318 assert!(tree.try_add(&[b"val_1"]).is_ok());
319 assert!(tree.filled_subtrees[0].eq(&a));
320
321 assert!(tree.try_add(&[b"val_2"]).is_ok());
322 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_add(&[b"val_3"]).is_ok());
325 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
328 assert_eq!(tree.filled_subtrees[1], i);
329 assert_eq!(tree.filled_subtrees[2], m);
330 assert_eq!(root, tree.get_root());
331
332 let val1_proof = vec![b.clone(), j.clone(), n.clone()];
333 let val2_proof = vec![a.clone(), j.clone(), n.clone()];
334 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
335
336 assert!(tree.contains(&val1_proof, &[b"val_1"]));
338 assert!(tree.contains(&val2_proof, &[b"val_2"]));
339 assert!(tree.contains(&val3_proof, &[b"val_3"]));
340
341 assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty_leaf));
343 assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty_leaf));
344 assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty_leaf));
345 assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty_leaf));
346 assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty_leaf));
347
348 assert!(tree.try_remove(&val2_proof, &[b"val_2"]).is_ok());
350
351 let i = hash_left_right(a, empty);
353 let m: Hash = hash_left_right(i, j);
354 let root = hash_left_right(m, n);
355
356 assert_eq!(root, tree.get_root());
357
358 let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
359 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
360
361 assert!(tree.contains_leaf(&val1_proof, Leaf::new(&[b"val_1"])));
362 assert!(tree.contains_leaf(&val2_proof, empty_leaf));
363 assert!(tree.contains_leaf(&val3_proof, Leaf::new(&[b"val_3"])));
364
365 assert!(!tree.contains_leaf(&val2_proof, Leaf::new(&[b"val_2"])));
367
368 assert!(tree.try_add(&[b"val_4"]).is_ok());
370 assert!(tree.filled_subtrees[0].eq(&c)); let d = Hash::from(Leaf::new(&[b"val_4"]));
374 let j = hash_left_right(c, d);
375 let m = hash_left_right(i, j);
376 let root = hash_left_right(m, n);
377
378 assert_eq!(root, tree.get_root());
379 }
380
381 #[test]
382 fn test_proof() {
383 let seeds: &[&[u8]] = &[b"test"];
384
385 let mut tree = TestTree::new(seeds);
386
387 let leaves = [
388 Leaf::new(&[b"val_1"]),
389 Leaf::new(&[b"val_2"]),
390 Leaf::new(&[b"val_3"]),
391 ];
392
393 assert!(tree.try_add(&[b"val_1"]).is_ok());
394 assert!(tree.try_add(&[b"val_2"]).is_ok());
395 assert!(tree.try_add(&[b"val_3"]).is_ok());
396
397 let val1_proof = tree.get_merkle_proof(&leaves, 0);
398 let val2_proof = tree.get_merkle_proof(&leaves, 1);
399 let val3_proof = tree.get_merkle_proof(&leaves, 2);
400
401 assert!(tree.contains(&val1_proof, &[b"val_1"]));
402 assert!(tree.contains(&val2_proof, &[b"val_2"]));
403 assert!(tree.contains(&val3_proof, &[b"val_3"]));
404
405 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(&invalid_proof_short, &[b"val_1"]));
410 assert!(!tree.contains(&invalid_proof_long, &[b"val_1"]));
411
412 let empty_proof: Vec<Hash> = Vec::new();
414 assert!(!tree.contains(&empty_proof, &[b"val_1"]));
415 }
416
417 #[test]
418 fn test_init_and_reinit() {
419 let seeds: &[&[u8]] = &[b"test"];
420 let mut tree = TestTree::new(seeds);
421
422 let initial_root = tree.get_root();
424 let initial_zeros = tree.zero_values;
425 let initial_filled = tree.filled_subtrees;
426 let initial_index = tree.next_index;
427
428 assert!(tree.try_add(&[b"val_1"]).is_ok());
430
431 tree.init(seeds);
433
434 assert_eq!(tree.get_root(), initial_root);
436 assert_eq!(tree.zero_values, initial_zeros);
437 assert_eq!(tree.filled_subtrees, initial_filled);
438 assert_eq!(tree.next_index, initial_index);
439 }
440
441 #[test]
442 fn test_tree_full() {
443 let seeds: &[&[u8]] = &[b"test"];
444 let mut tree = TestTree::new(seeds);
445
446 for i in 0u8..8 {
448 assert!(tree.try_add(&[&[i]]).is_ok());
449 }
450
451 let result = tree.try_add(&[b"extra"]);
453 assert!(result.is_err());
454 assert_eq!(result.unwrap_err(), BrineTreeError::TreeFull);
455 }
456
457 #[test]
458 fn test_replace_leaf() {
459 let seeds: &[&[u8]] = &[b"test"];
460 let mut tree = TestTree::new(seeds);
461
462 assert!(tree.try_add(&[b"val_1"]).is_ok());
464 assert!(tree.try_add(&[b"val_2"]).is_ok());
465
466 let leaves = [
468 Leaf::new(&[b"val_1"]),
469 Leaf::new(&[b"val_2"]),
470 ];
471 let proof = tree.get_merkle_proof(&leaves, 0);
472
473 assert!(tree.try_replace(&proof, &[b"val_1"], &[b"new_val"]).is_ok());
475
476 assert!(tree.contains(&proof, &[b"new_val"]));
478 assert!(!tree.contains(&proof, &[b"val_1"]));
479
480 let proof_val2 = tree.get_merkle_proof(&[Leaf::new(&[b"new_val"]), leaves[1]], 1);
482 assert!(tree.contains(&proof_val2, &[b"val_2"]));
483 }
484
485 #[test]
486 fn test_verify() {
487 let seeds: &[&[u8]] = &[b"test"];
488 let mut tree = TestTree::new(seeds);
489
490 assert!(tree.try_add(&[b"val_1"]).is_ok());
492 assert!(tree.try_add(&[b"val_2"]).is_ok());
493
494 let leaves = [
496 Leaf::new(&[b"val_1"]),
497 Leaf::new(&[b"val_2"]),
498 ];
499 let proof = tree.get_merkle_proof(&leaves, 0);
500
501 assert!(verify(tree.get_root(), &proof, Leaf::new(&[b"val_1"])));
503
504 let a : [u8; 32] = tree.get_root().to_bytes();
505 let b : [[u8; 32]; 3] = [
506 proof[0].to_bytes(),
507 proof[1].to_bytes(),
508 proof[2].to_bytes(),
509 ];
510 let c : [u8; 32] = Leaf::new(&[b"val_1"]).to_bytes();
511
512 assert!(verify(a, &b, c));
514 }
515}