nam_sparse_merkle_tree/
merkle_proof.rs

1use crate::{
2    collections::{BTreeMap, VecDeque},
3    error::{Error, Result},
4    merge::{hash_leaf, merge},
5    traits::{Hasher, Value},
6    vec::Vec,
7    Key, H256, TREE_HEIGHT,
8};
9use core::convert::TryInto;
10
11type Range = core::ops::Range<usize>;
12
13#[derive(Debug, Clone)]
14pub struct MerkleProof {
15    leaves_path: Vec<Vec<usize>>,
16    proof: Vec<(H256, usize)>,
17}
18
19impl MerkleProof {
20    /// Create MerkleProof
21    /// leaves_path: contains height of non-zero siblings
22    /// proof: contains merkle path for each leaves it's height
23    pub fn new(leaves_path: Vec<Vec<usize>>, proof: Vec<(H256, usize)>) -> Self {
24        MerkleProof { leaves_path, proof }
25    }
26
27    /// Destruct the structure, useful for serialization
28    pub fn take(self) -> (Vec<Vec<usize>>, Vec<(H256, usize)>) {
29        let MerkleProof { leaves_path, proof } = self;
30        (leaves_path, proof)
31    }
32
33    /// number of leaves required by this merkle proof
34    pub fn leaves_count(&self) -> usize {
35        self.leaves_path.len()
36    }
37
38    /// return the inner leaves_path vector
39    pub fn leaves_path(&self) -> &Vec<Vec<usize>> {
40        &self.leaves_path
41    }
42
43    /// return proof merkle path
44    pub fn proof(&self) -> &Vec<(H256, usize)> {
45        &self.proof
46    }
47
48    /// convert merkle proof into CompiledMerkleProof
49    pub fn compile<K, const N: usize>(
50        self,
51        mut leaves: Vec<(K, H256)>,
52    ) -> Result<CompiledMerkleProof>
53    where
54        K: Key<N>,
55    {
56        if leaves.is_empty() {
57            return Err(Error::EmptyKeys);
58        } else if leaves.len() != self.leaves_count() {
59            return Err(Error::IncorrectNumberOfLeaves {
60                expected: self.leaves_count(),
61                actual: leaves.len(),
62            });
63        }
64
65        let (leaves_path, proof) = self.take();
66        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
67        let mut proof: VecDeque<_> = proof.into();
68
69        // sort leaves
70        leaves.sort_unstable_by_key(|(k, _v)| **k);
71        // tree_buf: (height, key) -> (key_index, node)
72        let mut tree_buf: BTreeMap<_, _> = leaves
73            .into_iter()
74            .enumerate()
75            .map(|(i, (k, _v))| ((0, *k), (i, leaf_program(i))))
76            .collect();
77        // rebuild the tree from bottom to top
78        while !tree_buf.is_empty() {
79            // pop_front from tree_buf, the API is unstable
80            let &(mut height, key) = tree_buf.keys().next().unwrap();
81            let (leaf_index, program) = tree_buf.remove(&(height, key)).unwrap();
82
83            if proof.is_empty() && tree_buf.is_empty() {
84                return Ok(CompiledMerkleProof(program.0));
85            } else if height == TREE_HEIGHT {
86                if !proof.is_empty() {
87                    return Err(Error::CorruptedProof);
88                }
89                return Ok(CompiledMerkleProof(program.0));
90            }
91
92            let mut sibling_key = key.parent_path(height);
93            if !key.get_bit(height) {
94                sibling_key.set_bit(height)
95            }
96
97            let (parent_key, parent_program, height) =
98                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
99                    let (_leaf_index, sibling_program) = tree_buf
100                        .remove(&(height, sibling_key))
101                        .expect("pop sibling");
102                    let parent_key = key.parent_path(height);
103                    let parent_program = merge_program(&program, &sibling_program, height)?;
104                    (parent_key, parent_program, height)
105                } else {
106                    let merge_height = leaves_path[leaf_index]
107                        .front()
108                        .map(|h| *h as usize)
109                        .unwrap_or(height);
110                    if height != merge_height {
111                        debug_assert!(height < merge_height);
112                        let parent_key = key.copy_bits(merge_height..);
113                        // skip zeros
114                        tree_buf.insert((merge_height, parent_key), (leaf_index, program));
115                        continue;
116                    }
117                    let (proof, proof_height) = proof.pop_front().expect("pop proof");
118                    debug_assert_eq!(proof_height, leaves_path[leaf_index][0]);
119                    let proof_height = proof_height as usize;
120                    debug_assert!(height <= proof_height);
121                    if height < proof_height {
122                        height = proof_height;
123                    }
124
125                    let parent_key = key.parent_path(height);
126                    let parent_program = proof_program(&program, proof, height);
127                    (parent_key, parent_program, height)
128                };
129
130            leaves_path[leaf_index].pop_front();
131            tree_buf.insert((height + 1, parent_key), (leaf_index, parent_program));
132        }
133
134        Err(Error::CorruptedProof)
135    }
136
137    /// Compute root from proof
138    /// leaves: a vector of (key, value)
139    ///
140    /// return EmptyProof error when proof is empty
141    /// return CorruptedProof error when proof is invalid
142    pub fn compute_root<H: Hasher + Default, K, V, const N: usize>(
143        self,
144        mut leaves: Vec<(K, V)>,
145    ) -> Result<H256>
146    where
147        K: Key<N>,
148        V: Value,
149    {
150        if leaves.is_empty() {
151            return Err(Error::EmptyKeys);
152        } else if leaves.len() != self.leaves_count() {
153            return Err(Error::IncorrectNumberOfLeaves {
154                expected: self.leaves_count(),
155                actual: leaves.len(),
156            });
157        }
158
159        let (leaves_path, proof) = self.take();
160        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
161        let mut proof: VecDeque<_> = proof.into();
162
163        // sort leaves
164        leaves.sort_unstable_by_key(|(k, _v)| **k);
165        // tree_buf: (height, key) -> (key_index, node)
166        let mut tree_buf: BTreeMap<_, _> = leaves
167            .into_iter()
168            .enumerate()
169            .map(|(i, (k, v))| ((0, *k), (i, hash_leaf::<H, K, V, N>(&k, &v))))
170            .collect();
171        // rebuild the tree from bottom to top
172        while !tree_buf.is_empty() {
173            // pop_front from tree_buf, the API is unstable
174            let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
175            tree_buf.remove(&(height, key));
176
177            if proof.is_empty() && tree_buf.is_empty() {
178                return Ok(node);
179            } else if height == 8 * N {
180                if !proof.is_empty() {
181                    return Err(Error::CorruptedProof);
182                }
183                return Ok(node);
184            }
185
186            let mut sibling_key = key.parent_path(height);
187            if !key.get_bit(height) {
188                sibling_key.set_bit(height)
189            }
190            let (sibling, sibling_height) =
191                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
192                    let (_leaf_index, sibling) = tree_buf
193                        .remove(&(height, sibling_key))
194                        .expect("pop sibling");
195                    (sibling, height)
196                } else {
197                    let merge_height = leaves_path[leaf_index]
198                        .front()
199                        .map(|h| *h as usize)
200                        .unwrap_or(height);
201                    if height != merge_height {
202                        debug_assert!(height < merge_height);
203                        let parent_key = key.copy_bits(merge_height..);
204                        // skip zeros
205                        tree_buf.insert((merge_height, parent_key), (leaf_index, node));
206                        continue;
207                    }
208                    let (node, height) = proof.pop_front().expect("pop proof");
209                    debug_assert_eq!(height, leaves_path[leaf_index][0]);
210                    (node, height as usize)
211                };
212            debug_assert!(height <= sibling_height);
213            if height < sibling_height {
214                height = sibling_height;
215            }
216            // skip zero merkle path
217            let parent_key = key.parent_path(height);
218
219            let parent = if key.get_bit(height) {
220                merge::<H>(&sibling, &node)
221            } else {
222                merge::<H>(&node, &sibling)
223            };
224            leaves_path[leaf_index].pop_front();
225            tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
226        }
227
228        Err(Error::CorruptedProof)
229    }
230
231    /// Verify merkle proof
232    /// see compute_root_from_proof
233    pub fn verify<H: Hasher + Default, K, V, const N: usize>(
234        self,
235        root: &H256,
236        leaves: Vec<(K, V)>,
237    ) -> Result<bool>
238    where
239        K: Key<N>,
240        V: Value
241    {
242        let calculated_root = self.compute_root::<H, K, V, N>(leaves)?;
243        Ok(&calculated_root == root)
244    }
245}
246
247fn leaf_program(leaf_index: usize) -> (Vec<u8>, Option<Range>) {
248    let program = vec![0x4C];
249    (
250        program,
251        Some(Range {
252            start: leaf_index,
253            end: leaf_index + 1,
254        }),
255    )
256}
257
258fn proof_program(
259    child: &(Vec<u8>, Option<Range>),
260    proof: H256,
261    height: usize,
262) -> (Vec<u8>, Option<Range>) {
263    let (child_program, child_range) = child;
264    let mut program = Vec::new();
265    let height = height as u64;
266    program.resize(41 + child_program.len(), 0x50);
267    program[..child_program.len()].copy_from_slice(child_program);
268    program[child_program.len() + 1..child_program.len() + 9]
269        .copy_from_slice(&height.to_be_bytes());
270    program[child_program.len() + 9..].copy_from_slice(proof.as_slice());
271    (program, child_range.clone())
272}
273
274fn merge_program(
275    a: &(Vec<u8>, Option<Range>),
276    b: &(Vec<u8>, Option<Range>),
277    height: usize,
278) -> Result<(Vec<u8>, Option<Range>)> {
279    let (a_program, a_range) = a;
280    let (b_program, b_range) = b;
281    let (a_comes_first, range) = if a_range.is_none() || b_range.is_none() {
282        let range = if a_range.is_none() { b_range } else { a_range }
283            .clone()
284            .unwrap();
285        (true, range)
286    } else {
287        let a_range = a_range.clone().unwrap();
288        let b_range = b_range.clone().unwrap();
289        if a_range.end == b_range.start {
290            (
291                true,
292                Range {
293                    start: a_range.start,
294                    end: b_range.end,
295                },
296            )
297        } else {
298            return Err(Error::NonMergableRange);
299        }
300    };
301    let mut program = Vec::new();
302    program.resize(9 + a_program.len() + b_program.len(), 0x48);
303    if a_comes_first {
304        program[..a_program.len()].copy_from_slice(a_program);
305        program[a_program.len()..a_program.len() + b_program.len()].copy_from_slice(b_program);
306    } else {
307        program[..b_program.len()].copy_from_slice(b_program);
308        program[b_program.len()..a_program.len() + b_program.len()].copy_from_slice(a_program);
309    }
310    let height = height as u64;
311    let height_pos = a_program.len() + b_program.len() + 1;
312    program[height_pos..height_pos + 8].copy_from_slice(&height.to_be_bytes());
313    Ok((program, Some(range)))
314}
315
316/// An structure optimized for verify merkle proof
317#[derive(Debug, Clone)]
318pub struct CompiledMerkleProof(pub Vec<u8>);
319
320impl CompiledMerkleProof {
321    pub fn compute_root<H: Hasher + Default, K, V, const N: usize>(
322        &self,
323        mut leaves: Vec<(K, V)>,
324    ) -> Result<H256>
325    where
326        K: Key<N>,
327        V: Value,
328    {
329        leaves.sort_unstable_by_key(|(k, _v)| **k);
330        let mut program_index = 0;
331        let mut leave_index = 0;
332        let mut stack = Vec::new();
333        while program_index < self.0.len() {
334            let code = self.0[program_index];
335            program_index += 1;
336            match code {
337                // L
338                0x4C => {
339                    if leave_index >= leaves.len() {
340                        return Err(Error::CorruptedStack);
341                    }
342                    let (k, v) = leaves[leave_index].clone();
343                    stack.push((*k, hash_leaf::<H, K, V, N>(&k, &v)));
344                    leave_index += 1;
345                }
346                // P
347                0x50 => {
348                    if stack.is_empty() {
349                        return Err(Error::CorruptedStack);
350                    }
351                    if program_index + 40 > self.0.len() {
352                        return Err(Error::CorruptedProof);
353                    }
354                    let height: [u8; 8] = self.0[program_index..program_index + 8]
355                        .try_into()
356                        .expect("8 bytes should fit in an 8 byte array");
357                    let height = u64::from_be_bytes(height) as usize;
358                    program_index += 8;
359                    let mut data = [0u8; 32];
360                    data.copy_from_slice(&self.0[program_index..program_index + 32]);
361                    program_index += 32;
362                    let proof = H256::from(data);
363                    let (key, value) = stack.pop().unwrap();
364                    let parent_key = key.parent_path(height);
365                    let parent = if key.get_bit(height) {
366                        merge::<H>(&proof, &value)
367                    } else {
368                        merge::<H>(&value, &proof)
369                    };
370                    stack.push((parent_key, parent));
371                }
372                // H
373                0x48 => {
374                    if stack.len() < 2 {
375                        return Err(Error::CorruptedStack);
376                    }
377                    if program_index >= self.0.len() {
378                        return Err(Error::CorruptedProof);
379                    }
380                    let height: [u8; 8] = self.0[program_index..program_index + 8]
381                        .try_into()
382                        .expect("8 bytes should fit in an 8 byte array");
383                    let height = u64::from_be_bytes(height) as usize;
384                    program_index += 8;
385                    let (key_b, value_b) = stack.pop().unwrap();
386                    let (key_a, value_a) = stack.pop().unwrap();
387                    let parent_key_a = key_a.copy_bits(height..);
388                    let parent_key_b = key_b.copy_bits(height..);
389                    let a_set = key_a.get_bit(height);
390                    let b_set = key_b.get_bit(height);
391                    let mut sibling_key_a = parent_key_a;
392                    if !a_set {
393                        sibling_key_a.set_bit(height);
394                    }
395                    // Test if a and b are siblings
396                    if !(sibling_key_a == parent_key_b && (a_set ^ b_set)) {
397                        return Err(Error::NonSiblings);
398                    }
399                    let parent = if key_a.get_bit(height) {
400                        merge::<H>(&value_b, &value_a)
401                    } else {
402                        merge::<H>(&value_a, &value_b)
403                    };
404                    stack.push((parent_key_a, parent));
405                }
406                _ => return Err(Error::InvalidCode(code)),
407            }
408        }
409        if stack.len() != 1 {
410            return Err(Error::CorruptedStack);
411        }
412        Ok(stack[0].1)
413    }
414
415    pub fn verify<H: Hasher + Default, K, V, const N: usize>(
416        &self,
417        root: &H256,
418        leaves: Vec<(K, V)>,
419    ) -> Result<bool>
420    where
421        K: Key<N>,
422        V: Value,
423    {
424        let calculated_root = self.compute_root::<H, K, V, N>(leaves)?;
425        Ok(&calculated_root == root)
426    }
427}