1#![allow(unexpected_cfgs)]
2
3use bytemuck::{Pod, Zeroable};
4use super::hash::Hash;
5use super::{utils, utils::check_condition};
6use super::error::{ProgramError, ProgramResult};
7
8#[repr(C, align(8))]
9#[derive(Clone, Copy, PartialEq, Debug,)]
10pub struct MerkleTree<const N: usize> {
11 root: Hash,
12 filled_subtrees: [Hash; N],
13 zero_values: [Hash; N],
14 next_index: u64,
15}
16
17unsafe impl<const N: usize> Zeroable for MerkleTree<N> {}
18unsafe impl<const N: usize> Pod for MerkleTree<N> {}
19
20impl<const N: usize> MerkleTree<N> {
21 pub const fn get_depth(&self) -> u8 {
22 N as u8
23 }
24
25 pub const fn get_size() -> usize {
26 std::mem::size_of::<Self>()
27 }
28
29 pub fn get_root(&self) -> Hash {
30 self.root
31 }
32
33 pub fn get_empty_leaf(&self) -> Hash {
34 self.zero_values[0]
35 }
36
37 pub fn new(seeds: &[&[u8]]) -> Self {
38 let zeros = Self::calc_zeros(seeds);
39 Self {
40 next_index: 0,
41 root: zeros[N - 1],
42 filled_subtrees: zeros,
43 zero_values: zeros,
44 }
45 }
46
47 pub fn init(&mut self, seeds: &[&[u8]]) {
48 let zeros = Self::calc_zeros(seeds);
49 self.next_index = 0;
50 self.root = zeros[N - 1];
51 self.filled_subtrees = zeros;
52 self.zero_values = zeros;
53 }
54
55 fn calc_zeros(seeds: &[&[u8]]) -> [Hash; N] {
56 let mut zeros: [Hash; N] = [Hash::default(); N];
57 let mut current = utils::hashv(seeds);
58
59 for i in 0..N {
60 zeros[i] = current;
61 current = utils::hashv(&[current.as_ref(), current.as_ref()]);
62 }
63
64 zeros
65 }
66
67 pub fn try_insert(&mut self, val: Hash) -> ProgramResult {
68 check_condition(
69 self.next_index < (1u64 << N),
70 "merkle tree is full",
71 )?;
72
73 let mut current_index = self.next_index;
74 let mut current_hash = MerkleTree::<N>::as_leaf(val);
75 let mut left;
76 let mut right;
77
78 for i in 0..N {
79 if current_index % 2 == 0 {
80 left = current_hash;
81 right = self.zero_values[i];
82 self.filled_subtrees[i] = current_hash;
83 } else {
84 left = self.filled_subtrees[i];
85 right = current_hash;
86 }
87
88 current_hash = Self::hash_left_right(left, right);
89 current_index /= 2;
90 }
91
92 self.root = current_hash;
93 self.next_index += 1;
94
95 Ok(())
96 }
97
98 pub fn try_remove(&mut self, proof: &[Hash], val: Hash) -> ProgramResult {
99 self.check_length(proof)?;
100
101 self.try_replace_leaf(proof, Self::as_leaf(val), self.get_empty_leaf())
102 }
103
104 pub fn try_replace(&mut self, proof: &[Hash], original_val: Hash, new_val: Hash) -> ProgramResult {
105 self.check_length(proof)?;
106
107 let original_leaf = Self::as_leaf(original_val);
108 let new_leaf = Self::as_leaf(new_val);
109
110 self.try_replace_leaf(proof, original_leaf, new_leaf)
111 }
112
113 pub fn try_replace_leaf(&mut self, proof: &[Hash], original_leaf: Hash, new_leaf: Hash) -> ProgramResult {
114 self.check_length(proof)?;
115
116 let original_path = MerkleTree::<N>::compute_path(proof, original_leaf);
117 let new_path = MerkleTree::<N>::compute_path(proof, new_leaf);
118
119 check_condition(
120 MerkleTree::<N>::is_valid_path(&original_path, self.root),
121 "invalid proof for original leaf",
122 )?;
123
124 for i in 0..N {
125 if original_path[i] == self.filled_subtrees[i] {
126 self.filled_subtrees[i] = new_path[i];
127 }
128 }
129
130 self.root = *new_path.last().unwrap();
131
132 Ok(())
133 }
134
135 pub fn contains(&self, proof: &[Hash], val: Hash) -> bool {
136 if let Err(_) = self.check_length(proof) {
137 return false;
138 }
139
140 let leaf = Self::as_leaf(val);
141 self.contains_leaf(proof, leaf)
142 }
143
144 pub fn contains_leaf(&self, proof: &[Hash], leaf: Hash) -> bool {
145 if let Err(_) = self.check_length(proof) {
146 return false;
147 }
148
149 let root = self.get_root();
150 Self::is_valid_leaf(proof, root, leaf)
151 }
152
153 pub fn as_leaf(val: Hash) -> Hash {
154 utils::hash(val.as_ref())
155 }
156
157 pub fn hash_left_right(left: Hash, right: Hash) -> Hash {
158 let combined;
159 if left.to_bytes() <= right.to_bytes() {
160 combined = [left.as_ref(), right.as_ref()];
161 } else {
162 combined = [right.as_ref(), left.as_ref()];
163 }
164
165 utils::hashv(&combined)
166 }
167
168 pub fn compute_path(proof: &[Hash], leaf: Hash) -> Vec<Hash> {
169 let mut computed_path = Vec::with_capacity(proof.len() + 1);
170 let mut computed_hash = leaf;
171
172 computed_path.push(computed_hash);
173
174 for proof_element in proof.iter() {
175 computed_hash = Self::hash_left_right(computed_hash, *proof_element);
176 computed_path.push(computed_hash);
177 }
178
179 computed_path
180 }
181
182 pub fn is_valid_leaf(proof: &[Hash], root: Hash, leaf: Hash) -> bool {
183 let computed_path = Self::compute_path(proof, leaf);
184 Self::is_valid_path(&computed_path, root)
185 }
186
187 pub fn is_valid_path(path: &[Hash], root: Hash) -> bool {
188 if path.is_empty() {
189 return false;
190 }
191
192 *path.last().unwrap() == root
193 }
194
195 #[cfg(not(target_os = "solana"))]
196 fn hash_pairs(pairs: Vec<Hash>) -> Vec<Hash> {
197 let mut res = Vec::with_capacity(pairs.len() / 2);
200
201 for i in (0..pairs.len()).step_by(2) {
202 let left = pairs[i];
203 let right = pairs[i + 1];
204
205 let hashed = Self::hash_left_right(left, right);
206 res.push(hashed);
207 }
208
209 res
210 }
211
212 #[cfg(not(target_os = "solana"))]
213 pub fn get_merkle_proof(&self, values: &[Hash], index: usize) -> Vec<Hash> {
214 let mut layers = Vec::with_capacity(N);
215 let mut current_layer = values.to_vec();
216 for i in 0..N {
217 if current_layer.len() % 2 != 0 {
218 current_layer.push(self.zero_values[i]);
219 }
220
221 layers.push(current_layer.clone());
222 current_layer = Self::hash_pairs(current_layer);
223 }
224
225 let mut proof = Vec::with_capacity(N);
230 let mut current_index = index;
231 let mut layer_index = 0;
232 let mut sibling;
233
234 for _ in 0..N {
235 if current_index % 2 == 0 {
236 sibling = layers[layer_index][current_index + 1];
237 } else {
238 sibling = layers[layer_index][current_index - 1];
239 }
240
241 proof.push(sibling);
242
243 current_index /= 2;
244 layer_index += 1;
245 }
246
247 proof
248 }
249
250 fn check_length(&self, proof: &[Hash]) -> Result<(), ProgramError> {
251 check_condition(
252 proof.len() == N,
253 "merkle proof length does not match tree depth",
254 )
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 type TestTree = MerkleTree<3>;
263
264 #[test]
265 fn test_create_tree() {
266 let seeds : &[&[u8]] = &[b"test"];
267 let tree = TestTree::new(seeds);
268
269 assert_eq!(tree.get_depth(), 3);
270 assert_eq!(tree.get_root(), tree.zero_values.last().unwrap().clone());
271 }
272
273 #[test]
274 fn test_insert_and_remove() {
275 let seeds : &[&[u8]] = &[b"test"];
276
277 let mut tree = TestTree::new(seeds);
278 let empty = tree.zero_values.first().unwrap().clone();
279
280 let val1 = utils::hash(b"val_1");
281 let val2 = utils::hash(b"val_2");
282 let val3 = utils::hash(b"val_3");
283 let val4 = utils::hash(b"val_4");
284
285 let a = TestTree::as_leaf(val1);
296 let b = TestTree::as_leaf(val2);
297 let c = TestTree::as_leaf(val3);
298
299 let d = empty.clone();
300 let e = empty.clone();
301 let f = empty.clone();
302 let g = empty.clone();
303 let h = empty.clone();
304
305 let i = TestTree::hash_left_right(a, b);
306 let j: Hash = TestTree::hash_left_right(c, d);
307 let k: Hash = TestTree::hash_left_right(e, f);
308 let l: Hash = TestTree::hash_left_right(g, h);
309 let m: Hash = TestTree::hash_left_right(i, j);
310 let n: Hash = TestTree::hash_left_right(k, l);
311 let root = TestTree::hash_left_right(m, n);
312
313 assert!(tree.try_insert(val1.clone()).is_ok());
314 assert!(tree.filled_subtrees[0].eq(&a));
315
316 assert!(tree.try_insert(val2.clone()).is_ok());
317 assert!(tree.filled_subtrees[0].eq(&a)); assert!(tree.try_insert(val3.clone()).is_ok());
320 assert!(tree.filled_subtrees[0].eq(&c)); assert_eq!(tree.filled_subtrees[0], c);
323 assert_eq!(tree.filled_subtrees[1], i);
324 assert_eq!(tree.filled_subtrees[2], m);
325 assert_eq!(root, tree.get_root());
326
327 let val1_proof = vec![b.clone(), j.clone(), n.clone()];
328 let val2_proof = vec![a.clone(), j.clone(), n.clone()];
329 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
330
331 assert!(tree.contains(&val1_proof, val1));
333 assert!(tree.contains(&val2_proof, val2));
334 assert!(tree.contains(&val3_proof, val3));
335
336 assert!(tree.contains_leaf(&vec![c.clone(), i.clone(), n.clone()], empty));
338 assert!(tree.contains_leaf(&vec![f.clone(), l.clone(), m.clone()], empty));
339 assert!(tree.contains_leaf(&vec![e.clone(), l.clone(), m.clone()], empty));
340 assert!(tree.contains_leaf(&vec![h.clone(), k.clone(), m.clone()], empty));
341 assert!(tree.contains_leaf(&vec![g.clone(), k.clone(), m.clone()], empty));
342
343 assert!(tree.try_remove(&val2_proof, val2).is_ok());
345
346 let i = TestTree::hash_left_right(a, empty);
348 let m: Hash = TestTree::hash_left_right(i, j);
349 let root = TestTree::hash_left_right(m, n);
350
351 assert_eq!(root, tree.get_root());
352
353 let val1_proof = vec![empty.clone(), j.clone(), n.clone()];
354 let val3_proof = vec![d.clone(), i.clone(), n.clone()];
355
356 assert!(tree.contains_leaf(&val1_proof, a));
357 assert!(tree.contains_leaf(&val2_proof, empty));
358 assert!(tree.contains_leaf(&val3_proof, c));
359
360 assert!(!tree.contains_leaf(&val2_proof, b));
362
363 assert!(tree.try_insert(val4.clone()).is_ok());
365 assert!(tree.filled_subtrees[0].eq(&c)); let d = TestTree::as_leaf(val4.clone());
369 let j = TestTree::hash_left_right(c, d);
370 let m = TestTree::hash_left_right(i, j);
371 let root = TestTree::hash_left_right(m, n);
372
373 assert_eq!(root, tree.get_root());
374
375 }
376
377 #[test]
378 fn test_proof() {
379 let seeds : &[&[u8]] = &[b"test"];
380
381 let mut tree = TestTree::new(seeds);
382
383 let val1 = utils::hash(b"val_1");
384 let val2 = utils::hash(b"val_2");
385 let val3 = utils::hash(b"val_3");
386
387 let leaves = [
388 TestTree::as_leaf(val1),
389 TestTree::as_leaf(val2),
390 TestTree::as_leaf(val3),
391 ];
392
393 assert!(tree.try_insert(val1.clone()).is_ok());
394 assert!(tree.try_insert(val2.clone()).is_ok());
395 assert!(tree.try_insert(val3.clone()).is_ok());
396
397 let val1_proof = tree.get_merkle_proof(&leaves, 0);
398 let val2_proof = tree.get_merkle_proof(&leaves, 1);
399 let val3_proof = tree.get_merkle_proof(&leaves, 2);
400
401 assert!(tree.contains(&val1_proof, val1));
402 assert!(tree.contains(&val2_proof, val2));
403 assert!(tree.contains(&val3_proof, val3));
404
405 let invalid_proof_short = &val1_proof[..2]; let invalid_proof_long = [&val1_proof[..], &val1_proof[..]].concat(); assert!(!tree.contains(&invalid_proof_short, val1));
410 assert!(!tree.contains(&invalid_proof_long, val1));
411
412 let empty_proof: Vec<Hash> = Vec::new();
414 assert!(!tree.contains(&empty_proof, val1));
415 }
416}