1use std::marker::PhantomData;
2
3#[cfg(feature = "solana")]
4use anchor_lang::prelude::*;
5
6use bytemuck::{Pod, Zeroable};
7use config::MerkleTreeConfig;
8use hasher::{Hash, Hasher};
9
10pub mod config;
11pub mod constants;
12pub mod hasher;
13
14pub const DATA_LEN: usize = 32;
15pub const HASH_LEN: usize = 32;
16pub const MAX_HEIGHT: usize = 18;
17pub const MERKLE_TREE_HISTORY_SIZE: usize = 20;
18
19#[cfg(feature = "solana")]
20#[derive(AnchorSerialize, AnchorDeserialize, PartialEq, Eq, Debug, Clone, Copy)]
21pub enum HashFunction {
22 Sha256,
23 Poseidon,
24}
25
26#[cfg_attr(feature = "solana", derive(AnchorSerialize, AnchorDeserialize))]
29#[derive(PartialEq, Eq, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct MerkleTree<H, C>
32where
33 H: Hasher,
34 C: MerkleTreeConfig,
35{
36 pub height: u64,
38 pub filled_subtrees: [[u8; 32]; 18],
40 pub roots: [[u8; 32]; 20],
43 pub next_index: u64,
45 pub current_root_index: u64,
47
48 #[cfg(feature = "solana")]
50 pub hash_function: HashFunction,
51
52 hasher: PhantomData<H>,
53 config: PhantomData<C>,
54}
55
56impl<H, C> MerkleTree<H, C>
57where
58 H: Hasher,
59 C: MerkleTreeConfig,
60{
61 fn check_height(height: usize) {
62 assert!(height > 0);
63 assert!(height <= MAX_HEIGHT);
64 }
65
66 fn new_filled_subtrees(height: usize) -> [[u8; HASH_LEN]; MAX_HEIGHT] {
67 let mut filled_subtrees = [[0; HASH_LEN]; MAX_HEIGHT];
68
69 for i in 0..height {
70 filled_subtrees[i] = C::ZERO_BYTES[i];
71 }
72
73 filled_subtrees
74 }
75
76 fn new_roots(height: usize) -> [[u8; HASH_LEN]; MERKLE_TREE_HISTORY_SIZE] {
77 let mut roots = [[0; HASH_LEN]; MERKLE_TREE_HISTORY_SIZE];
78 roots[0] = C::ZERO_BYTES[height - 1];
79
80 roots
81 }
82
83 #[cfg(not(feature = "solana"))]
85 pub fn new(height: usize, #[cfg(feature = "solana")] hash_function: HashFunction) -> Self {
86 Self::check_height(height);
87
88 let filled_subtrees = Self::new_filled_subtrees(height);
89 let roots = Self::new_roots(height);
90
91 MerkleTree {
92 height: height as u64,
93 filled_subtrees,
94 roots,
95 next_index: 0,
96 current_root_index: 0,
97 #[cfg(feature = "solana")]
98 hash_function,
99 hasher: PhantomData,
100 config: PhantomData,
101 }
102 }
103
104 #[cfg(feature = "solana")]
107 pub fn init(&mut self, height: usize, hash_function: HashFunction) {
108 Self::check_height(height);
109
110 self.height = height as u64;
111 self.filled_subtrees = Self::new_filled_subtrees(height);
112 self.roots = Self::new_roots(height);
113 self.hash_function = hash_function;
114 }
115
116 pub fn hash(&mut self, leaf1: [u8; DATA_LEN], leaf2: [u8; DATA_LEN]) -> Hash {
117 H::hashv(&[&leaf1, &leaf2])
118 }
119
120 pub fn insert(&mut self, leaf1: [u8; DATA_LEN], leaf2: [u8; DATA_LEN]) {
121 assert_ne!(self.next_index, 2u64.pow(self.height as u32));
123
124 let mut current_index = self.next_index / 2;
125 let mut current_level_hash = self.hash(leaf1, leaf2);
126
127 for i in 1..self.height as usize {
128 let (left, right) = if current_index % 2 == 0 {
129 self.filled_subtrees[i] = current_level_hash;
130 (current_level_hash, C::ZERO_BYTES[i])
131 } else {
132 (self.filled_subtrees[i], current_level_hash)
133 };
134
135 current_index /= 2;
136 current_level_hash = self.hash(left, right);
137 }
138
139 self.current_root_index = (self.current_root_index + 1) % MERKLE_TREE_HISTORY_SIZE as u64;
140 self.roots[self.current_root_index as usize] = current_level_hash;
141 self.next_index += 2;
142 }
143
144 pub fn is_known_root(&self, root: [u8; HASH_LEN]) -> bool {
145 for i in (0..(self.current_root_index as usize + 1)).rev() {
146 if self.roots[i] == root {
147 return true;
148 }
149 }
150 return false;
151 }
152
153 pub fn last_root(&self) -> [u8; HASH_LEN] {
154 self.roots[self.current_root_index as usize]
155 }
156}
157
158unsafe impl<H, C> Pod for MerkleTree<H, C>
169where
170 H: Hasher + Copy + 'static,
171 C: MerkleTreeConfig + Copy + 'static,
172{
173}
174
175unsafe impl<H, C> Zeroable for MerkleTree<H, C>
187where
188 H: Hasher,
189 C: MerkleTreeConfig,
190{
191}
192
193#[cfg(feature = "solana")]
194impl<H, C> Owner for MerkleTree<H, C>
195where
196 H: Hasher,
197 C: MerkleTreeConfig,
198{
199 fn owner() -> Pubkey {
200 C::PROGRAM_ID
201 }
202}