pod_types/cryptography/
merkle_tree.rs

1use alloy_sol_types::SolValue;
2use itertools::Itertools;
3use serde::{Deserialize, Serialize};
4use std::{
5    collections::{HashMap, VecDeque},
6    ops::Deref,
7};
8
9use crate::cryptography::hash::{Hash, Hashable};
10
11#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
12pub enum MerkleError {
13    #[error("invalid index: {0}")]
14    InvalidIndex(usize),
15}
16
17#[derive(Debug, Clone, Default)]
18pub struct MerkleTree {
19    tree: Vec<Hash>,
20}
21
22#[derive(Debug)]
23pub struct StandardMerkleTree {
24    tree: MerkleTree,
25    indices: HashMap<Hash, usize>,
26}
27
28impl Deref for StandardMerkleTree {
29    type Target = MerkleTree;
30
31    fn deref(&self) -> &Self::Target {
32        &self.tree
33    }
34}
35
36#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
37pub struct MerkleProof {
38    pub path: Vec<Hash>,
39}
40
41#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
42pub struct MerkleMultiProof {
43    path: Vec<Hash>,
44    flags: Vec<bool>,
45}
46
47impl MerkleProof {
48    pub fn new(path: Vec<Hash>) -> Self {
49        MerkleProof { path }
50    }
51}
52
53fn hash_pair(left: Hash, right: Hash) -> Hash {
54    [left, right].concat().hash_custom()
55}
56
57fn commutative_hash_pair(left: Hash, right: Hash) -> Hash {
58    if left < right {
59        hash_pair(left, right)
60    } else {
61        hash_pair(right, left)
62    }
63}
64
65fn left_child_index(index: usize) -> usize {
66    2 * index + 1
67}
68
69fn right_child_index(index: usize) -> usize {
70    2 * index + 2
71}
72
73fn parent_index(index: usize) -> usize {
74    (index - 1) / 2
75}
76
77fn sibling_index(index: usize) -> usize {
78    if index % 2 == 0 { index - 1 } else { index + 1 }
79}
80
81fn is_leaf_index(tree_len: usize, index: usize) -> bool {
82    index < tree_len && left_child_index(index) >= tree_len
83}
84
85impl Hashable for StandardMerkleTree {
86    fn hash_custom(&self) -> Hash {
87        self.root()
88    }
89}
90
91impl Hashable for MerkleTree {
92    fn hash_custom(&self) -> Hash {
93        self.root()
94    }
95}
96
97impl StandardMerkleTree {
98    pub fn hash_leaf(prefix: String, leaf: Hash) -> Hash {
99        (prefix, leaf).abi_encode_packed().hash_custom()
100    }
101
102    pub fn new(leaves: Vec<Hash>) -> Self {
103        let leaves_sorted = leaves.into_iter().sorted().collect::<Vec<_>>();
104
105        let tree = MerkleTree::new(&leaves_sorted);
106        let indices = leaves_sorted
107            .into_iter()
108            .enumerate()
109            .map(|(i, leaf)| (leaf, tree.length() - i - 1))
110            .collect::<HashMap<Hash, usize>>();
111
112        Self { tree, indices }
113    }
114
115    pub fn generate_proof(&self, leaf: Hash) -> Option<MerkleProof> {
116        self.indices.get(&leaf).map(|&tree_index| {
117            self.tree
118                .generate_proof(tree_index)
119                .expect("it's guaranteed that index is in the tree")
120        })
121    }
122
123    pub fn generate_multi_proof(&self, leaves: &[Hash]) -> Option<MerkleMultiProof> {
124        let mut indices = Vec::new();
125        for leaf in leaves {
126            if let Some(&tree_index) = self.indices.get(leaf) {
127                indices.push(tree_index);
128            } else {
129                return None;
130            }
131        }
132
133        self.tree.generate_multi_proof(&indices)
134    }
135
136    pub fn verify_proof(root: Hash, leaf: Hash, proof: MerkleProof) -> bool {
137        MerkleTree::verify_proof(root, leaf, proof)
138    }
139
140    pub fn verify_multi_proof(root: Hash, leaves: &[Hash], proof: MerkleMultiProof) -> bool {
141        MerkleTree::verify_multi_proof(root, leaves, proof)
142    }
143}
144
145fn join_prefix(prefix: &str, sub: &str) -> String {
146    match (prefix.is_empty(), sub.is_empty()) {
147        (true, true) => "".to_string(),
148        (true, false) => sub.to_string(),
149        (false, true) => prefix.to_string(),
150        (false, false) => format!("{prefix}.{sub}"),
151    }
152}
153
154pub fn index_prefix(prefix: &str, index: usize) -> String {
155    if prefix.is_empty() {
156        format!("[{index}]")
157    } else {
158        format!("{prefix}[{index}]")
159    }
160}
161
162fn apply_prefix_to_leaf(prefix: &str, (sub_prefix, leaf): (String, Hash)) -> Hash {
163    StandardMerkleTree::hash_leaf(join_prefix(prefix, &sub_prefix), leaf)
164}
165
166fn apply_prefix_to_leaves(prefix: &str, leaves: Vec<(String, Hash)>) -> Vec<Hash> {
167    leaves
168        .into_iter()
169        .map(|leaf| apply_prefix_to_leaf(prefix, leaf))
170        .collect()
171}
172pub struct MerkleBuilder {
173    leaves: Vec<(String, Hash)>,
174}
175
176impl Default for MerkleBuilder {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182impl MerkleBuilder {
183    pub fn new() -> Self {
184        Self { leaves: Vec::new() }
185    }
186
187    pub fn add_field(&mut self, name: impl Into<String>, hash: Hash) {
188        self.leaves.push((name.into(), hash));
189    }
190
191    pub fn add_merkleizable(&mut self, prefix: &str, item: &impl Merkleizable) {
192        for (sub_field, hash) in item.leaves() {
193            self.leaves.push((join_prefix(prefix, &sub_field), hash));
194        }
195    }
196
197    pub fn add_slice<T: Merkleizable>(&mut self, prefix: &str, items: &[T]) {
198        for (index, item) in items.iter().enumerate() {
199            self.add_merkleizable(&index_prefix(prefix, index), item);
200        }
201    }
202
203    pub fn build(self) -> Vec<(String, Hash)> {
204        self.leaves
205    }
206}
207
208pub trait ToLeaf {
209    fn to_leaf(&self) -> (String, Hash);
210}
211
212pub trait Merkleizable {
213    fn append_leaves(&self, builder: &mut MerkleBuilder);
214
215    fn leaves(&self) -> Vec<(String, Hash)> {
216        let mut builder = MerkleBuilder::new();
217        self.append_leaves(&mut builder);
218        builder.build()
219    }
220
221    fn to_merkle_tree(&self) -> StandardMerkleTree {
222        let leaves = self
223            .leaves()
224            .into_iter()
225            .map(|(path, leaf)| StandardMerkleTree::hash_leaf(path, leaf))
226            .collect::<Vec<_>>();
227
228        StandardMerkleTree::new(leaves)
229    }
230
231    // Generate a proof for the given item.
232    fn generate_proof<T: ToLeaf>(&self, prefix: &str, item: &T) -> Option<MerkleProof> {
233        let leaf = apply_prefix_to_leaf(prefix, item.to_leaf());
234        self.to_merkle_tree().generate_proof(leaf)
235    }
236
237    // Generate proofs for the given items.
238    fn generate_proofs<T: Merkleizable>(
239        &self,
240        prefix: &str,
241        items: &[T],
242    ) -> Vec<Option<MerkleProof>> {
243        let leaves = apply_prefix_to_leaves(prefix, items.leaves());
244        let tree = self.to_merkle_tree();
245        leaves
246            .into_iter()
247            .map(|leaf| tree.generate_proof(leaf))
248            .collect()
249    }
250
251    // Generates a multiproof for all children of a given item.
252    fn generate_multi_proof<T: Merkleizable>(
253        &self,
254        prefix: &str,
255        item: &T,
256    ) -> Option<(Vec<Hash>, MerkleMultiProof)> {
257        let leaves = apply_prefix_to_leaves(prefix, item.leaves());
258        Some((
259            leaves.clone(),
260            self.to_merkle_tree().generate_multi_proof(&leaves)?,
261        ))
262    }
263
264    // Generates multiproofs for all children of the given items.
265    fn generate_multi_proofs<T: Merkleizable>(
266        &self,
267        prefix: &str,
268        items: &[T],
269    ) -> Option<(Vec<Hash>, MerkleMultiProof)> {
270        let leaves = items
271            .iter()
272            .enumerate()
273            .flat_map(|(index, item)| {
274                apply_prefix_to_leaves(&index_prefix(prefix, index), item.leaves())
275            })
276            .collect::<Vec<_>>();
277        Some((
278            leaves.clone(),
279            self.to_merkle_tree().generate_multi_proof(&leaves)?,
280        ))
281    }
282}
283
284impl Merkleizable for Hash {
285    fn append_leaves(&self, builder: &mut MerkleBuilder) {
286        builder.add_field("", *self);
287    }
288}
289
290impl ToLeaf for Hash {
291    fn to_leaf(&self) -> (String, Hash) {
292        ("".to_string(), *self)
293    }
294}
295
296impl<T: Merkleizable> Merkleizable for &[T] {
297    fn append_leaves(&self, builder: &mut MerkleBuilder) {
298        builder.add_slice("", self);
299    }
300}
301
302impl<T: Merkleizable> Merkleizable for Vec<T> {
303    fn append_leaves(&self, builder: &mut MerkleBuilder) {
304        self.as_slice().append_leaves(builder);
305    }
306}
307
308impl MerkleTree {
309    pub fn new(leaves: &[Hash]) -> Self {
310        if leaves.is_empty() {
311            // TODO: right approach?
312            return MerkleTree {
313                tree: vec![Hash::default()],
314            };
315        }
316        let leaves_len = leaves.len();
317        let tree_len = 2 * leaves_len - 1;
318        let mut tree = vec![Hash::default(); tree_len];
319
320        for (i, leaf) in leaves.iter().enumerate() {
321            tree[tree_len - 1 - i] = *leaf;
322        }
323
324        for i in (0..tree_len - leaves_len).rev() {
325            let left_leaf = tree[left_child_index(i)];
326            let right_leaf = tree[right_child_index(i)];
327            tree[i] = commutative_hash_pair(left_leaf, right_leaf);
328        }
329
330        Self { tree }
331    }
332
333    pub fn root(&self) -> Hash {
334        self.tree[0]
335    }
336
337    pub fn length(&self) -> usize {
338        self.tree.len()
339    }
340
341    pub fn generate_proof(&self, index: usize) -> Result<MerkleProof, MerkleError> {
342        let tree_len = self.tree.len();
343        if !is_leaf_index(tree_len, index) {
344            return Err(MerkleError::InvalidIndex(index));
345        }
346
347        let mut path = Vec::new();
348        let mut current = index;
349        while current > 0 {
350            let sibling = sibling_index(current);
351            if sibling < tree_len {
352                path.push(self.tree[sibling]);
353            }
354
355            current = parent_index(current);
356        }
357
358        Ok(MerkleProof::new(path))
359    }
360
361    pub fn generate_multi_proof(&self, indices: &[usize]) -> Option<MerkleMultiProof> {
362        let tree_len = self.tree.len();
363        if indices.iter().any(|&i| !is_leaf_index(tree_len, i)) {
364            return None;
365        }
366
367        let sorted_indices = indices
368            .iter()
369            .cloned()
370            .sorted_by(|a, b| b.cmp(a))
371            .unique()
372            .collect::<Vec<_>>();
373
374        let mut stack = VecDeque::from(sorted_indices);
375        let mut path = Vec::new();
376        let mut flags = Vec::new();
377
378        while let Some(j) = stack.pop_front() {
379            if j == 0 {
380                break;
381            }
382
383            let s = sibling_index(j);
384            let p = parent_index(j);
385
386            match stack.front() {
387                Some(&next) if next == s => {
388                    flags.push(true);
389                    stack.pop_front();
390                }
391                _ => {
392                    flags.push(false);
393                    path.push(self.tree[s]);
394                }
395            }
396
397            stack.push_back(p);
398        }
399
400        if indices.is_empty() {
401            path.push(self.tree[0]);
402        }
403
404        Some(MerkleMultiProof { path, flags })
405    }
406
407    pub fn verify_proof(root: Hash, leaf: Hash, proof: MerkleProof) -> bool {
408        root == proof.path.into_iter().fold(leaf, commutative_hash_pair)
409    }
410
411    pub fn verify_multi_proof(root: Hash, leaves: &[Hash], proof: MerkleMultiProof) -> bool {
412        let path_len = proof.path.len();
413        if path_len < proof.flags.iter().filter(|&&f| !f).count() {
414            tracing::debug!("invalid multiproof: too few path hashes");
415            return false;
416        }
417
418        if leaves.len() + path_len != proof.flags.len() + 1 {
419            tracing::debug!("invalid multiproof: invalid total hashes");
420            return false;
421        }
422
423        // This is a deviation from OpenZeppelin's implementation,
424        // which expects leaves to be given in sorted order.
425        let mut stack = leaves.iter().cloned().sorted().collect::<Vec<Hash>>();
426
427        let mut path = proof.path.to_vec();
428
429        for flag in proof.flags {
430            let a = stack.remove(0);
431            let b = if flag {
432                stack.remove(0)
433            } else {
434                path.remove(0)
435            };
436
437            stack.push(commutative_hash_pair(a, b));
438        }
439
440        let reconstructed_root = match (stack.len(), path.len()) {
441            (1, 0) => stack.remove(0),
442            (0, 1) => path.remove(0),
443            _ => panic!("invalid multiproof: invalid total hashes"),
444        };
445
446        root == reconstructed_root
447    }
448}
449
450#[cfg(test)]
451mod test {
452    use super::*;
453    use alloy_sol_types::SolValue;
454
455    #[test]
456    pub fn test_standard_tree_proof() {
457        let leaves = vec![
458            StandardMerkleTree::hash_leaf("0".to_string(), 1u32.abi_encode().hash_custom()),
459            StandardMerkleTree::hash_leaf("1".to_string(), 2u32.abi_encode().hash_custom()),
460            StandardMerkleTree::hash_leaf("2".to_string(), 3u32.abi_encode().hash_custom()),
461        ];
462        let tree = StandardMerkleTree::new(leaves.clone());
463        let leaf = leaves[1];
464        let proof = tree.generate_proof(leaf).unwrap();
465        assert!(MerkleTree::verify_proof(tree.root(), leaf, proof.clone()));
466    }
467
468    #[test]
469    pub fn test_standard_tree_multi_proof() {
470        let leaves = vec![
471            StandardMerkleTree::hash_leaf("0".to_string(), 1u32.abi_encode().hash_custom()),
472            StandardMerkleTree::hash_leaf("1".to_string(), 2u32.abi_encode().hash_custom()),
473            StandardMerkleTree::hash_leaf("2".to_string(), 3u32.abi_encode().hash_custom()),
474        ];
475        let tree = StandardMerkleTree::new(leaves.clone());
476        let proof = tree.generate_multi_proof(&leaves).unwrap();
477        assert!(MerkleTree::verify_multi_proof(
478            tree.root(),
479            &leaves,
480            proof.clone()
481        ));
482    }
483}