chia_sdk_driver/
merkle_tree.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
use std::collections::HashMap;

use chia_protocol::Bytes32;
use clvmr::sha2::Sha256;

const HASH_TREE_PREFIX: &[u8] = &[2];
const HASH_LEAF_PREFIX: &[u8] = &[1];

#[derive(Debug, Clone)]
pub struct MerkleTree {
    pub root: Bytes32,
    pub proofs: HashMap<Bytes32, (u32, Vec<Bytes32>)>,
}

use std::fmt::Debug;

#[derive(Debug, Clone)]
pub enum BinaryTree<T> {
    Leaf(T),
    Node(Box<BinaryTree<T>>, Box<BinaryTree<T>>),
}

impl MerkleTree {
    pub fn new(leaves: &[Bytes32]) -> Self {
        if leaves.is_empty() {
            return Self {
                root: Bytes32::default(),
                proofs: HashMap::new(),
            };
        }

        let (root, proofs) = MerkleTree::build_merkle_tree(leaves);
        Self { root, proofs }
    }

    fn build_merkle_tree(leaves: &[Bytes32]) -> (Bytes32, HashMap<Bytes32, (u32, Vec<Bytes32>)>) {
        let binary_tree = MerkleTree::list_to_binary_tree(leaves);
        MerkleTree::build_merkle_tree_from_binary_tree(&binary_tree)
    }

    fn sha256(args: &[&[u8]]) -> Bytes32 {
        let mut hasher = Sha256::new();
        args.iter().for_each(|arg| hasher.update(arg));

        Bytes32::from(hasher.finalize())
    }

    fn list_to_binary_tree<T: Clone + Debug + Default>(objects: &[T]) -> BinaryTree<T> {
        let size = objects.len();
        if size == 0 {
            return BinaryTree::Leaf(T::default());
        }
        if size == 1 {
            return BinaryTree::Leaf(objects[0].clone());
        }
        let midpoint = (size + 1) >> 1;
        let first_half = &objects[..midpoint];
        let last_half = &objects[midpoint..];
        let left_tree = MerkleTree::list_to_binary_tree(first_half);
        let right_tree = MerkleTree::list_to_binary_tree(last_half);
        BinaryTree::Node(Box::new(left_tree), Box::new(right_tree))
    }

    fn build_merkle_tree_from_binary_tree(
        tuples: &BinaryTree<Bytes32>,
    ) -> (Bytes32, HashMap<Bytes32, (u32, Vec<Bytes32>)>) {
        match tuples {
            BinaryTree::Leaf(t) => {
                let hash = MerkleTree::sha256(&[HASH_LEAF_PREFIX, t]);
                let mut proof = HashMap::new();
                proof.insert(*t, (0, vec![]));
                (hash, proof)
            }
            BinaryTree::Node(left, right) => {
                let (left_root, left_proofs) = MerkleTree::build_merkle_tree_from_binary_tree(left);
                let (right_root, right_proofs) =
                    MerkleTree::build_merkle_tree_from_binary_tree(right);

                let new_root = MerkleTree::sha256(&[HASH_TREE_PREFIX, &left_root, &right_root]);
                let mut new_proofs = HashMap::new();

                for (name, (path, mut proof)) in left_proofs {
                    proof.push(right_root);
                    new_proofs.insert(name, (path, proof));
                }

                for (name, (path, mut proof)) in right_proofs {
                    let path = path | (1 << proof.len());
                    proof.push(left_root);
                    new_proofs.insert(name, (path, proof));
                }

                (new_root, new_proofs)
            }
        }
    }

