Skip to main content

irontide_core/
merkle.rs

1//! Merkle hash tree for `BitTorrent` v2 (BEP 52).
2//!
3//! Uses a binary heap layout stored as a flat `Vec<Id32>`. Leaves are padded
4//! to a power of two with zero hashes. Internal nodes are `SHA-256(left || right)`.
5//!
6//! The piece layer is the tree layer whose nodes correspond to piece-sized chunks
7//! of the file, used for initial verification. Below that, individual 16 KiB blocks
8//! can be verified via Merkle proof paths.
9
10use crate::hash::Id32;
11
12/// A complete binary Merkle tree stored in a flat array (1-indexed heap layout).
13///
14/// Index 1 is the root. For a node at index `i`, its left child is at `2*i`
15/// and its right child is at `2*i + 1`.
16#[derive(Debug, Clone)]
17pub struct MerkleTree {
18    /// 1-indexed array: nodes[0] is unused, nodes[1] is root.
19    nodes: Vec<Id32>,
20    /// Number of actual (non-padding) leaves.
21    leaf_count: usize,
22    /// Total leaves including padding (always a power of 2).
23    padded_leaf_count: usize,
24}
25
26impl MerkleTree {
27    /// Build a Merkle tree from leaf hashes.
28    ///
29    /// Pads to the next power of two with zero hashes, then builds bottom-up
30    /// by hashing pairs: `parent = SHA-256(left || right)`.
31    #[must_use]
32    pub fn from_leaves(leaves: &[Id32]) -> Self {
33        assert!(
34            !leaves.is_empty(),
35            "cannot build Merkle tree from empty leaves"
36        );
37
38        let leaf_count = leaves.len();
39        let padded = leaf_count.next_power_of_two();
40        let total_nodes = 2 * padded; // 1-indexed: indices 1..total_nodes-1
41
42        let mut nodes = vec![Id32::ZERO; total_nodes];
43
44        // Place leaves at the bottom layer (indices padded..2*padded-1)
45        for (i, leaf) in leaves.iter().enumerate() {
46            nodes[padded + i] = *leaf;
47        }
48        // Padding leaves remain ZERO
49
50        // Build bottom-up
51        for i in (1..padded).rev() {
52            let left = nodes[2 * i];
53            let right = nodes[2 * i + 1];
54            nodes[i] = hash_pair(left, right);
55        }
56
57        Self {
58            nodes,
59            leaf_count,
60            padded_leaf_count: padded,
61        }
62    }
63
64    /// The Merkle root hash.
65    #[must_use]
66    pub fn root(&self) -> Id32 {
67        self.nodes[1]
68    }
69
70    /// Number of actual (non-padding) leaves.
71    #[must_use]
72    pub fn leaf_count(&self) -> usize {
73        self.leaf_count
74    }
75
76    /// Tree depth (number of layers below the root). A single leaf has depth 0.
77    #[must_use]
78    pub fn depth(&self) -> usize {
79        self.padded_leaf_count.trailing_zeros() as usize
80    }
81
82    /// Get all hashes at a given depth (0 = root, `depth()` = leaves).
83    ///
84    /// Returns a slice of the internal array at that tree level.
85    #[must_use]
86    pub fn layer(&self, depth: usize) -> &[Id32] {
87        let layer_size = 1usize << depth;
88        let start = layer_size; // 1-indexed: layer at depth d starts at index 2^d
89        if start + layer_size > self.nodes.len() {
90            return &[];
91        }
92        &self.nodes[start..start + layer_size]
93    }
94
95    /// Get the piece layer — the tree layer whose nodes correspond to pieces.
96    ///
97    /// `blocks_per_piece` is `piece_length / 16384` (number of 16 KiB blocks per piece).
98    /// The piece layer is at depth `depth() - log2(blocks_per_piece)`.
99    #[must_use]
100    pub fn piece_layer(&self, blocks_per_piece: usize) -> &[Id32] {
101        assert!(
102            blocks_per_piece.is_power_of_two(),
103            "blocks_per_piece must be a power of 2"
104        );
105        let levels_up = blocks_per_piece.trailing_zeros() as usize;
106        let tree_depth = self.depth();
107        if levels_up > tree_depth {
108            // Piece is larger than entire file — root is the piece hash
109            return self.layer(0);
110        }
111        self.layer(tree_depth - levels_up)
112    }
113
114    /// Get the leaf hashes.
115    #[must_use]
116    pub fn leaves(&self) -> &[Id32] {
117        let start = self.padded_leaf_count;
118        &self.nodes[start..start + self.leaf_count]
119    }
120
121    /// Compute a Merkle root from a list of hashes (e.g., verify piece layer → root).
122    ///
123    /// Pads to power of two and builds a temporary tree. Useful for validating
124    /// that a received piece layer matches a known root.
125    #[must_use]
126    pub fn root_from_hashes(hashes: &[Id32]) -> Id32 {
127        if hashes.is_empty() {
128            return Id32::ZERO;
129        }
130        Self::from_leaves(hashes).root()
131    }
132
133    /// Extract the Merkle proof path for a specific leaf.
134    ///
135    /// Returns the sibling hashes from leaf to root needed to verify the leaf
136    /// against the root. Used by M34's hash request/response (BEP 52 wire
137    /// messages 21-23) for block-level verification.
138    #[must_use]
139    pub fn proof_path(&self, leaf_index: usize) -> Vec<Id32> {
140        assert!(
141            leaf_index < self.padded_leaf_count,
142            "leaf index out of range"
143        );
144        let mut path = Vec::with_capacity(self.depth());
145        let mut idx = self.padded_leaf_count + leaf_index;
146
147        while idx > 1 {
148            // Sibling is at idx ^ 1 (flip lowest bit)
149            let sibling = idx ^ 1;
150            path.push(self.nodes[sibling]);
151            idx /= 2; // move to parent
152        }
153
154        path
155    }
156
157    /// Verify a leaf hash against a root using a proof path.
158    ///
159    /// Recomputes the root from the leaf and its sibling hashes, then compares.
160    #[must_use]
161    pub fn verify_proof(root: Id32, leaf: Id32, leaf_index: usize, proof: &[Id32]) -> bool {
162        let mut hash = leaf;
163        let mut idx = leaf_index;
164
165        for sibling in proof {
166            hash = if idx.is_multiple_of(2) {
167                // We're the left child
168                hash_pair(hash, *sibling)
169            } else {
170                // We're the right child
171                hash_pair(*sibling, hash)
172            };
173            idx /= 2;
174        }
175
176        hash == root
177    }
178}
179
180/// Hash a pair of nodes: `SHA-256(left || right)`.
181fn hash_pair(left: Id32, right: Id32) -> Id32 {
182    crate::sha256(&[left.0, right.0].concat())
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    fn leaf(byte: u8) -> Id32 {
190        crate::sha256(&[byte])
191    }
192
193    #[test]
194    fn single_leaf() {
195        let l = leaf(0x01);
196        let tree = MerkleTree::from_leaves(&[l]);
197        // Single leaf: root = hash(leaf || ZERO) because we pad to 2 leaves
198        assert_eq!(tree.leaf_count(), 1);
199        assert_eq!(tree.depth(), 0);
200        assert_eq!(tree.root(), l);
201    }
202
203    #[test]
204    fn two_leaves() {
205        let l0 = leaf(0x01);
206        let l1 = leaf(0x02);
207        let tree = MerkleTree::from_leaves(&[l0, l1]);
208        assert_eq!(tree.leaf_count(), 2);
209        assert_eq!(tree.depth(), 1);
210        assert_eq!(tree.root(), hash_pair(l0, l1));
211        assert_eq!(tree.leaves(), &[l0, l1]);
212    }
213
214    #[test]
215    fn three_leaves_padded() {
216        let l0 = leaf(0x01);
217        let l1 = leaf(0x02);
218        let l2 = leaf(0x03);
219        let tree = MerkleTree::from_leaves(&[l0, l1, l2]);
220        assert_eq!(tree.leaf_count(), 3);
221        assert_eq!(tree.depth(), 2); // padded to 4 leaves
222        // Bottom layer: l0, l1, l2, ZERO
223        let left = hash_pair(l0, l1);
224        let right = hash_pair(l2, Id32::ZERO);
225        assert_eq!(tree.root(), hash_pair(left, right));
226    }
227
228    #[test]
229    fn layer_extraction() {
230        let leaves: Vec<Id32> = (0..4).map(leaf).collect();
231        let tree = MerkleTree::from_leaves(&leaves);
232        assert_eq!(tree.depth(), 2);
233
234        // Layer 0 = root (1 node)
235        assert_eq!(tree.layer(0).len(), 1);
236        assert_eq!(tree.layer(0)[0], tree.root());
237
238        // Layer 1 = 2 intermediate nodes
239        assert_eq!(tree.layer(1).len(), 2);
240
241        // Layer 2 = 4 leaves
242        assert_eq!(tree.layer(2).len(), 4);
243        assert_eq!(tree.layer(2)[0], leaves[0]);
244        assert_eq!(tree.layer(2)[3], leaves[3]);
245    }
246
247    #[test]
248    fn piece_layer_extraction() {
249        // 8 block-level leaves, 2 blocks per piece → piece layer at depth 2 (4 nodes)
250        let leaves: Vec<Id32> = (0..8).map(leaf).collect();
251        let tree = MerkleTree::from_leaves(&leaves);
252        assert_eq!(tree.depth(), 3);
253
254        let pieces = tree.piece_layer(2); // 2 blocks per piece
255        assert_eq!(pieces.len(), 4);
256        // Each piece hash should be hash of 2 consecutive block hashes
257        assert_eq!(pieces[0], hash_pair(leaves[0], leaves[1]));
258        assert_eq!(pieces[1], hash_pair(leaves[2], leaves[3]));
259    }
260
261    #[test]
262    fn root_from_piece_layer_round_trip() {
263        let leaves: Vec<Id32> = (0..8).map(leaf).collect();
264        let tree = MerkleTree::from_leaves(&leaves);
265        let pieces = tree.piece_layer(2);
266        // Rebuilding from piece layer should give the same root
267        let rebuilt_root = MerkleTree::root_from_hashes(pieces);
268        assert_eq!(rebuilt_root, tree.root());
269    }
270
271    #[test]
272    fn proof_path_generation() {
273        let leaves: Vec<Id32> = (0..4).map(leaf).collect();
274        let tree = MerkleTree::from_leaves(&leaves);
275
276        let proof = tree.proof_path(0);
277        assert_eq!(proof.len(), 2); // depth = 2, so 2 siblings
278        // First sibling is leaf[1], second is hash(leaf[2], leaf[3])
279        assert_eq!(proof[0], leaves[1]);
280        assert_eq!(proof[1], hash_pair(leaves[2], leaves[3]));
281    }
282
283    #[test]
284    fn proof_verification_success() {
285        let leaves: Vec<Id32> = (0..4).map(leaf).collect();
286        let tree = MerkleTree::from_leaves(&leaves);
287
288        for (i, &leaf_hash) in leaves.iter().enumerate() {
289            let proof = tree.proof_path(i);
290            assert!(
291                MerkleTree::verify_proof(tree.root(), leaf_hash, i, &proof),
292                "proof failed for leaf {i}"
293            );
294        }
295    }
296
297    #[test]
298    fn proof_verification_failure() {
299        let leaves: Vec<Id32> = (0..4).map(leaf).collect();
300        let tree = MerkleTree::from_leaves(&leaves);
301
302        let proof = tree.proof_path(0);
303        let wrong_leaf = leaf(0xFF);
304        assert!(!MerkleTree::verify_proof(
305            tree.root(),
306            wrong_leaf,
307            0,
308            &proof
309        ));
310    }
311}