../../.cargo/katex-header.html

plonky2/hash/
path_compression.rs

1#[cfg(not(feature = "std"))]
2use alloc::{vec, vec::Vec};
3
4use hashbrown::HashMap;
5use num::Integer;
6
7use crate::hash::hash_types::RichField;
8use crate::hash::merkle_proofs::MerkleProof;
9use crate::plonk::config::Hasher;
10
11/// Compress multiple Merkle proofs on the same tree by removing redundancy in the Merkle paths.
12pub(crate) fn compress_merkle_proofs<F: RichField, H: Hasher<F>>(
13    cap_height: usize,
14    indices: &[usize],
15    proofs: &[MerkleProof<F, H>],
16) -> Vec<MerkleProof<F, H>> {
17    assert!(!proofs.is_empty());
18    let height = cap_height + proofs[0].siblings.len();
19    let num_leaves = 1 << height;
20    let mut compressed_proofs = Vec::with_capacity(proofs.len());
21    // Holds the known nodes in the tree at a given time. The root is at index 1.
22    // Valid indices are 1 through n, and each element at index `i` has
23    // children at indices `2i` and `2i +1` its parent at index `floor(i ∕ 2)`.
24    let mut known = vec![false; 2 * num_leaves];
25    for &i in indices {
26        // The path from a leaf to the cap is known.
27        for j in 0..(height - cap_height) {
28            known[(i + num_leaves) >> j] = true;
29        }
30    }
31    // For each proof collect all the unknown proof elements.
32    for (&i, p) in indices.iter().zip(proofs) {
33        let mut compressed_proof = MerkleProof {
34            siblings: Vec::new(),
35        };
36        let mut index = i + num_leaves;
37        for &sibling in &p.siblings {
38            let sibling_index = index ^ 1;
39            if !known[sibling_index] {
40                // If the sibling is not yet known, add it to the proof and set it to known.
41                compressed_proof.siblings.push(sibling);
42                known[sibling_index] = true;
43            }
44            // Go up the tree and set the parent to known.
45            index >>= 1;
46            known[index] = true;
47        }
48        compressed_proofs.push(compressed_proof);
49    }
50
51    compressed_proofs
52}
53
54/// Decompress compressed Merkle proofs.
55/// Note: The data and indices must be in the same order as in `compress_merkle_proofs`.
56pub(crate) fn decompress_merkle_proofs<F: RichField, H: Hasher<F>>(
57    leaves_data: &[Vec<F>],
58    leaves_indices: &[usize],
59    compressed_proofs: &[MerkleProof<F, H>],
60    height: usize,
61    cap_height: usize,
62) -> Vec<MerkleProof<F, H>> {
63    let num_leaves = 1 << height;
64    let compressed_proofs = compressed_proofs.to_vec();
65    let mut decompressed_proofs = Vec::with_capacity(compressed_proofs.len());
66    // Holds the already seen nodes in the tree along with their value.
67    let mut seen = HashMap::new();
68
69    for (&i, v) in leaves_indices.iter().zip(leaves_data) {
70        // Observe the leaves.
71        seen.insert(i + num_leaves, H::hash_or_noop(v));
72    }
73
74    // Iterators over the siblings.
75    let mut siblings = compressed_proofs
76        .iter()
77        .map(|p| p.siblings.iter())
78        .collect::<Vec<_>>();
79    // Fill the `seen` map from the bottom of the tree to the cap.
80    for layer_height in 0..height - cap_height {
81        for (&i, p) in leaves_indices.iter().zip(siblings.iter_mut()) {
82            let index = (i + num_leaves) >> layer_height;
83            let current_hash = seen[&index];
84            let sibling_index = index ^ 1;
85            let sibling_hash = *seen
86                .entry(sibling_index)
87                .or_insert_with(|| *p.next().unwrap());
88            let parent_hash = if index.is_even() {
89                H::two_to_one(current_hash, sibling_hash)
90            } else {
91                H::two_to_one(sibling_hash, current_hash)
92            };
93            seen.insert(index >> 1, parent_hash);
94        }
95    }
96    // For every index, go up the tree by querying `seen` to get node values.
97    for &i in leaves_indices {
98        let mut decompressed_proof = MerkleProof {
99            siblings: Vec::new(),
100        };
101        let mut index = i + num_leaves;
102        for _ in 0..height - cap_height {
103            let sibling_index = index ^ 1;
104            let h = seen[&sibling_index];
105            decompressed_proof.siblings.push(h);
106            index >>= 1;
107        }
108
109        decompressed_proofs.push(decompressed_proof);
110    }
111
112    decompressed_proofs
113}
114
115#[cfg(test)]
116mod tests {
117    use rand::rngs::OsRng;
118    use rand::Rng;
119
120    use super::*;
121    use crate::field::types::Sample;
122    use crate::hash::merkle_tree::MerkleTree;
123    use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
124
125    #[test]
126    fn test_path_compression() {
127        const D: usize = 2;
128        type C = PoseidonGoldilocksConfig;
129        type F = <C as GenericConfig<D>>::F;
130        let h = 10;
131        let cap_height = 3;
132        let vs = (0..1 << h).map(|_| vec![F::rand()]).collect::<Vec<_>>();
133        let mt = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(vs.clone(), cap_height);
134
135        let mut rng = OsRng;
136        let k = rng.gen_range(1..=1 << h);
137        let indices = (0..k).map(|_| rng.gen_range(0..1 << h)).collect::<Vec<_>>();
138        let proofs = indices.iter().map(|&i| mt.prove(i)).collect::<Vec<_>>();
139
140        let compressed_proofs = compress_merkle_proofs(cap_height, &indices, &proofs);
141        let decompressed_proofs = decompress_merkle_proofs(
142            &indices.iter().map(|&i| vs[i].clone()).collect::<Vec<_>>(),
143            &indices,
144            &compressed_proofs,
145            h,
146            cap_height,
147        );
148
149        assert_eq!(proofs, decompressed_proofs);
150
151        #[cfg(feature = "std")]
152        {
153            let compressed_proof_bytes = serde_cbor::to_vec(&compressed_proofs).unwrap();
154            println!(
155                "Compressed proof length: {} bytes",
156                compressed_proof_bytes.len()
157            );
158            let proof_bytes = serde_cbor::to_vec(&proofs).unwrap();
159            println!("Proof length: {} bytes", proof_bytes.len());
160        }
161    }
162}