1use crate::{SyncError, SyncResult};
8use serde::{Deserialize, Serialize};
9
10pub type Hash = [u8; 32];
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
15enum MerkleNode {
16 Leaf {
18 hash: Hash,
20 },
21 Internal {
23 hash: Hash,
25 left: Box<MerkleNode>,
27 right: Box<MerkleNode>,
29 },
30}
31
32impl MerkleNode {
33 fn hash(&self) -> &Hash {
35 match self {
36 MerkleNode::Leaf { hash } => hash,
37 MerkleNode::Internal { hash, .. } => hash,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct MerkleTree {
64 root: Option<MerkleNode>,
66 leaf_count: usize,
68}
69
70impl MerkleTree {
71 pub fn new() -> Self {
73 Self {
74 root: None,
75 leaf_count: 0,
76 }
77 }
78
79 pub fn from_data(data: Vec<Vec<u8>>) -> SyncResult<Self> {
89 if data.is_empty() {
90 return Ok(Self::new());
91 }
92
93 let leaf_count = data.len();
94 let leaves: Vec<MerkleNode> = data
95 .into_iter()
96 .map(|block| {
97 let hash = Self::hash_data(&block);
98 MerkleNode::Leaf { hash }
99 })
100 .collect();
101
102 let root = Self::build_tree(leaves)?;
103
104 Ok(Self {
105 root: Some(root),
106 leaf_count,
107 })
108 }
109
110 fn build_tree(mut nodes: Vec<MerkleNode>) -> SyncResult<MerkleNode> {
112 if nodes.is_empty() {
113 return Err(SyncError::MerkleVerificationFailed(
114 "Cannot build tree from empty nodes".to_string(),
115 ));
116 }
117
118 while nodes.len() > 1 {
119 let mut next_level = Vec::new();
120
121 for chunk in nodes.chunks(2) {
122 match chunk {
123 [left, right] => {
124 let combined = Self::combine_hashes(left.hash(), right.hash());
125 next_level.push(MerkleNode::Internal {
126 hash: combined,
127 left: Box::new(left.clone()),
128 right: Box::new(right.clone()),
129 });
130 }
131 [single] => {
132 next_level.push(single.clone());
134 }
135 _ => unreachable!(),
136 }
137 }
138
139 nodes = next_level;
140 }
141
142 nodes
143 .into_iter()
144 .next()
145 .ok_or_else(|| SyncError::MerkleVerificationFailed("Failed to build tree".to_string()))
146 }
147
148 fn hash_data(data: &[u8]) -> Hash {
150 let hash = blake3::hash(data);
151 *hash.as_bytes()
152 }
153
154 fn combine_hashes(left: &Hash, right: &Hash) -> Hash {
156 let mut combined = Vec::with_capacity(64);
157 combined.extend_from_slice(left);
158 combined.extend_from_slice(right);
159 Self::hash_data(&combined)
160 }
161
162 pub fn root_hash(&self) -> Option<&Hash> {
168 self.root.as_ref().map(|node| node.hash())
169 }
170
171 pub fn leaf_count(&self) -> usize {
173 self.leaf_count
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.root.is_none()
179 }
180
181 pub fn diff(&self, other: &MerkleTree) -> Vec<usize> {
191 let mut differences = Vec::new();
192
193 if let (Some(self_root), Some(other_root)) = (&self.root, &other.root) {
194 Self::diff_nodes(self_root, other_root, 0, &mut differences);
195 }
196
197 differences
198 }
199
200 fn diff_nodes(
202 self_node: &MerkleNode,
203 other_node: &MerkleNode,
204 index: usize,
205 differences: &mut Vec<usize>,
206 ) {
207 if self_node.hash() == other_node.hash() {
208 return;
210 }
211
212 match (self_node, other_node) {
213 (MerkleNode::Leaf { .. }, MerkleNode::Leaf { .. }) => {
214 differences.push(index);
215 }
216 (
217 MerkleNode::Internal {
218 left: l1,
219 right: r1,
220 ..
221 },
222 MerkleNode::Internal {
223 left: l2,
224 right: r2,
225 ..
226 },
227 ) => {
228 Self::diff_nodes(l1, l2, index * 2, differences);
229 Self::diff_nodes(r1, r2, index * 2 + 1, differences);
230 }
231 _ => {
232 differences.push(index);
234 }
235 }
236 }
237
238 pub fn verify(&self, data: &[Vec<u8>]) -> SyncResult<bool> {
248 if data.len() != self.leaf_count {
249 return Ok(false);
250 }
251
252 let verification_tree = Self::from_data(data.to_vec())?;
253
254 Ok(self.root_hash() == verification_tree.root_hash())
255 }
256
257 pub fn get_proof(&self, index: usize) -> SyncResult<Vec<Hash>> {
269 if index >= self.leaf_count {
270 return Err(SyncError::MerkleVerificationFailed(
271 "Index out of bounds".to_string(),
272 ));
273 }
274
275 let mut proof = Vec::new();
276
277 if let Some(root) = &self.root {
278 Self::collect_proof(root, index, &mut proof)?;
279 }
280
281 Ok(proof)
282 }
283
284 fn collect_proof(node: &MerkleNode, index: usize, proof: &mut Vec<Hash>) -> SyncResult<bool> {
286 match node {
287 MerkleNode::Leaf { .. } => Ok(true),
288 MerkleNode::Internal { left, right, .. } => {
289 let left_leaves = Self::count_leaves(left);
290
291 if index < left_leaves {
292 proof.push(*right.hash());
294 Self::collect_proof(left, index, proof)
295 } else {
296 proof.push(*left.hash());
298 Self::collect_proof(right, index - left_leaves, proof)
299 }
300 }
301 }
302 }
303
304 fn count_leaves(node: &MerkleNode) -> usize {
306 match node {
307 MerkleNode::Leaf { .. } => 1,
308 MerkleNode::Internal { left, right, .. } => {
309 Self::count_leaves(left) + Self::count_leaves(right)
310 }
311 }
312 }
313
314 pub fn verify_proof(&self, leaf_hash: &Hash, proof: &[Hash], index: usize) -> bool {
326 let mut current_hash = *leaf_hash;
327 let mut current_index = index;
328
329 for sibling_hash in proof.iter().rev() {
331 if current_index % 2 == 0 {
332 current_hash = Self::combine_hashes(¤t_hash, sibling_hash);
334 } else {
335 current_hash = Self::combine_hashes(sibling_hash, ¤t_hash);
337 }
338 current_index /= 2;
339 }
340
341 self.root_hash() == Some(¤t_hash)
342 }
343}
344
345impl Default for MerkleTree {
346 fn default() -> Self {
347 Self::new()
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_merkle_tree_creation() {
357 let tree = MerkleTree::new();
358 assert!(tree.is_empty());
359 assert_eq!(tree.leaf_count(), 0);
360 }
361
362 #[test]
363 fn test_merkle_tree_from_data() -> SyncResult<()> {
364 let data = vec![
365 b"block1".to_vec(),
366 b"block2".to_vec(),
367 b"block3".to_vec(),
368 b"block4".to_vec(),
369 ];
370
371 let tree = MerkleTree::from_data(data)?;
372 assert!(!tree.is_empty());
373 assert_eq!(tree.leaf_count(), 4);
374 assert!(tree.root_hash().is_some());
375
376 Ok(())
377 }
378
379 #[test]
380 fn test_merkle_tree_verify() -> SyncResult<()> {
381 let data = vec![b"block1".to_vec(), b"block2".to_vec(), b"block3".to_vec()];
382
383 let tree = MerkleTree::from_data(data.clone())?;
384 assert!(tree.verify(&data)?);
385
386 let mut modified_data = data.clone();
388 modified_data[1] = b"modified".to_vec();
389 assert!(!tree.verify(&modified_data)?);
390
391 Ok(())
392 }
393
394 #[test]
395 fn test_merkle_tree_diff() -> SyncResult<()> {
396 let data1 = vec![
397 b"block1".to_vec(),
398 b"block2".to_vec(),
399 b"block3".to_vec(),
400 b"block4".to_vec(),
401 ];
402
403 let mut data2 = data1.clone();
404 data2[1] = b"modified".to_vec();
405
406 let tree1 = MerkleTree::from_data(data1)?;
407 let tree2 = MerkleTree::from_data(data2)?;
408
409 let differences = tree1.diff(&tree2);
410 assert!(!differences.is_empty());
411
412 Ok(())
413 }
414
415 #[test]
416 fn test_merkle_tree_proof() -> SyncResult<()> {
417 let data = vec![
418 b"block1".to_vec(),
419 b"block2".to_vec(),
420 b"block3".to_vec(),
421 b"block4".to_vec(),
422 ];
423
424 let tree = MerkleTree::from_data(data.clone())?;
425
426 let leaf_hash = MerkleTree::hash_data(&data[0]);
427 let proof = tree.get_proof(0)?;
428
429 assert!(tree.verify_proof(&leaf_hash, &proof, 0));
430
431 Ok(())
432 }
433}