1use alloy_sol_types::SolValue;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::{
5 collections::{HashMap, VecDeque},
6 ops::Deref,
7};
8
9use crate::cryptography::hash::{Hash, Hashable};
10
11#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
12pub enum MerkleError {
13 #[error("invalid index: {0}")]
14 InvalidIndex(usize),
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct MerkleTree {
19 tree: Vec<Hash>,
20}
21
22#[derive(Debug)]
23pub struct StandardMerkleTree {
24 tree: MerkleTree,
25 indices: HashMap<Hash, usize>,
26}
27
28impl Deref for StandardMerkleTree {
29 type Target = MerkleTree;
30
31 fn deref(&self) -> &Self::Target {
32 &self.tree
33 }
34}
35
36#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
37pub struct MerkleProof {
38 pub path: Vec<Hash>,
39}
40
41#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
42pub struct MerkleMultiProof {
43 path: Vec<Hash>,
44 flags: Vec<bool>,
45}
46
47impl MerkleProof {
48 pub fn new(path: Vec<Hash>) -> Self {
49 MerkleProof { path }
50 }
51}
52
53fn hash_pair(left: Hash, right: Hash) -> Hash {
54 [left, right].concat().hash_custom()
55}
56
57fn commutative_hash_pair(left: Hash, right: Hash) -> Hash {
58 if left < right {
59 hash_pair(left, right)
60 } else {
61 hash_pair(right, left)
62 }
63}
64
65fn left_child_index(index: usize) -> usize {
66 2 * index + 1
67}
68
69fn right_child_index(index: usize) -> usize {
70 2 * index + 2
71}
72
73fn parent_index(index: usize) -> usize {
74 (index - 1) / 2
75}
76
77fn sibling_index(index: usize) -> usize {
78 if index % 2 == 0 { index - 1 } else { index + 1 }
79}
80
81fn is_leaf_index(tree_len: usize, index: usize) -> bool {
82 index < tree_len && left_child_index(index) >= tree_len
83}
84
85impl Hashable for StandardMerkleTree {
86 fn hash_custom(&self) -> Hash {
87 self.root()
88 }
89}
90
91impl Hashable for MerkleTree {
92 fn hash_custom(&self) -> Hash {
93 self.root()
94 }
95}
96
97impl StandardMerkleTree {
98 pub fn hash_leaf(prefix: String, leaf: Hash) -> Hash {
99 (prefix, leaf).abi_encode_packed().hash_custom()
100 }
101
102 pub fn new(leaves: Vec<Hash>) -> Self {
103 let leaves_sorted = leaves.into_iter().sorted().collect::<Vec<_>>();
104
105 let tree = MerkleTree::new(&leaves_sorted);
106 let indices = leaves_sorted
107 .into_iter()
108 .enumerate()
109 .map(|(i, leaf)| (leaf, tree.length() - i - 1))
110 .collect::<HashMap<Hash, usize>>();
111
112 Self { tree, indices }
113 }
114
115 pub fn generate_proof(&self, leaf: Hash) -> Option<MerkleProof> {
116 self.indices.get(&leaf).map(|&tree_index| {
117 self.tree
118 .generate_proof(tree_index)
119 .expect("it's guaranteed that index is in the tree")
120 })
121 }
122
123 pub fn generate_multi_proof(&self, leaves: &[Hash]) -> Option<MerkleMultiProof> {
124 let mut indices = Vec::new();
125 for leaf in leaves {
126 if let Some(&tree_index) = self.indices.get(leaf) {
127 indices.push(tree_index);
128 } else {
129 return None;
130 }
131 }
132
133 self.tree.generate_multi_proof(&indices)
134 }
135
136 pub fn verify_proof(root: Hash, leaf: Hash, proof: MerkleProof) -> bool {
137 MerkleTree::verify_proof(root, leaf, proof)
138 }
139
140 pub fn verify_multi_proof(root: Hash, leaves: &[Hash], proof: MerkleMultiProof) -> bool {
141 MerkleTree::verify_multi_proof(root, leaves, proof)
142 }
143}
144
145fn join_prefix(prefix: &str, sub: &str) -> String {
146 match (prefix.is_empty(), sub.is_empty()) {
147 (true, true) => "".to_string(),
148 (true, false) => sub.to_string(),
149 (false, true) => prefix.to_string(),
150 (false, false) => format!("{prefix}.{sub}"),
151 }
152}
153
154pub fn index_prefix(prefix: &str, index: usize) -> String {
155 if prefix.is_empty() {
156 format!("[{index}]")
157 } else {
158 format!("{prefix}[{index}]")
159 }
160}
161
162fn apply_prefix_to_leaf(prefix: &str, (sub_prefix, leaf): (String, Hash)) -> Hash {
163 StandardMerkleTree::hash_leaf(join_prefix(prefix, &sub_prefix), leaf)
164}
165
166fn apply_prefix_to_leaves(prefix: &str, leaves: Vec<(String, Hash)>) -> Vec<Hash> {
167 leaves
168 .into_iter()
169 .map(|leaf| apply_prefix_to_leaf(prefix, leaf))
170 .collect()
171}
172pub struct MerkleBuilder {
173 leaves: Vec<(String, Hash)>,
174}
175
176impl Default for MerkleBuilder {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182impl MerkleBuilder {
183 pub fn new() -> Self {
184 Self { leaves: Vec::new() }
185 }
186
187 pub fn add_field(&mut self, name: impl Into<String>, hash: Hash) {
188 self.leaves.push((name.into(), hash));
189 }
190
191 pub fn add_merkleizable(&mut self, prefix: &str, item: &impl Merkleizable) {
192 for (sub_field, hash) in item.leaves() {
193 self.leaves.push((join_prefix(prefix, &sub_field), hash));
194 }
195 }
196
197 pub fn add_slice<T: Merkleizable>(&mut self, prefix: &str, items: &[T]) {
198 for (index, item) in items.iter().enumerate() {
199 self.add_merkleizable(&index_prefix(prefix, index), item);
200 }
201 }
202
203 pub fn build(self) -> Vec<(String, Hash)> {
204 self.leaves
205 }
206}
207
208pub trait ToLeaf {
209 fn to_leaf(&self) -> (String, Hash);
210}
211
212pub trait Merkleizable {
213 fn append_leaves(&self, builder: &mut MerkleBuilder);
214
215 fn leaves(&self) -> Vec<(String, Hash)> {
216 let mut builder = MerkleBuilder::new();
217 self.append_leaves(&mut builder);
218 builder.build()
219 }
220
221 fn to_merkle_tree(&self) -> StandardMerkleTree {
222 let leaves = self
223 .leaves()
224 .into_iter()
225 .map(|(path, leaf)| StandardMerkleTree::hash_leaf(path, leaf))
226 .collect::<Vec<_>>();
227
228 StandardMerkleTree::new(leaves)
229 }
230
231 fn generate_proof<T: ToLeaf>(&self, prefix: &str, item: &T) -> Option<MerkleProof> {
233 let leaf = apply_prefix_to_leaf(prefix, item.to_leaf());
234 self.to_merkle_tree().generate_proof(leaf)
235 }
236
237 fn generate_proofs<T: Merkleizable>(
239 &self,
240 prefix: &str,
241 items: &[T],
242 ) -> Vec<Option<MerkleProof>> {
243 let leaves = apply_prefix_to_leaves(prefix, items.leaves());
244 let tree = self.to_merkle_tree();
245 leaves
246 .into_iter()
247 .map(|leaf| tree.generate_proof(leaf))
248 .collect()
249 }
250
251 fn generate_multi_proof<T: Merkleizable>(
253 &self,
254 prefix: &str,
255 item: &T,
256 ) -> Option<(Vec<Hash>, MerkleMultiProof)> {
257 let leaves = apply_prefix_to_leaves(prefix, item.leaves());
258 Some((
259 leaves.clone(),
260 self.to_merkle_tree().generate_multi_proof(&leaves)?,
261 ))
262 }
263
264 fn generate_multi_proofs<T: Merkleizable>(
266 &self,
267 prefix: &str,
268 items: &[T],
269 ) -> Option<(Vec<Hash>, MerkleMultiProof)> {
270 let leaves = items
271 .iter()
272 .enumerate()
273 .flat_map(|(index, item)| {
274 apply_prefix_to_leaves(&index_prefix(prefix, index), item.leaves())
275 })
276 .collect::<Vec<_>>();
277 Some((
278 leaves.clone(),
279 self.to_merkle_tree().generate_multi_proof(&leaves)?,
280 ))
281 }
282}
283
284impl Merkleizable for Hash {
285 fn append_leaves(&self, builder: &mut MerkleBuilder) {
286 builder.add_field("", *self);
287 }
288}
289
290impl ToLeaf for Hash {
291 fn to_leaf(&self) -> (String, Hash) {
292 ("".to_string(), *self)
293 }
294}
295
296impl<T: Merkleizable> Merkleizable for &[T] {
297 fn append_leaves(&self, builder: &mut MerkleBuilder) {
298 builder.add_slice("", self);
299 }
300}
301
302impl<T: Merkleizable> Merkleizable for Vec<T> {
303 fn append_leaves(&self, builder: &mut MerkleBuilder) {
304 self.as_slice().append_leaves(builder);
305 }
306}
307
308impl MerkleTree {
309 pub fn new(leaves: &[Hash]) -> Self {
310 if leaves.is_empty() {
311 return MerkleTree {
313 tree: vec![Hash::default()],
314 };
315 }
316 let leaves_len = leaves.len();
317 let tree_len = 2 * leaves_len - 1;
318 let mut tree = vec![Hash::default(); tree_len];
319
320 for (i, leaf) in leaves.iter().enumerate() {
321 tree[tree_len - 1 - i] = *leaf;
322 }
323
324 for i in (0..tree_len - leaves_len).rev() {
325 let left_leaf = tree[left_child_index(i)];
326 let right_leaf = tree[right_child_index(i)];
327 tree[i] = commutative_hash_pair(left_leaf, right_leaf);
328 }
329
330 Self { tree }
331 }
332
333 pub fn root(&self) -> Hash {
334 self.tree[0]
335 }
336
337 pub fn length(&self) -> usize {
338 self.tree.len()
339 }
340
341 pub fn generate_proof(&self, index: usize) -> Result<MerkleProof, MerkleError> {
342 let tree_len = self.tree.len();
343 if !is_leaf_index(tree_len, index) {
344 return Err(MerkleError::InvalidIndex(index));
345 }
346
347 let mut path = Vec::new();
348 let mut current = index;
349 while current > 0 {
350 let sibling = sibling_index(current);
351 if sibling < tree_len {
352 path.push(self.tree[sibling]);
353 }
354
355 current = parent_index(current);
356 }
357
358 Ok(MerkleProof::new(path))
359 }
360
361 pub fn generate_multi_proof(&self, indices: &[usize]) -> Option<MerkleMultiProof> {
362 let tree_len = self.tree.len();
363 if indices.iter().any(|&i| !is_leaf_index(tree_len, i)) {
364 return None;
365 }
366
367 let sorted_indices = indices
368 .iter()
369 .cloned()
370 .sorted_by(|a, b| b.cmp(a))
371 .unique()
372 .collect::<Vec<_>>();
373
374 let mut stack = VecDeque::from(sorted_indices);
375 let mut path = Vec::new();
376 let mut flags = Vec::new();
377
378 while let Some(j) = stack.pop_front() {
379 if j == 0 {
380 break;
381 }
382
383 let s = sibling_index(j);
384 let p = parent_index(j);
385
386 match stack.front() {
387 Some(&next) if next == s => {
388 flags.push(true);
389 stack.pop_front();
390 }
391 _ => {
392 flags.push(false);
393 path.push(self.tree[s]);
394 }
395 }
396
397 stack.push_back(p);
398 }
399
400 if indices.is_empty() {
401 path.push(self.tree[0]);
402 }
403
404 Some(MerkleMultiProof { path, flags })
405 }
406
407 pub fn verify_proof(root: Hash, leaf: Hash, proof: MerkleProof) -> bool {
408 root == proof.path.into_iter().fold(leaf, commutative_hash_pair)
409 }
410
411 pub fn verify_multi_proof(root: Hash, leaves: &[Hash], proof: MerkleMultiProof) -> bool {
412 let path_len = proof.path.len();
413 if path_len < proof.flags.iter().filter(|&&f| !f).count() {
414 tracing::debug!("invalid multiproof: too few path hashes");
415 return false;
416 }
417
418 if leaves.len() + path_len != proof.flags.len() + 1 {
419 tracing::debug!("invalid multiproof: invalid total hashes");
420 return false;
421 }
422
423 let mut stack = leaves.iter().cloned().sorted().collect::<Vec<Hash>>();
426
427 let mut path = proof.path.to_vec();
428
429 for flag in proof.flags {
430 let a = stack.remove(0);
431 let b = if flag {
432 stack.remove(0)
433 } else {
434 path.remove(0)
435 };
436
437 stack.push(commutative_hash_pair(a, b));
438 }
439
440 let reconstructed_root = match (stack.len(), path.len()) {
441 (1, 0) => stack.remove(0),
442 (0, 1) => path.remove(0),
443 _ => panic!("invalid multiproof: invalid total hashes"),
444 };
445
446 root == reconstructed_root
447 }
448}
449
450#[cfg(test)]
451mod test {
452 use super::*;
453 use alloy_sol_types::SolValue;
454
455 #[test]
456 pub fn test_standard_tree_proof() {
457 let leaves = vec![
458 StandardMerkleTree::hash_leaf("0".to_string(), 1u32.abi_encode().hash_custom()),
459 StandardMerkleTree::hash_leaf("1".to_string(), 2u32.abi_encode().hash_custom()),
460 StandardMerkleTree::hash_leaf("2".to_string(), 3u32.abi_encode().hash_custom()),
461 ];
462 let tree = StandardMerkleTree::new(leaves.clone());
463 let leaf = leaves[1];
464 let proof = tree.generate_proof(leaf).unwrap();
465 assert!(MerkleTree::verify_proof(tree.root(), leaf, proof.clone()));
466 }
467
468 #[test]
469 pub fn test_standard_tree_multi_proof() {
470 let leaves = vec![
471 StandardMerkleTree::hash_leaf("0".to_string(), 1u32.abi_encode().hash_custom()),
472 StandardMerkleTree::hash_leaf("1".to_string(), 2u32.abi_encode().hash_custom()),
473 StandardMerkleTree::hash_leaf("2".to_string(), 3u32.abi_encode().hash_custom()),
474 ];
475 let tree = StandardMerkleTree::new(leaves.clone());
476 let proof = tree.generate_multi_proof(&leaves).unwrap();
477 assert!(MerkleTree::verify_multi_proof(
478 tree.root(),
479 &leaves,
480 proof.clone()
481 ));
482 }
483}