hash_based_signatures/
merkle_tree.rs1use crate::signature::HashType;
2use crate::utils::{get_least_significant_bits, hash};
3use data_encoding::HEXLOWER;
4use serde::{Deserialize, Serialize};
5use std::fmt::{Debug, Formatter};
6use std::marker::PhantomData;
7
8#[derive(Clone)]
20pub struct MerkleTree<T: Serialize> {
21 root_hash: [u8; 32],
22 root_node: Node<T>,
23 depth: usize,
24
25 phantom: PhantomData<T>,
27}
28
29#[derive(Clone)]
30enum Node<T: Serialize> {
31 Leaf(),
32 InternalNode(Box<MerkleTree<T>>, Box<MerkleTree<T>>),
33}
34
35#[derive(PartialEq, Serialize, Deserialize)]
39pub struct MerkleProof<T: Serialize> {
40 pub index: usize,
42 pub hash_chain: Vec<[u8; 32]>,
44
45 phantom: PhantomData<T>,
47}
48
49pub fn leaf_hash<T: Serialize>(data: &T) -> [u8; 32] {
54 let data = rmp_serde::to_vec(data).expect("Failed to serialize data");
55
56 let zero = [0u8];
60 let all_elements = [&data, &zero as &[u8]].concat();
61 hash(&all_elements)
62}
63
64pub fn internal_node_hash(left: &[u8; 32], right: &[u8; 32]) -> [u8; 32] {
66 let all_elements = [*left, *right].concat();
67 hash(&all_elements)
68}
69
70impl<T: Serialize + Debug> MerkleTree<T> {
71 pub fn new(elements: &[T]) -> MerkleTree<T> {
81 let depth = (elements.len() as f64).log2() as usize;
82
83 if 1 << depth != elements.len() {
84 panic!(
85 "Number of elements needs to be a power of 2, got {}",
86 elements.len()
87 )
88 }
89
90 let (root_node, root_hash) = if elements.len() == 1 {
91 let element_hash = leaf_hash(&elements[0]);
92 (Node::Leaf(), element_hash)
93 } else {
94 let mid = elements.len() / 2;
95 let elements_left = &elements[..mid];
96 let elements_right = &elements[mid..];
97 let left_tree = Box::new(MerkleTree::new(elements_left));
98 let right_tree = Box::new(MerkleTree::new(elements_right));
99
100 let root_hash = internal_node_hash(&left_tree.root_hash, &right_tree.root_hash);
101 let root_node = Node::InternalNode(left_tree, right_tree);
102
103 (root_node, root_hash)
104 };
105
106 MerkleTree {
107 root_hash,
108 root_node,
109 depth,
110 phantom: PhantomData,
111 }
112 }
113
114 pub fn get_root_hash(&self) -> &[u8; 32] {
116 &self.root_hash
117 }
118
119 pub fn get_proof(&self, i: usize) -> MerkleProof<T> {
121 assert!(i < 1 << self.depth);
122
123 match &self.root_node {
124 Node::Leaf() => MerkleProof {
125 index: i,
126 hash_chain: vec![],
127 phantom: PhantomData,
128 },
129 Node::InternalNode(left_tree, right_tree) => {
130 let mut proof = if i < 1 << (self.depth - 1) {
131 let mut proof = left_tree.get_proof(i);
133 proof.hash_chain.push(right_tree.root_hash);
134 proof
135 } else {
136 let mut proof = right_tree.get_proof(i - (1 << (self.depth - 1)));
138 proof.hash_chain.push(left_tree.root_hash);
139 proof
140 };
141 proof.index = i;
142 proof
143 }
144 }
145 }
146
147 fn representation_string(&self, indent: usize) -> String {
148 let mut result = String::new();
149 let indent_str = " ".repeat(indent).to_string();
150 result += &format!("{}{}\n", indent_str, HEXLOWER.encode(&self.root_hash));
151
152 match &self.root_node {
153 Node::Leaf() => {
154 result += &format!("{} Leaf\n", indent_str);
155 }
156 Node::InternalNode(left, right) => {
157 result += &left.representation_string(indent + 1);
158 result += &right.representation_string(indent + 1);
159 }
160 }
161
162 result
163 }
164}
165
166impl<T: Serialize + Debug> Debug for MerkleTree<T> {
167 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
168 write!(f, "{}", self.representation_string(0))
169 }
170}
171
172impl<T: Serialize> MerkleProof<T> {
173 pub fn verify(&self, root_hash: HashType, data: &T) -> bool {
178 let index_bits = get_least_significant_bits(self.index, self.hash_chain.len());
179 let mut expected_root_hash = leaf_hash(data);
180 for (hash, index_bit) in self.hash_chain.iter().zip(index_bits.iter().rev()) {
181 expected_root_hash = match index_bit {
182 false => internal_node_hash(&expected_root_hash, hash),
183 true => internal_node_hash(hash, &expected_root_hash),
184 }
185 }
186
187 expected_root_hash == root_hash
188 }
189}
190
191impl<T: Serialize> Debug for MerkleProof<T> {
192 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193 let mut representation = format!("Index: {}\nProof:\n", self.index);
194 for hash in self.hash_chain.iter() {
195 representation += &format!(" {}\n", HEXLOWER.encode(hash));
196 }
197 write!(f, "{}", representation)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use crate::merkle_tree::{MerkleProof, MerkleTree};
204 use std::marker::PhantomData;
205
206 fn merkle_tree() -> MerkleTree<Vec<u8>> {
207 let elements: Vec<Vec<u8>> = (0u8..128).map(|x| vec![x]).collect();
208 MerkleTree::new(&elements)
209 }
210
211 #[test]
212 fn test_valid_proofs() {
213 let tree = merkle_tree();
214 let proof = tree.get_proof(43);
215
216 assert!(proof.verify(*tree.get_root_hash(), &vec![43]));
217 }
218
219 #[test]
220 fn test_invalid_proofs() {
221 let tree = merkle_tree();
222 let proof1 = tree.get_proof(43);
223 let proof2 = tree.get_proof(123);
224
225 let invalid_proof_wrong_index = MerkleProof {
226 hash_chain: proof1.hash_chain.clone(),
227 index: proof2.index,
228 phantom: PhantomData,
229 };
230 assert!(!invalid_proof_wrong_index.verify(tree.root_hash, &vec![43]));
231
232 let invalid_proof_wrong_hash_chain = MerkleProof {
233 hash_chain: proof2.hash_chain.clone(),
234 index: proof1.index,
235 phantom: PhantomData,
236 };
237 assert!(!invalid_proof_wrong_hash_chain.verify(tree.root_hash, &vec![43]));
238
239 let invalid_proof_wrong_data = MerkleProof {
240 hash_chain: proof2.hash_chain.clone(),
241 index: proof2.index,
242 phantom: PhantomData,
243 };
244 assert!(!invalid_proof_wrong_data.verify(tree.root_hash, &vec![43]));
245 }
246
247 #[test]
248 fn test_works_with_complex_data() {
249 let elements: Vec<(u32, u32, (u32,))> = (0..128).map(|x| (x, x + 1, (x + 2,))).collect();
250 let tree = MerkleTree::new(&elements);
251 let proof = tree.get_proof(43);
252 assert!(proof.verify(*tree.get_root_hash(), &(43, 44, (45,))));
253 }
254}