winter_crypto/merkle/
proofs.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6use alloc::{collections::BTreeMap, vec::Vec};
7
8use utils::{ByteReader, Deserializable, DeserializationError, Serializable};
9
10use super::MerkleTreeOpening;
11use crate::{errors::MerkleTreeError, Hasher};
12
13// BATCH MERKLE PROOF
14// ================================================================================================
15
16/// Multiple Merkle proofs aggregated into a single proof.
17///
18/// The aggregation is done in a way which removes all duplicate internal nodes, and thus,
19/// it is possible to achieve non-negligible compression as compared to naively concatenating
20/// individual Merkle proofs. The algorithm is for aggregation is a variation of
21/// [Octopus](https://eprint.iacr.org/2017/933).
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub struct BatchMerkleProof<H: Hasher> {
24    /// Hashes of Merkle Tree proof values above the leaf layer
25    pub nodes: Vec<Vec<H::Digest>>,
26    /// Depth of the leaves
27    pub depth: u8,
28}
29
30impl<H: Hasher> BatchMerkleProof<H> {
31    /// Constructs a batch Merkle proof from collection of single Merkle proofs.
32    ///
33    /// # Panics
34    /// Panics if:
35    /// * No proofs have been provided (i.e., `proofs` is an empty slice).
36    /// * Number of proofs is not equal to the number of indexes.
37    /// * Not all proofs have the same length.
38    pub fn from_single_proofs(
39        proofs: &[MerkleTreeOpening<H>],
40        indexes: &[usize],
41    ) -> BatchMerkleProof<H> {
42        // TODO: optimize this to reduce amount of vector cloning.
43        assert!(!proofs.is_empty(), "at least one proof must be provided");
44        assert_eq!(proofs.len(), indexes.len(), "number of proofs must equal number of indexes");
45
46        let depth = proofs[0].1.len();
47
48        // sort indexes in ascending order, and also re-arrange proofs accordingly
49        let mut proof_map = BTreeMap::new();
50        for (&index, proof) in indexes.iter().zip(proofs.iter().cloned()) {
51            assert_eq!(depth, proof.1.len(), "not all proofs have the same length");
52            proof_map.insert(index, proof);
53        }
54        let indexes = proof_map.keys().cloned().collect::<Vec<_>>();
55        let proofs = proof_map.values().cloned().collect::<Vec<_>>();
56        proof_map.clear();
57
58        let mut leaves = vec![H::Digest::default(); indexes.len()];
59        let mut nodes: Vec<Vec<H::Digest>> = Vec::with_capacity(indexes.len());
60
61        // populate values and the first layer of proof nodes
62        let mut i = 0;
63        while i < indexes.len() {
64            leaves[i] = proofs[i].0;
65
66            if indexes.len() > i + 1 && are_siblings(indexes[i], indexes[i + 1]) {
67                leaves[i + 1] = proofs[i].1[0];
68                nodes.push(vec![]);
69                i += 1;
70            } else {
71                nodes.push(vec![proofs[i].1[0]]);
72            }
73            proof_map.insert(indexes[i] >> 1, proofs[i].clone());
74            i += 1;
75        }
76
77        // populate all remaining layers of proof nodes
78        for d in 1..depth {
79            let indexes = proof_map.keys().cloned().collect::<Vec<_>>();
80            let mut next_proof_map = BTreeMap::new();
81
82            let mut i = 0;
83            while i < indexes.len() {
84                let index = indexes[i];
85                let proof = proof_map.get(&index).unwrap();
86                if indexes.len() > i + 1 && are_siblings(index, indexes[i + 1]) {
87                    i += 1;
88                } else {
89                    nodes[i].push(proof.1[d]);
90                }
91                next_proof_map.insert(index >> 1, proof.clone());
92                i += 1;
93            }
94
95            core::mem::swap(&mut proof_map, &mut next_proof_map);
96        }
97
98        BatchMerkleProof { nodes, depth: (depth) as u8 }
99    }
100
101    /// Computes a node to which all Merkle proofs aggregated in this proof resolve.
102    ///
103    /// # Errors
104    /// Returns an error if:
105    /// * No indexes were provided (i.e., `indexes` is an empty slice).
106    /// * Any of the specified `indexes` is greater than or equal to the number of leaves in the
107    ///   tree for which this batch proof was generated.
108    /// * List of indexes contains duplicates.
109    /// * The proof does not resolve to a single root.
110    pub fn get_root(
111        &self,
112        indexes: &[usize],
113        leaves: &[H::Digest],
114    ) -> Result<H::Digest, MerkleTreeError> {
115        if indexes.is_empty() {
116            return Err(MerkleTreeError::TooFewLeafIndexes);
117        }
118
119        let mut buf = [H::Digest::default(); 2];
120        let mut v = BTreeMap::new();
121
122        // replace odd indexes, offset, and sort in ascending order
123        let index_map = super::map_indexes(indexes, self.depth as usize)?;
124        let indexes = super::normalize_indexes(indexes);
125        if indexes.len() != self.nodes.len() {
126            return Err(MerkleTreeError::InvalidProof);
127        }
128
129        // for each index use values to compute parent nodes
130        let offset = 2usize.pow(self.depth as u32);
131        let mut next_indexes: Vec<usize> = Vec::new();
132        let mut proof_pointers: Vec<usize> = Vec::with_capacity(indexes.len());
133        for (i, index) in indexes.into_iter().enumerate() {
134            // copy values of leaf sibling leaf nodes into the buffer
135            match index_map.get(&index) {
136                Some(&index1) => {
137                    if leaves.len() <= index1 {
138                        return Err(MerkleTreeError::InvalidProof);
139                    }
140                    buf[0] = leaves[index1];
141                    match index_map.get(&(index + 1)) {
142                        Some(&index2) => {
143                            if leaves.len() <= index2 {
144                                return Err(MerkleTreeError::InvalidProof);
145                            }
146                            buf[1] = leaves[index2];
147                            proof_pointers.push(0);
148                        },
149                        None => {
150                            if self.nodes[i].is_empty() {
151                                return Err(MerkleTreeError::InvalidProof);
152                            }
153                            buf[1] = self.nodes[i][0];
154                            proof_pointers.push(1);
155                        },
156                    }
157                },
158                None => {
159                    if self.nodes[i].is_empty() {
160                        return Err(MerkleTreeError::InvalidProof);
161                    }
162                    buf[0] = self.nodes[i][0];
163                    match index_map.get(&(index + 1)) {
164                        Some(&index2) => {
165                            if leaves.len() <= index2 {
166                                return Err(MerkleTreeError::InvalidProof);
167                            }
168                            buf[1] = leaves[index2];
169                        },
170                        None => return Err(MerkleTreeError::InvalidProof),
171                    }
172                    proof_pointers.push(1);
173                },
174            }
175
176            // hash sibling nodes into their parent
177            let parent = H::merge(&buf);
178
179            let parent_index = (offset + index) >> 1;
180            v.insert(parent_index, parent);
181            next_indexes.push(parent_index);
182        }
183
184        // iteratively move up, until we get to the root
185        for _ in 1..self.depth {
186            let indexes = next_indexes.clone();
187            next_indexes.truncate(0);
188
189            let mut i = 0;
190            while i < indexes.len() {
191                let node_index = indexes[i];
192                let sibling_index = node_index ^ 1;
193
194                // determine the sibling
195                let sibling: H::Digest;
196                if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
197                    sibling = match v.get(&sibling_index) {
198                        Some(sibling) => *sibling,
199                        None => return Err(MerkleTreeError::InvalidProof),
200                    };
201                    i += 1;
202                } else {
203                    let pointer = proof_pointers[i];
204                    if self.nodes[i].len() <= pointer {
205                        return Err(MerkleTreeError::InvalidProof);
206                    }
207                    sibling = self.nodes[i][pointer];
208                    proof_pointers[i] += 1;
209                }
210
211                // get the node from the map of hashed nodes
212                let node = match v.get(&node_index) {
213                    Some(node) => node,
214                    None => return Err(MerkleTreeError::InvalidProof),
215                };
216
217                // compute parent node from node and sibling
218                if node_index & 1 != 0 {
219                    buf[0] = sibling;
220                    buf[1] = *node;
221                } else {
222                    buf[0] = *node;
223                    buf[1] = sibling;
224                }
225                let parent = H::merge(&buf);
226
227                // add the parent node to the next set of nodes
228                let parent_index = node_index >> 1;
229                v.insert(parent_index, parent);
230                next_indexes.push(parent_index);
231
232                i += 1;
233            }
234        }
235        v.remove(&1).ok_or(MerkleTreeError::InvalidProof)
236    }
237
238    /// Computes the uncompressed individual Merkle proofs which aggregate to this batch proof.
239    ///
240    /// # Errors
241    /// Returns an error if:
242    /// * No indexes were provided (i.e., `indexes` is an empty slice).
243    /// * Number of provided indexes does not match the number of leaf nodes in the proof.
244    pub fn into_openings(
245        self,
246        leaves: &[H::Digest],
247        indexes: &[usize],
248    ) -> Result<Vec<MerkleTreeOpening<H>>, MerkleTreeError> {
249        if indexes.is_empty() {
250            return Err(MerkleTreeError::TooFewLeafIndexes);
251        }
252        if indexes.len() != leaves.len() {
253            return Err(MerkleTreeError::InvalidProof);
254        }
255
256        let mut partial_tree_map = BTreeMap::new();
257
258        for (&i, leaf) in indexes.iter().zip(leaves.iter()) {
259            partial_tree_map.insert(i + (1 << (self.depth)), *leaf);
260        }
261
262        let mut buf = [H::Digest::default(); 2];
263        let mut v = BTreeMap::new();
264
265        // replace odd indexes, offset, and sort in ascending order
266        let original_indexes = indexes;
267        let index_map = super::map_indexes(indexes, self.depth as usize)?;
268        let indexes = super::normalize_indexes(indexes);
269        if indexes.len() != self.nodes.len() {
270            return Err(MerkleTreeError::InvalidProof);
271        }
272
273        // for each index use values to compute parent nodes
274        let offset = 2usize.pow(self.depth as u32);
275        let mut next_indexes: Vec<usize> = Vec::new();
276        let mut proof_pointers: Vec<usize> = Vec::with_capacity(indexes.len());
277        for (i, index) in indexes.into_iter().enumerate() {
278            // copy values of leaf sibling leaf nodes into the buffer
279            match index_map.get(&index) {
280                Some(&index1) => {
281                    if leaves.len() <= index1 {
282                        return Err(MerkleTreeError::InvalidProof);
283                    }
284                    buf[0] = leaves[index1];
285                    match index_map.get(&(index + 1)) {
286                        Some(&index2) => {
287                            if leaves.len() <= index2 {
288                                return Err(MerkleTreeError::InvalidProof);
289                            }
290                            buf[1] = leaves[index2];
291                            proof_pointers.push(0);
292                        },
293                        None => {
294                            if self.nodes[i].is_empty() {
295                                return Err(MerkleTreeError::InvalidProof);
296                            }
297                            buf[1] = self.nodes[i][0];
298                            proof_pointers.push(1);
299                        },
300                    }
301                },
302                None => {
303                    if self.nodes[i].is_empty() {
304                        return Err(MerkleTreeError::InvalidProof);
305                    }
306                    buf[0] = self.nodes[i][0];
307                    match index_map.get(&(index + 1)) {
308                        Some(&index2) => {
309                            if leaves.len() <= index2 {
310                                return Err(MerkleTreeError::InvalidProof);
311                            }
312                            buf[1] = leaves[index2];
313                        },
314                        None => return Err(MerkleTreeError::InvalidProof),
315                    }
316                    proof_pointers.push(1);
317                },
318            }
319
320            // hash sibling nodes into their parent and add it to partial_tree
321            let parent = H::merge(&buf);
322            partial_tree_map.insert(offset + index, buf[0]);
323            partial_tree_map.insert((offset + index) ^ 1, buf[1]);
324            let parent_index = (offset + index) >> 1;
325            v.insert(parent_index, parent);
326            next_indexes.push(parent_index);
327            partial_tree_map.insert(parent_index, parent);
328        }
329
330        // iteratively move up, until we get to the root
331        for _ in 1..self.depth {
332            let indexes = next_indexes.clone();
333            next_indexes.clear();
334
335            let mut i = 0;
336            while i < indexes.len() {
337                let node_index = indexes[i];
338                let sibling_index = node_index ^ 1;
339
340                // determine the sibling
341                let sibling = if i + 1 < indexes.len() && indexes[i + 1] == sibling_index {
342                    i += 1;
343                    match v.get(&sibling_index) {
344                        Some(sibling) => *sibling,
345                        None => return Err(MerkleTreeError::InvalidProof),
346                    }
347                } else {
348                    let pointer = proof_pointers[i];
349                    if self.nodes[i].len() <= pointer {
350                        return Err(MerkleTreeError::InvalidProof);
351                    }
352                    proof_pointers[i] += 1;
353                    self.nodes[i][pointer]
354                };
355
356                // get the node from the map of hashed nodes
357                let node = match v.get(&node_index) {
358                    Some(node) => node,
359                    None => return Err(MerkleTreeError::InvalidProof),
360                };
361
362                // compute parent node from node and sibling
363                partial_tree_map.insert(node_index ^ 1, sibling);
364                let parent = if node_index & 1 != 0 {
365                    H::merge(&[sibling, *node])
366                } else {
367                    H::merge(&[*node, sibling])
368                };
369
370                // add the parent node to the next set of nodes and partial_tree
371                let parent_index = node_index >> 1;
372                v.insert(parent_index, parent);
373                next_indexes.push(parent_index);
374                partial_tree_map.insert(parent_index, parent);
375
376                i += 1;
377            }
378        }
379
380        original_indexes
381            .iter()
382            .map(|&i| get_proof::<H>(i, &partial_tree_map, self.depth as usize))
383            .collect()
384    }
385}
386
387// SERIALIZATION / DESERIALIZATION
388// --------------------------------------------------------------------------------------------
389
390impl<H: Hasher> Serializable for BatchMerkleProof<H> {
391    /// Writes all internal proof nodes into the provided target.
392    fn write_into<W: utils::ByteWriter>(&self, target: &mut W) {
393        target.write_u8(self.depth);
394        target.write_usize(self.nodes.len());
395
396        for nodes in self.nodes.iter() {
397            // record the number of nodes, and append all nodes to the proof buffer
398            nodes.write_into(target);
399        }
400    }
401}
402
403impl<H: Hasher> Deserializable for BatchMerkleProof<H> {
404    /// Parses internal nodes from the provided `source`, and constructs a batch Merkle proof
405    /// from these nodes.
406    ///
407    /// # Errors
408    /// Returns an error if:
409    /// * `source` could not be deserialized into a valid set of internal nodes.
410    fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
411        let depth = source.read_u8()?;
412        let num_node_vectors = source.read_usize()?;
413
414        let mut nodes = Vec::with_capacity(num_node_vectors);
415        for _ in 0..num_node_vectors {
416            // read the digests and add them to the node vector
417            let digests = Vec::<_>::read_from(source)?;
418            nodes.push(digests);
419        }
420
421        Ok(BatchMerkleProof { nodes, depth })
422    }
423}
424
425// HELPER FUNCTIONS
426// ================================================================================================
427
428/// Two nodes are siblings if index of the left node is even and right node
429/// immediately follows the left node.
430fn are_siblings(left: usize, right: usize) -> bool {
431    left & 1 == 0 && right - 1 == left
432}
433
434/// Computes the Merkle proof from the computed (partial) tree.
435pub fn get_proof<H: Hasher>(
436    index: usize,
437    tree: &BTreeMap<usize, <H as Hasher>::Digest>,
438    depth: usize,
439) -> Result<MerkleTreeOpening<H>, MerkleTreeError> {
440    let mut index = index + (1 << depth);
441    let leaf = if let Some(leaf) = tree.get(&index) {
442        *leaf
443    } else {
444        return Err(MerkleTreeError::InvalidProof);
445    };
446
447    let mut proof = vec![];
448    while index > 1 {
449        let leaf = if let Some(leaf) = tree.get(&(index ^ 1)) {
450            *leaf
451        } else {
452            return Err(MerkleTreeError::InvalidProof);
453        };
454
455        proof.push(leaf);
456        index >>= 1;
457    }
458
459    Ok((leaf, proof))
460}