1use crate::errors::DecdsError;
2use blake3;
3use std::collections::VecDeque;
4
5pub struct MerkleTree {
8 root: blake3::Hash,
9 leaves: Vec<blake3::Hash>,
10}
11
12impl MerkleTree {
13 pub fn new(leaf_nodes: Vec<blake3::Hash>) -> Result<Self, DecdsError> {
24 if leaf_nodes.is_empty() {
25 return Err(DecdsError::NoLeafNodesToBuildMerkleTreeOn);
26 }
27
28 let mut zero_hash = blake3::Hash::from_bytes([0u8; 32]);
29 let mut current_level = VecDeque::from(leaf_nodes.clone());
30
31 while current_level.len() > 1 {
32 let mut parent_level = VecDeque::new();
33
34 while !current_level.is_empty() {
35 let left = unsafe { current_level.pop_front().unwrap_unchecked() };
36 let right = current_level.pop_front().unwrap_or(zero_hash);
37
38 let parent = Self::parent_hash(left.as_bytes(), right.as_bytes());
39 parent_level.push_back(parent);
40 }
41
42 zero_hash = blake3::Hasher::new().update(zero_hash.as_bytes()).update(zero_hash.as_bytes()).finalize();
43 current_level = parent_level;
44 }
45
46 Ok(MerkleTree {
47 root: unsafe { current_level.pop_front().unwrap_unchecked() },
48 leaves: leaf_nodes,
49 })
50 }
51
52 pub fn get_root_commitment(&self) -> blake3::Hash {
58 self.root
59 }
60
61 pub fn generate_proof(&self, leaf_index: usize) -> Result<Vec<blake3::Hash>, DecdsError> {
76 if leaf_index >= self.leaves.len() {
77 return Err(DecdsError::InvalidLeafNodeIndex(leaf_index, self.leaves.len()));
78 }
79
80 let num_leaf_nodes = self.leaves.len();
81 let proof_size = num_leaf_nodes.next_power_of_two().ilog2() as usize;
82
83 let mut proof = Vec::with_capacity(proof_size);
84
85 let mut current_level: VecDeque<blake3::Hash> = self.leaves.clone().into();
86 let mut current_index = leaf_index;
87
88 let mut zero_hash = blake3::Hash::from_bytes([0u8; 32]);
89
90 while current_level.len() > 1 {
91 let mut parent_level = VecDeque::new();
92 let mut i = 0;
93
94 while i < current_level.len() {
95 let left = current_level[i];
96 let right = *current_level.get(i + 1).unwrap_or(&zero_hash);
97 let parent = Self::parent_hash(left.as_bytes(), right.as_bytes());
98
99 if current_index == i {
100 proof.push(right);
101 } else if current_index == i + 1 {
102 proof.push(left);
103 }
104
105 parent_level.push_back(parent);
106 i += 2;
107 }
108
109 current_index /= 2;
110 current_level = parent_level;
111
112 zero_hash = Self::parent_hash(zero_hash.as_bytes(), zero_hash.as_bytes());
113 }
114
115 Ok(proof)
116 }
117
118 pub fn verify_proof(leaf_index: usize, leaf_node: blake3::Hash, proof: &[blake3::Hash], root_hash: blake3::Hash) -> bool {
132 let mut current_hash = leaf_node;
133 let mut current_index = leaf_index;
134
135 for sibling_hash in proof {
136 current_hash = if current_index & 1 == 0 {
137 Self::parent_hash(current_hash.as_bytes(), sibling_hash.as_bytes())
138 } else {
139 Self::parent_hash(sibling_hash.as_bytes(), current_hash.as_bytes())
140 };
141
142 current_index /= 2;
143 }
144
145 current_hash == root_hash
146 }
147
148 fn parent_hash(left: &[u8], right: &[u8]) -> blake3::Hash {
159 blake3::Hasher::new().update(left).update(right).finalize()
160 }
161}
162
163#[cfg(test)]
164pub mod tests {
165 use crate::{errors::DecdsError, merkle_tree::MerkleTree};
166 use rand::Rng;
167
168 fn generate_random_leaf_hashes<R: Rng + ?Sized>(leaf_count: usize, rng: &mut R) -> Vec<blake3::Hash> {
169 let mut leaf_nodes = Vec::with_capacity(leaf_count);
170
171 (0..leaf_count).for_each(|_| {
172 let random_input = (0..blake3::OUT_LEN).map(|_| rng.random()).collect::<Vec<u8>>();
173 leaf_nodes.push(blake3::hash(&random_input));
174 });
175
176 leaf_nodes
177 }
178
179 pub fn flip_a_bit(byte: u8, bit_idx: usize) -> u8 {
182 byte ^ (1u8 << bit_idx)
183 }
184
185 fn flip_a_single_bit_in_merkle_proof<R: Rng + ?Sized>(mut proof: Vec<blake3::Hash>, rng: &mut R) -> Vec<blake3::Hash> {
186 if proof.is_empty() {
187 return proof;
188 }
189
190 let random_node_index = rng.random_range(0..proof.len());
191 let random_byte_index = rng.random_range(0..blake3::OUT_LEN);
192 let random_bit_index = rng.random_range(0..u8::BITS) as usize;
193
194 let mut bytes = [0u8; blake3::OUT_LEN];
195 bytes.copy_from_slice(proof[random_node_index].as_bytes());
196
197 bytes[random_byte_index] = flip_a_bit(bytes[random_byte_index], random_bit_index);
198
199 proof[random_node_index] = blake3::Hash::from_bytes(bytes);
200 proof
201 }
202
203 #[test]
204 fn prop_test_merkle_tree_operations() {
205 const NUM_TEST_ITERAATIONS: usize = 10;
206
207 const MIN_LEAF_NODE_COUNT: usize = 1;
208 const MAX_LEAF_NODE_COUNT: usize = 10_000;
209
210 let mut rng = rand::rng();
211
212 (0..NUM_TEST_ITERAATIONS).for_each(|_| {
213 let leaf_count = rng.random_range(MIN_LEAF_NODE_COUNT..=MAX_LEAF_NODE_COUNT);
214 let leaf_nodes = generate_random_leaf_hashes(leaf_count, &mut rng);
215
216 let merkle_tree = MerkleTree::new(leaf_nodes.clone()).expect("Must be able to build Merkle Tree");
217 let root_hash = merkle_tree.get_root_commitment();
218
219 leaf_nodes.iter().enumerate().for_each(|(leaf_index, &leaf_node)| {
220 let mut merkle_proof = merkle_tree.generate_proof(leaf_index).expect("Must be able to generate Merkle Proof");
221
222 let is_valid = MerkleTree::verify_proof(leaf_index, leaf_node, &merkle_proof, root_hash);
223 assert!(is_valid);
224
225 merkle_proof = flip_a_single_bit_in_merkle_proof(merkle_proof, &mut rng);
226
227 let is_valid = MerkleTree::verify_proof(leaf_index, leaf_node, &merkle_proof, root_hash);
228 assert!(!is_valid);
229 });
230 });
231 }
232
233 #[test]
234 fn test_new_with_empty_leaf_nodes() {
235 let leaf_nodes: Vec<blake3::Hash> = Vec::new();
236 assert!(MerkleTree::new(leaf_nodes).is_err());
237 }
238
239 #[test]
240 fn test_new_with_single_leaf_node() {
241 let leaf_nodes = vec![blake3::hash(b"hello")];
242 let merkle_tree = MerkleTree::new(leaf_nodes.clone()).expect("Must be able to build Merkle Tree");
243 assert_eq!(merkle_tree.get_root_commitment(), leaf_nodes[0]);
244 }
245
246 #[test]
247 fn test_new_with_two_leaf_nodes() {
248 let leaf1 = blake3::hash(b"hello");
249 let leaf2 = blake3::hash(b"world");
250 let leaf_nodes = vec![leaf1, leaf2];
251
252 let merkle_tree = MerkleTree::new(leaf_nodes.clone()).expect("Must be able to build Merkle Tree");
253 let expected_root = MerkleTree::parent_hash(leaf1.as_bytes(), leaf2.as_bytes());
254
255 assert_eq!(merkle_tree.get_root_commitment(), expected_root);
256 }
257
258 #[test]
259 fn test_generate_proof_out_of_bounds() {
260 let num_leaves = 5;
261 let leaf_nodes = generate_random_leaf_hashes(num_leaves, &mut rand::rng());
262 let merkle_tree = MerkleTree::new(leaf_nodes).expect("Must be able to build Merkle Tree");
263
264 assert_eq!(merkle_tree.generate_proof(5), Err(DecdsError::InvalidLeafNodeIndex(5, num_leaves)));
265 assert_eq!(merkle_tree.generate_proof(100), Err(DecdsError::InvalidLeafNodeIndex(100, num_leaves)));
266 }
267
268 #[test]
269 fn test_generate_proof_single_leaf_node() {
270 let leaf_node = blake3::hash(b"single");
271 let leaf_nodes = vec![leaf_node];
272 let merkle_tree = MerkleTree::new(leaf_nodes).expect("Must be able to build Merkle Tree");
273
274 let proof = merkle_tree.generate_proof(0).expect("Proof generation failed");
275 assert!(proof.is_empty());
276 }
277
278 #[test]
279 fn test_verify_proof_single_leaf_node() {
280 let leaf_node = blake3::hash(b"single_leaf");
281 let leaf_nodes = vec![leaf_node];
282 let merkle_tree = MerkleTree::new(leaf_nodes).expect("Must be able to build Merkle Tree");
283 let root_hash = merkle_tree.get_root_commitment();
284
285 let proof = merkle_tree.generate_proof(0).expect("Proof generation failed");
286 assert!(proof.is_empty());
287
288 let is_valid = MerkleTree::verify_proof(0, leaf_node, &proof, root_hash);
289 assert!(is_valid);
290
291 let tampered_leaf = blake3::hash(b"tampered");
293 let is_valid_tampered = MerkleTree::verify_proof(0, tampered_leaf, &proof, root_hash);
294 assert!(!is_valid_tampered);
295 }
296
297 #[test]
298 fn test_generate_and_verify_proof_for_two_leaf_nodes() {
299 let leaf1 = blake3::hash(b"first");
300 let leaf2 = blake3::hash(b"second");
301 let leaf_nodes = vec![leaf1, leaf2];
302 let merkle_tree = MerkleTree::new(leaf_nodes.clone()).expect("Must be able to build Merkle Tree");
303 let root_hash = merkle_tree.get_root_commitment();
304
305 let proof1 = merkle_tree.generate_proof(0).expect("Proof for leaf1 failed");
307 assert_eq!(proof1.len(), 1);
308 assert_eq!(proof1[0], leaf2); assert!(MerkleTree::verify_proof(0, leaf1, &proof1, root_hash));
310
311 let proof2 = merkle_tree.generate_proof(1).expect("Proof for leaf2 failed");
313 assert_eq!(proof2.len(), 1);
314 assert_eq!(proof2[0], leaf1); assert!(MerkleTree::verify_proof(1, leaf2, &proof2, root_hash));
316
317 let tampered_proof1 = vec![blake3::hash(b"fake_sibling")];
319 assert!(!MerkleTree::verify_proof(0, leaf1, &tampered_proof1, root_hash));
320
321 let tampered_leaf1 = blake3::hash(b"tampered_first");
323 assert!(!MerkleTree::verify_proof(0, tampered_leaf1, &proof1, root_hash));
324 }
325}