Skip to main content

microscope_memory/
merkle.rs

1//! Merkle tree for whole-index integrity verification.
2//!
3//! Leaf = SHA-256(block_data). Parent = SHA-256(left || right).
4//! Odd leaf count → last leaf is promoted (hashed with itself).
5
6use sha2::{Digest, Sha256};
7
8pub struct MerkleTree {
9    /// All nodes: leaves first, then internal nodes bottom-up. Root is last.
10    pub nodes: Vec<[u8; 32]>,
11    pub leaf_count: usize,
12    pub root: [u8; 32],
13}
14
15impl MerkleTree {
16    /// Build Merkle tree from block data slices.
17    pub fn build(leaves: &[&[u8]]) -> Self {
18        assert!(!leaves.is_empty(), "cannot build Merkle tree from 0 leaves");
19
20        // Hash each leaf
21        let leaf_hashes: Vec<[u8; 32]> = leaves
22            .iter()
23            .map(|data| {
24                let mut h = Sha256::new();
25                h.update(data);
26                h.finalize().into()
27            })
28            .collect();
29
30        let leaf_count = leaf_hashes.len();
31
32        // Build tree bottom-up
33        // nodes layout: [leaf_0 .. leaf_n-1, internal_nodes..., root]
34        let mut nodes = leaf_hashes;
35
36        let mut level_start = 0;
37        let mut level_len = leaf_count;
38
39        while level_len > 1 {
40            let pairs = level_len / 2;
41            let odd = level_len % 2 == 1;
42
43            for p in 0..pairs {
44                let left = &nodes[level_start + p * 2];
45                let right = &nodes[level_start + p * 2 + 1];
46                let mut h = Sha256::new();
47                h.update(left);
48                h.update(right);
49                let parent: [u8; 32] = h.finalize().into();
50                nodes.push(parent);
51            }
52
53            // Odd node: hash with itself (promoted)
54            if odd {
55                let lone = &nodes[level_start + level_len - 1];
56                let mut h = Sha256::new();
57                h.update(lone);
58                h.update(lone);
59                let parent: [u8; 32] = h.finalize().into();
60                nodes.push(parent);
61            }
62
63            level_start += level_len;
64            level_len = pairs + if odd { 1 } else { 0 };
65        }
66
67        let root = *nodes.last().unwrap();
68        MerkleTree {
69            nodes,
70            leaf_count,
71            root,
72        }
73    }
74
75    /// Verify a single leaf at index against stored hash.
76    pub fn verify_leaf(&self, index: usize, data: &[u8]) -> bool {
77        if index >= self.leaf_count {
78            return false;
79        }
80        let mut h = Sha256::new();
81        h.update(data);
82        let computed: [u8; 32] = h.finalize().into();
83        self.nodes[index] == computed
84    }
85
86    /// Get Merkle proof path for a leaf.
87    /// Returns Vec of (sibling_hash, is_right) pairs from leaf to root.
88    pub fn proof(&self, index: usize) -> Vec<([u8; 32], bool)> {
89        if index >= self.leaf_count {
90            return vec![];
91        }
92
93        let mut path = Vec::new();
94        let mut level_start = 0;
95        let mut level_len = self.leaf_count;
96        let mut pos = index;
97
98        while level_len > 1 {
99            let sibling_pos = if pos.is_multiple_of(2) {
100                // We are left child, sibling is right
101                if pos + 1 < level_len {
102                    pos + 1
103                } else {
104                    pos // odd leaf, paired with itself
105                }
106            } else {
107                // We are right child, sibling is left
108                pos - 1
109            };
110
111            let is_right = pos % 2 == 1;
112            path.push((self.nodes[level_start + sibling_pos], is_right));
113
114            level_start += level_len;
115            level_len = level_len.div_ceil(2);
116            pos /= 2;
117        }
118
119        path
120    }
121
122    /// Verify a Merkle proof against a known root.
123    pub fn verify_proof(root: &[u8; 32], leaf_data: &[u8], proof: &[([u8; 32], bool)]) -> bool {
124        let mut h = Sha256::new();
125        h.update(leaf_data);
126        let mut current: [u8; 32] = h.finalize().into();
127
128        for &(ref sibling, is_right) in proof {
129            let mut h = Sha256::new();
130            if is_right {
131                // Current node is right child → sibling is left
132                h.update(sibling);
133                h.update(current);
134            } else {
135                // Current node is left child → sibling is right
136                h.update(current);
137                h.update(sibling);
138            }
139            current = h.finalize().into();
140        }
141
142        current == *root
143    }
144
145    /// Serialize tree to bytes.
146    /// Format: [u32 leaf_count][u32 node_count][nodes: 32 bytes each]
147    pub fn to_bytes(&self) -> Vec<u8> {
148        let mut buf = Vec::with_capacity(8 + self.nodes.len() * 32);
149        buf.extend_from_slice(&(self.leaf_count as u32).to_le_bytes());
150        buf.extend_from_slice(&(self.nodes.len() as u32).to_le_bytes());
151        for node in &self.nodes {
152            buf.extend_from_slice(node);
153        }
154        buf
155    }
156
157    /// Deserialize from bytes.
158    pub fn from_bytes(data: &[u8]) -> Option<Self> {
159        if data.len() < 8 {
160            return None;
161        }
162        let leaf_count = u32::from_le_bytes(data[0..4].try_into().ok()?) as usize;
163        let node_count = u32::from_le_bytes(data[4..8].try_into().ok()?) as usize;
164        if data.len() < 8 + node_count * 32 {
165            return None;
166        }
167        if leaf_count == 0 || node_count == 0 {
168            return None;
169        }
170
171        let mut nodes = Vec::with_capacity(node_count);
172        for i in 0..node_count {
173            let offset = 8 + i * 32;
174            let mut hash = [0u8; 32];
175            hash.copy_from_slice(&data[offset..offset + 32]);
176            nodes.push(hash);
177        }
178
179        let root = *nodes.last()?;
180        Some(MerkleTree {
181            nodes,
182            leaf_count,
183            root,
184        })
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn test_build_single_leaf() {
194        let data = b"hello";
195        let tree = MerkleTree::build(&[data.as_slice()]);
196        assert_eq!(tree.leaf_count, 1);
197        assert!(tree.verify_leaf(0, data));
198        assert!(!tree.verify_leaf(0, b"world"));
199    }
200
201    #[test]
202    fn test_build_two_leaves() {
203        let a = b"hello";
204        let b = b"world";
205        let tree = MerkleTree::build(&[a.as_slice(), b.as_slice()]);
206        assert_eq!(tree.leaf_count, 2);
207        assert!(tree.verify_leaf(0, a));
208        assert!(tree.verify_leaf(1, b));
209        assert!(!tree.verify_leaf(0, b));
210    }
211
212    #[test]
213    fn test_build_odd_leaves() {
214        let leaves: Vec<Vec<u8>> = (0..5u8).map(|i| vec![i; 10]).collect();
215        let refs: Vec<&[u8]> = leaves.iter().map(|v| v.as_slice()).collect();
216        let tree = MerkleTree::build(&refs);
217        assert_eq!(tree.leaf_count, 5);
218        for (i, leaf) in leaves.iter().enumerate() {
219            assert!(tree.verify_leaf(i, leaf));
220        }
221    }
222
223    #[test]
224    fn test_proof_and_verify() {
225        let leaves: Vec<Vec<u8>> = (0..8u8).map(|i| vec![i; 20]).collect();
226        let refs: Vec<&[u8]> = leaves.iter().map(|v| v.as_slice()).collect();
227        let tree = MerkleTree::build(&refs);
228
229        for (i, leaf) in leaves.iter().enumerate() {
230            let proof = tree.proof(i);
231            assert!(
232                MerkleTree::verify_proof(&tree.root, leaf, &proof),
233                "proof failed for leaf {}",
234                i
235            );
236        }
237    }
238
239    #[test]
240    fn test_proof_fails_on_tamper() {
241        let leaves: Vec<Vec<u8>> = (0..4u8).map(|i| vec![i; 15]).collect();
242        let refs: Vec<&[u8]> = leaves.iter().map(|v| v.as_slice()).collect();
243        let tree = MerkleTree::build(&refs);
244
245        let proof = tree.proof(0);
246        let tampered = vec![99u8; 15];
247        assert!(!MerkleTree::verify_proof(&tree.root, &tampered, &proof));
248    }
249
250    #[test]
251    fn test_serialization_roundtrip() {
252        let leaves: Vec<Vec<u8>> = (0..10u8).map(|i| vec![i; 30]).collect();
253        let refs: Vec<&[u8]> = leaves.iter().map(|v| v.as_slice()).collect();
254        let tree = MerkleTree::build(&refs);
255
256        let bytes = tree.to_bytes();
257        let restored = MerkleTree::from_bytes(&bytes).expect("deserialize");
258
259        assert_eq!(restored.leaf_count, tree.leaf_count);
260        assert_eq!(restored.root, tree.root);
261        assert_eq!(restored.nodes.len(), tree.nodes.len());
262    }
263
264    #[test]
265    fn test_root_changes_on_reorder() {
266        let a = b"first";
267        let b = b"second";
268        let tree1 = MerkleTree::build(&[a.as_slice(), b.as_slice()]);
269        let tree2 = MerkleTree::build(&[b.as_slice(), a.as_slice()]);
270        assert_ne!(tree1.root, tree2.root, "reordering must change root");
271    }
272}