    pub fn get_proof(&self, leaf: Bytes32) -> Option<(u32, Vec<Bytes32>)> {
        self.proofs.get(&leaf).cloned()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use hex_literal::hex;
    use rstest::rstest;

    #[rstest]
    #[case::no_leaves(&[],
           Bytes32::default(),
           vec![]
    )]
    #[case::one_leaf(&[Bytes32::from([1; 32])],
           Bytes32::from(hex!("ce041765675ad4d93378e20bd3a7d0d97ddcf3385fb6341581b21d4bc9e3e69e")),
           vec![(Bytes32::from([1; 32]), 0, vec![])]
    )]
    #[case::two_leaves(&[Bytes32::from([1; 32]), Bytes32::from([2; 32])],
           Bytes32::from(hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4")),
           vec![
               (Bytes32::from([1; 32]), 0, vec![hex!("f1386fff8b06ac98d347997ff5d0abad3b977514b1b7cfe0689f45f3f1393497").into()]),
               (Bytes32::from([2; 32]), 1, vec![hex!("ce041765675ad4d93378e20bd3a7d0d97ddcf3385fb6341581b21d4bc9e3e69e").into()])
           ]
    )]
    #[case::three_leaves(&[Bytes32::from([1; 32]), Bytes32::from([2; 32]), Bytes32::from([3; 32])],
           Bytes32::from(hex!("adb439d3868b9273de8753e20a62a8e6d9ff6cfb43b189337a23df0690c7f55b")), 
           vec![
               (Bytes32::from([1; 32]), 0, vec![hex!("f1386fff8b06ac98d347997ff5d0abad3b977514b1b7cfe0689f45f3f1393497").into(), hex!("131c41585fc6b26c2cf8ea6fc61be03c3c4e3facb3f7e70ec69ea094b17dc3e1").into()]),
               (Bytes32::from([2; 32]), 1, vec![hex!("ce041765675ad4d93378e20bd3a7d0d97ddcf3385fb6341581b21d4bc9e3e69e").into(), hex!("131c41585fc6b26c2cf8ea6fc61be03c3c4e3facb3f7e70ec69ea094b17dc3e1").into()]),
               (Bytes32::from([3; 32]), 1, vec![hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4").into()])
           ]
    )]
    #[case::seven_leaves(&[Bytes32::from([1; 32]), Bytes32::from([2; 32]), Bytes32::from([3; 32]), Bytes32::from([4; 32]), Bytes32::from([5; 32]), Bytes32::from([6; 32]), Bytes32::from([7; 32])],
           Bytes32::from(hex!("1c4b11429685dd0a516282981bb3e12c13596e846f67af1da080b9134cdea4c6")),
           vec![
               (Bytes32::from([1; 32]), 0, vec![hex!("f1386fff8b06ac98d347997ff5d0abad3b977514b1b7cfe0689f45f3f1393497").into(), hex!("1d85c3d5d2a5f093b49c79b2686ff698fb58d3ef4959b939ed6925dc65325499").into(), hex!("c80c9f4f69abfa70474c4d27d076ab32e23ff9bd1215fe87c6a0e6899a126d10").into()]),
               (Bytes32::from([2; 32]), 1, vec![hex!("ce041765675ad4d93378e20bd3a7d0d97ddcf3385fb6341581b21d4bc9e3e69e").into(), hex!("1d85c3d5d2a5f093b49c79b2686ff698fb58d3ef4959b939ed6925dc65325499").into(), hex!("c80c9f4f69abfa70474c4d27d076ab32e23ff9bd1215fe87c6a0e6899a126d10").into()]),
               (Bytes32::from([3; 32]), 2, vec![hex!("db1a2656e1809de78fb29dddf24a1c75fbf7c6dc1f1341f485457c713ce49fa0").into(), hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4").into(), hex!("c80c9f4f69abfa70474c4d27d076ab32e23ff9bd1215fe87c6a0e6899a126d10").into()]),
               (Bytes32::from([4; 32]), 3, vec![hex!("131c41585fc6b26c2cf8ea6fc61be03c3c4e3facb3f7e70ec69ea094b17dc3e1").into(), hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4").into(), hex!("c80c9f4f69abfa70474c4d27d076ab32e23ff9bd1215fe87c6a0e6899a126d10").into()]),
               (Bytes32::from([5; 32]), 4, vec![hex!("0684e189ecc12eb7472925a5b16ec60d10a476a59545452f58fcca994433a4f7").into(), hex!("d3907c0247e7e98b72338a00d87244248df71eb313589da290d45adfba44e6d2").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()]),
               (Bytes32::from([6; 32]), 5, vec![hex!("90cbc3c7c7634183ae482172520c1b8d85ee10f1ca0b4744fdbe7da2245141bb").into(), hex!("d3907c0247e7e98b72338a00d87244248df71eb313589da290d45adfba44e6d2").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()]),
               (Bytes32::from([7; 32]), 3, vec![hex!("3831644ba5da8ec5f16d32ef7c0a318cfec302245fac118321a5da9f43efbf94").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()])
           ]
    )]
    #[case::eight_leaves(&[Bytes32::from([1; 32]), Bytes32::from([2; 32]), Bytes32::from([3; 32]), Bytes32::from([4; 32]), Bytes32::from([5; 32]), Bytes32::from([6; 32]), Bytes32::from([7; 32]), Bytes32::from([8; 32])],
           Bytes32::from(hex!("3023a77c57dd4c0f84fe2d9b42252e483a9974482b6d4d5fbf0e3d405a46f436")),
           vec![
               (Bytes32::from([1; 32]), 0, vec![hex!("f1386fff8b06ac98d347997ff5d0abad3b977514b1b7cfe0689f45f3f1393497").into(), hex!("1d85c3d5d2a5f093b49c79b2686ff698fb58d3ef4959b939ed6925dc65325499").into(), hex!("eb06e593af742e80db1c2bef77f23c85ad87a8048bb1228037cd18d6b50f9042").into()]),
               (Bytes32::from([2; 32]), 1, vec![hex!("ce041765675ad4d93378e20bd3a7d0d97ddcf3385fb6341581b21d4bc9e3e69e").into(), hex!("1d85c3d5d2a5f093b49c79b2686ff698fb58d3ef4959b939ed6925dc65325499").into(), hex!("eb06e593af742e80db1c2bef77f23c85ad87a8048bb1228037cd18d6b50f9042").into()]),
               (Bytes32::from([3; 32]), 2, vec![hex!("db1a2656e1809de78fb29dddf24a1c75fbf7c6dc1f1341f485457c713ce49fa0").into(), hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4").into(), hex!("eb06e593af742e80db1c2bef77f23c85ad87a8048bb1228037cd18d6b50f9042").into()]),
               (Bytes32::from([4; 32]), 3, vec![hex!("131c41585fc6b26c2cf8ea6fc61be03c3c4e3facb3f7e70ec69ea094b17dc3e1").into(), hex!("00f2e7e0bc3ee77f0b5aa330406f69bfbd5c2e3b8a4338dba49f64bb3f0247c4").into(), hex!("eb06e593af742e80db1c2bef77f23c85ad87a8048bb1228037cd18d6b50f9042").into()]),
               (Bytes32::from([5; 32]), 4, vec![hex!("0684e189ecc12eb7472925a5b16ec60d10a476a59545452f58fcca994433a4f7").into(), hex!("f76c002f93a1ba959ebe50568ba888a5d1871e2f804977e996bb6932f7eadf06").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()]),
               (Bytes32::from([6; 32]), 5, vec![hex!("90cbc3c7c7634183ae482172520c1b8d85ee10f1ca0b4744fdbe7da2245141bb").into(), hex!("f76c002f93a1ba959ebe50568ba888a5d1871e2f804977e996bb6932f7eadf06").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()]),
               (Bytes32::from([7; 32]), 6, vec![hex!("467d8acd80729c1fe2c497db207e7861b0fd9aab3552da7a2abb828a45f288cc").into(), hex!("3831644ba5da8ec5f16d32ef7c0a318cfec302245fac118321a5da9f43efbf94").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()]),
               (Bytes32::from([8; 32]), 7, vec![hex!("d3907c0247e7e98b72338a00d87244248df71eb313589da290d45adfba44e6d2").into(), hex!("3831644ba5da8ec5f16d32ef7c0a318cfec302245fac118321a5da9f43efbf94").into(), hex!("7eb919730e38f305365791a43adddeea0fc275371aac8c7b08983937beeb956f").into()])
           ]
    )]
    fn test_merkle_tree(
        #[case] leaves: &[Bytes32],
        #[case] expected_root: Bytes32,
        #[case] expected_proofs: Vec<(Bytes32, u32, Vec<Bytes32>)>,
    ) {
        let merkle_tree = MerkleTree::new(leaves);

        assert_eq!(merkle_tree.root, expected_root);

        for (leaf, path, proof) in expected_proofs {
            assert_eq!(merkle_tree.get_proof(leaf), Some((path, proof)));
        }
    }
}