hash_based_signatures/
merkle_tree.rs

1use crate::signature::HashType;
2use crate::utils::{get_least_significant_bits, hash};
3use data_encoding::HEXLOWER;
4use serde::{Deserialize, Serialize};
5use std::fmt::{Debug, Formatter};
6use std::marker::PhantomData;
7
8/// A Merkle tree.
9///
10/// # Examples
11/// ```
12/// use hash_based_signatures::merkle_tree::MerkleTree;
13///
14/// let elements: Vec<u8> = (0..128).collect();
15/// let tree = MerkleTree::new(&elements);
16/// let proof = tree.get_proof(17);
17/// assert!(proof.verify(*tree.get_root_hash(), &17));
18/// ```
19#[derive(Clone)]
20pub struct MerkleTree<T: Serialize> {
21    root_hash: [u8; 32],
22    root_node: Node<T>,
23    depth: usize,
24
25    /// Phantom to keep the information of the element type.
26    phantom: PhantomData<T>,
27}
28
29#[derive(Clone)]
30enum Node<T: Serialize> {
31    Leaf(),
32    InternalNode(Box<MerkleTree<T>>, Box<MerkleTree<T>>),
33}
34
35/// A proof that a given datum is at a given index.
36/// Note that the proof does not store the data itself, but it needs to be
37/// provided to `MerkleProof::verify()`.
38#[derive(PartialEq, Serialize, Deserialize)]
39pub struct MerkleProof<T: Serialize> {
40    /// The index of the datum for which this is the proof.
41    pub index: usize,
42    /// Hash chain leading up to the root node
43    pub hash_chain: Vec<[u8; 32]>,
44
45    /// Phantom to keep the information of the element type.
46    phantom: PhantomData<T>,
47}
48
49/// Hash function applied to leaves of the Merkle tree
50///
51/// # Panics
52/// Panics if the data can't be serialized.
53pub fn leaf_hash<T: Serialize>(data: &T) -> [u8; 32] {
54    let data = rmp_serde::to_vec(data).expect("Failed to serialize data");
55
56    // For leafs, we need to use a different hash function for security:
57    // https://crypto.stackexchange.com/questions/2106/what-is-the-purpose-of-using-different-hash-functions-for-the-leaves-and-interna
58    // So, we append a zero to all leaves before hashing them
59    let zero = [0u8];
60    let all_elements = [&data, &zero as &[u8]].concat();
61    hash(&all_elements)
62}
63
64/// Hash function applied to internal nodes of the Merkle tree
65pub fn internal_node_hash(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
66    let all_elements = [*left, *right].concat();
67    hash(&all_elements)
68}
69
70impl<T: Serialize + Debug> MerkleTree<T> {
71    /// Construct a new Merkle tree from a list of `elements`.
72    ///
73    /// A single element is of type `Vec<u8>`, so any complex data structure has
74    /// to be serialized to a variable-length byte array.
75    /// The tree owns the values.
76    ///
77    /// # Panics
78    ///
79    /// Panics if the number of elements is not a power of two or if the provided data can't be serialized.
80    pub fn new(elements: &[T]) -> MerkleTree<T> {
81        let depth = (elements.len() as f64).log2() as usize;
82
83        if 1 << depth != elements.len() {
84            panic!(
85                "Number of elements needs to be a power of 2, got {}",
86                elements.len()
87            )
88        }
89
90        let (root_node, root_hash) = if elements.len() == 1 {
91            let element_hash = leaf_hash(&elements[0]);
92            (Node::Leaf(), element_hash)
93        } else {
94            let mid = elements.len() / 2;
95            let elements_left = &elements[..mid];
96            let elements_right = &elements[mid..];
97            let left_tree = Box::new(MerkleTree::new(elements_left));
98            let right_tree = Box::new(MerkleTree::new(elements_right));
99
100            let root_hash = internal_node_hash(&left_tree.root_hash, &right_tree.root_hash);
101            let root_node = Node::InternalNode(left_tree, right_tree);
102
103            (root_node, root_hash)
104        };
105
106        MerkleTree {
107            root_hash,
108            root_node,
109            depth,
110            phantom: PhantomData,
111        }
112    }
113
114    /// Get the root hash of the tree.
115    pub fn get_root_hash(&self) -> &[u8; 32] {
116        &self.root_hash
117    }
118
119    /// Get a Merkle proof for a given index `i`.
120    pub fn get_proof(&self, i: usize) -> MerkleProof<T> {
121        assert!(i < 1 << self.depth);
122
123        match &self.root_node {
124            Node::Leaf() => MerkleProof {
125                index: i,
126                hash_chain: vec![],
127                phantom: PhantomData,
128            },
129            Node::InternalNode(left_tree, right_tree) => {
130                let mut proof = if i < 1 << (self.depth - 1) {
131                    // Element is in left child
132                    let mut proof = left_tree.get_proof(i);
133                    proof.hash_chain.push(right_tree.root_hash);
134                    proof
135                } else {
136                    // Element is in left child
137                    let mut proof = right_tree.get_proof(i - (1 << (self.depth - 1)));
138                    proof.hash_chain.push(left_tree.root_hash);
139                    proof
140                };
141                proof.index = i;
142                proof
143            }
144        }
145    }
146
147    fn representation_string(&self, indent: usize) -> String {
148        let mut result = String::new();
149        let indent_str = "  ".repeat(indent).to_string();
150        result += &format!("{}{}\n", indent_str, HEXLOWER.encode(&self.root_hash));
151
152        match &self.root_node {
153            Node::Leaf() => {
154                result += &format!("{}  Leaf\n", indent_str);
155            }
156            Node::InternalNode(left, right) => {
157                result += &left.representation_string(indent + 1);
158                result += &right.representation_string(indent + 1);
159            }
160        }
161
162        result
163    }
164}
165
166impl<T: Serialize + Debug> Debug for MerkleTree<T> {
167    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
168        write!(f, "{}", self.representation_string(0))
169    }
170}
171
172impl<T: Serialize> MerkleProof<T> {
173    /// Verifies that the given root hash can be reconstructed from the Merkle proof.
174    ///
175    /// # Panics
176    /// Panics if the data can't be serialized.
177    pub fn verify(&self, root_hash: HashType, data: &T) -> bool {
178        let index_bits = get_least_significant_bits(self.index, self.hash_chain.len());
179        let mut expected_root_hash = leaf_hash(data);
180        for (hash, index_bit) in self.hash_chain.iter().zip(index_bits.iter().rev()) {
181            expected_root_hash = match index_bit {
182                false => internal_node_hash(&expected_root_hash, hash),
183                true => internal_node_hash(hash, &expected_root_hash),
184            }
185        }
186
187        expected_root_hash == root_hash
188    }
189}
190
191impl<T: Serialize> Debug for MerkleProof<T> {
192    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193        let mut representation = format!("Index: {}\nProof:\n", self.index);
194        for hash in self.hash_chain.iter() {
195            representation += &format!("  {}\n", HEXLOWER.encode(hash));
196        }
197        write!(f, "{}", representation)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use crate::merkle_tree::{MerkleProof, MerkleTree};
204    use std::marker::PhantomData;
205
206    fn merkle_tree() -> MerkleTree<Vec<u8>> {
207        let elements: Vec<Vec<u8>> = (0u8..128).map(|x| vec![x]).collect();
208        MerkleTree::new(&elements)
209    }
210
211    #[test]
212    fn test_valid_proofs() {
213        let tree = merkle_tree();
214        let proof = tree.get_proof(43);
215
216        assert!(proof.verify(*tree.get_root_hash(), &vec![43]));
217    }
218
219    #[test]
220    fn test_invalid_proofs() {
221        let tree = merkle_tree();
222        let proof1 = tree.get_proof(43);
223        let proof2 = tree.get_proof(123);
224
225        let invalid_proof_wrong_index = MerkleProof {
226            hash_chain: proof1.hash_chain.clone(),
227            index: proof2.index,
228            phantom: PhantomData,
229        };
230        assert!(!invalid_proof_wrong_index.verify(tree.root_hash, &vec![43]));
231
232        let invalid_proof_wrong_hash_chain = MerkleProof {
233            hash_chain: proof2.hash_chain.clone(),
234            index: proof1.index,
235            phantom: PhantomData,
236        };
237        assert!(!invalid_proof_wrong_hash_chain.verify(tree.root_hash, &vec![43]));
238
239        let invalid_proof_wrong_data = MerkleProof {
240            hash_chain: proof2.hash_chain.clone(),
241            index: proof2.index,
242            phantom: PhantomData,
243        };
244        assert!(!invalid_proof_wrong_data.verify(tree.root_hash, &vec![43]));
245    }
246
247    #[test]
248    fn test_works_with_complex_data() {
249        let elements: Vec<(u32, u32, (u32,))> = (0..128).map(|x| (x, x + 1, (x + 2,))).collect();
250        let tree = MerkleTree::new(&elements);
251        let proof = tree.get_proof(43);
252        assert!(proof.verify(*tree.get_root_hash(), &(43, 44, (45,))));
253    }
254}