restricted_sparse_merkle_tree/
merkle_proof.rs

1use crate::{
2    collections::{BTreeMap, VecDeque},
3    error::{Error, Result},
4    merge::{hash_leaf, merge},
5    traits::Hasher,
6    vec::Vec,
7    H256,
8};
9
10type Range = core::ops::Range<usize>;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct MerkleProof {
14    leaves_path: Vec<Vec<u8>>,
15    proof: Vec<(H256, u8)>,
16}
17
18impl MerkleProof {
19    /// Create MerkleProof
20    /// leaves_path: contains height of non-zero siblings
21    /// proof: contains merkle path for each leaves it's height
22    pub fn new(leaves_path: Vec<Vec<u8>>, proof: Vec<(H256, u8)>) -> Self {
23        MerkleProof { leaves_path, proof }
24    }
25
26    /// Destruct the structure, useful for serialization
27    pub fn take(self) -> (Vec<Vec<u8>>, Vec<(H256, u8)>) {
28        let MerkleProof { leaves_path, proof } = self;
29        (leaves_path, proof)
30    }
31
32    /// number of leaves required by this merkle proof
33    pub fn leaves_count(&self) -> usize {
34        self.leaves_path.len()
35    }
36
37    /// return the inner leaves_path vector
38    pub fn leaves_path(&self) -> &Vec<Vec<u8>> {
39        &self.leaves_path
40    }
41
42    /// return proof merkle path
43    pub fn proof(&self) -> &Vec<(H256, u8)> {
44        &self.proof
45    }
46
47    /// convert merkle proof into CompiledMerkleProof
48    pub fn compile(self, mut leaves: Vec<(H256, H256)>) -> Result<CompiledMerkleProof> {
49        if leaves.is_empty() {
50            return Err(Error::EmptyKeys);
51        } else if leaves.len() != self.leaves_count() {
52            return Err(Error::IncorrectNumberOfLeaves {
53                expected: self.leaves_count(),
54                actual: leaves.len(),
55            });
56        }
57
58        let (leaves_path, proof) = self.take();
59        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
60        let mut proof: VecDeque<_> = proof.into();
61
62        // sort leaves
63        leaves.sort_unstable_by_key(|(k, _v)| *k);
64        // tree_buf: (height, key) -> (key_index, node)
65        let mut tree_buf: BTreeMap<_, _> = leaves
66            .into_iter()
67            .enumerate()
68            .map(|(i, (k, _v))| ((0, k), (i, leaf_program(i))))
69            .collect();
70        // rebuild the tree from bottom to top
71        while !tree_buf.is_empty() {
72            // pop_front from tree_buf, the API is unstable
73            let &(mut height, key) = tree_buf.keys().next().unwrap();
74            let (leaf_index, program) = tree_buf.remove(&(height, key)).unwrap();
75
76            if proof.is_empty() && tree_buf.is_empty() {
77                return Ok(CompiledMerkleProof(program.0));
78            }
79
80            let mut sibling_key = key.parent_path(height);
81            if !key.get_bit(height) {
82                sibling_key.set_bit(height)
83            }
84
85            let (parent_key, parent_program, height) =
86                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
87                    let (_leaf_index, sibling_program) = tree_buf
88                        .remove(&(height, sibling_key))
89                        .expect("pop sibling");
90                    let parent_key = key.parent_path(height);
91                    let parent_program = merge_program(&program, &sibling_program, height)?;
92                    (parent_key, parent_program, height)
93                } else {
94                    let merge_height = leaves_path[leaf_index].front().copied().unwrap_or(height);
95                    if height != merge_height {
96                        let parent_key = key.copy_bits(merge_height);
97                        // skip zeros
98                        tree_buf.insert((merge_height, parent_key), (leaf_index, program));
99                        continue;
100                    }
101                    let (proof, proof_height) = proof.pop_front().ok_or(Error::CorruptedProof)?;
102                    let proof_height = proof_height;
103                    if height < proof_height {
104                        height = proof_height;
105                    }
106
107                    let parent_key = key.parent_path(height);
108                    let parent_program = proof_program(&program, proof, height);
109                    (parent_key, parent_program, height)
110                };
111
112            if height == core::u8::MAX {
113                if proof.is_empty() {
114                    return Ok(CompiledMerkleProof(parent_program.0));
115                } else {
116                    return Err(Error::CorruptedProof);
117                }
118            }
119            leaves_path[leaf_index].pop_front();
120            tree_buf.insert((height + 1, parent_key), (leaf_index, parent_program));
121        }
122
123        Err(Error::CorruptedProof)
124    }
125
126    /// Compute root from proof
127    /// leaves: a vector of (key, value)
128    ///
129    /// return EmptyProof error when proof is empty
130    /// return CorruptedProof error when proof is invalid
131    pub fn compute_root<H: Hasher + Default>(self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
132        if leaves.is_empty() {
133            return Err(Error::EmptyKeys);
134        } else if leaves.len() != self.leaves_count() {
135            return Err(Error::IncorrectNumberOfLeaves {
136                expected: self.leaves_count(),
137                actual: leaves.len(),
138            });
139        }
140
141        let (leaves_path, proof) = self.take();
142        let mut leaves_path: Vec<VecDeque<_>> = leaves_path.into_iter().map(Into::into).collect();
143        let mut proof: VecDeque<_> = proof.into();
144
145        // Deny non-inclusion proof
146        for (_k, v) in &leaves {
147            if v.is_zero() {
148                return Err(Error::ForbidZeroValueLeaf);
149            }
150        }
151
152        // sort leaves
153        leaves.sort_unstable_by_key(|(k, _v)| *k);
154        // tree_buf: (height, key) -> (key_index, node)
155        let mut tree_buf: BTreeMap<_, _> = leaves
156            .into_iter()
157            .enumerate()
158            .map(|(i, (k, v))| ((0, k), (i, hash_leaf::<H>(&k, &v))))
159            .collect();
160        // rebuild the tree from bottom to top
161        while !tree_buf.is_empty() {
162            // pop_front from tree_buf, the API is unstable
163            let (&(mut height, key), &(leaf_index, node)) = tree_buf.iter().next().unwrap();
164            tree_buf.remove(&(height, key));
165
166            if proof.is_empty() && tree_buf.is_empty() {
167                return Ok(node);
168            }
169
170            let mut sibling_key = key.parent_path(height);
171            if !key.get_bit(height) {
172                sibling_key.set_bit(height)
173            }
174            let (sibling, sibling_height) =
175                if Some(&(height, sibling_key)) == tree_buf.keys().next() {
176                    let (_leaf_index, sibling) = tree_buf
177                        .remove(&(height, sibling_key))
178                        .expect("pop sibling");
179                    (sibling, height)
180                } else {
181                    let merge_height = leaves_path[leaf_index].front().copied().unwrap_or(height);
182                    if height != merge_height {
183                        let parent_key = key.copy_bits(merge_height);
184                        // skip zeros
185                        tree_buf.insert((merge_height, parent_key), (leaf_index, node));
186                        continue;
187                    }
188                    let (node, height) = proof.pop_front().ok_or(Error::CorruptedProof)?;
189                    (node, height)
190                };
191            if height < sibling_height {
192                height = sibling_height;
193            }
194            // skip zero merkle path
195            let parent_key = key.parent_path(height);
196
197            let parent = if key.get_bit(height) {
198                merge::<H>(&sibling, &node)
199            } else {
200                merge::<H>(&node, &sibling)
201            };
202
203            if height == core::u8::MAX {
204                if proof.is_empty() {
205                    return Ok(parent);
206                } else {
207                    return Err(Error::CorruptedProof);
208                }
209            } else {
210                leaves_path[leaf_index].pop_front();
211                tree_buf.insert((height + 1, parent_key), (leaf_index, parent));
212            }
213        }
214
215        Err(Error::CorruptedProof)
216    }
217
218    /// Verify merkle proof
219    /// see compute_root_from_proof
220    pub fn verify<H: Hasher + Default>(
221        self,
222        root: &H256,
223        leaves: Vec<(H256, H256)>,
224    ) -> Result<bool> {
225        let calculated_root = self.compute_root::<H>(leaves)?;
226        Ok(&calculated_root == root)
227    }
228}
229
230fn leaf_program(leaf_index: usize) -> (Vec<u8>, Option<Range>) {
231    let mut program = Vec::with_capacity(1);
232    program.push(0x4C);
233    (
234        program,
235        Some(Range {
236            start: leaf_index,
237            end: leaf_index + 1,
238        }),
239    )
240}
241
242fn proof_program(
243    child: &(Vec<u8>, Option<Range>),
244    proof: H256,
245    height: u8,
246) -> (Vec<u8>, Option<Range>) {
247    let (child_program, child_range) = child;
248    let mut program = Vec::new();
249    program.resize(34 + child_program.len(), 0x50);
250    program[..child_program.len()].copy_from_slice(child_program);
251    program[child_program.len() + 1] = height;
252    program[child_program.len() + 2..].copy_from_slice(proof.as_slice());
253    (program, child_range.clone())
254}
255
256fn merge_program(
257    a: &(Vec<u8>, Option<Range>),
258    b: &(Vec<u8>, Option<Range>),
259    height: u8,
260) -> Result<(Vec<u8>, Option<Range>)> {
261    let (a_program, a_range) = a;
262    let (b_program, b_range) = b;
263    let (a_comes_first, range) = if a_range.is_none() || b_range.is_none() {
264        let range = if a_range.is_none() { b_range } else { a_range }
265            .clone()
266            .unwrap();
267        (true, range)
268    } else {
269        let a_range = a_range.clone().unwrap();
270        let b_range = b_range.clone().unwrap();
271        if a_range.end == b_range.start {
272            (
273                true,
274                Range {
275                    start: a_range.start,
276                    end: b_range.end,
277                },
278            )
279        } else {
280            return Err(Error::NonMergableRange);
281        }
282    };
283    let mut program = Vec::new();
284    program.resize(2 + a_program.len() + b_program.len(), 0x48);
285    if a_comes_first {
286        program[..a_program.len()].copy_from_slice(a_program);
287        program[a_program.len()..a_program.len() + b_program.len()].copy_from_slice(b_program);
288    } else {
289        program[..b_program.len()].copy_from_slice(b_program);
290        program[b_program.len()..a_program.len() + b_program.len()].copy_from_slice(a_program);
291    }
292    program[a_program.len() + b_program.len() + 1] = height;
293    Ok((program, Some(range)))
294}
295
296/// An structure optimized for verify merkle proof
297#[derive(Debug, Clone)]
298pub struct CompiledMerkleProof(pub Vec<u8>);
299
300impl CompiledMerkleProof {
301    pub fn compute_root<H: Hasher + Default>(&self, mut leaves: Vec<(H256, H256)>) -> Result<H256> {
302        leaves.sort_unstable_by_key(|(k, _v)| *k);
303        let mut program_index = 0;
304        let mut leave_index = 0;
305        let mut stack = Vec::new();
306        while program_index < self.0.len() {
307            let code = self.0[program_index];
308            program_index += 1;
309            match code {
310                // L
311                0x4C => {
312                    if leave_index >= leaves.len() {
313                        return Err(Error::CorruptedStack);
314                    }
315                    let (k, v) = leaves[leave_index];
316
317                    // Deny non-inclusion proof
318                    if v.is_zero() {
319                        return Err(Error::ForbidZeroValueLeaf);
320                    }
321
322                    stack.push((k, hash_leaf::<H>(&k, &v)));
323                    leave_index += 1;
324                }
325                // P
326                0x50 => {
327                    if stack.is_empty() {
328                        return Err(Error::CorruptedStack);
329                    }
330                    if program_index + 33 > self.0.len() {
331                        return Err(Error::CorruptedProof);
332                    }
333                    let height = self.0[program_index];
334                    program_index += 1;
335                    let mut data = [0u8; 32];
336                    data.copy_from_slice(&self.0[program_index..program_index + 32]);
337                    program_index += 32;
338                    let proof = H256::from(data);
339                    let (key, value) = stack.pop().unwrap();
340                    let parent_key = key.parent_path(height);
341                    let parent = if key.get_bit(height) {
342                        merge::<H>(&proof, &value)
343                    } else {
344                        merge::<H>(&value, &proof)
345                    };
346                    stack.push((parent_key, parent));
347                }
348                // H
349                0x48 => {
350                    if stack.len() < 2 {
351                        return Err(Error::CorruptedStack);
352                    }
353                    if program_index >= self.0.len() {
354                        return Err(Error::CorruptedProof);
355                    }
356                    let height = self.0[program_index];
357                    program_index += 1;
358                    let (key_b, value_b) = stack.pop().unwrap();
359                    let (key_a, value_a) = stack.pop().unwrap();
360                    let parent_key_a = key_a.copy_bits(height);
361                    let parent_key_b = key_b.copy_bits(height);
362                    let a_set = key_a.get_bit(height);
363                    let b_set = key_b.get_bit(height);
364                    let mut sibling_key_a = parent_key_a;
365                    if !a_set {
366                        sibling_key_a.set_bit(height);
367                    }
368                    // Test if a and b are siblings
369                    if !(sibling_key_a == parent_key_b && (a_set ^ b_set)) {
370                        return Err(Error::NonSiblings);
371                    }
372                    let parent = if key_a.get_bit(height) {
373                        merge::<H>(&value_b, &value_a)
374                    } else {
375                        merge::<H>(&value_a, &value_b)
376                    };
377                    stack.push((parent_key_a, parent));
378                }
379                _ => return Err(Error::InvalidCode(code)),
380            }
381        }
382        if stack.len() != 1 {
383            return Err(Error::CorruptedStack);
384        }
385        Ok(stack[0].1)
386    }
387
388    pub fn verify<H: Hasher + Default>(
389        &self,
390        root: &H256,
391        leaves: Vec<(H256, H256)>,
392    ) -> Result<bool> {
393        let calculated_root = self.compute_root::<H>(leaves)?;
394        Ok(&calculated_root == root)
395    }
396}
397
398impl Into<Vec<u8>> for CompiledMerkleProof {
399    fn into(self) -> Vec<u8> {
400        self.0
401    }
402}