../../.cargo/katex-header.html

plonky2/hash/
merkle_tree.rs

1#[cfg(not(feature = "std"))]
2use alloc::vec::Vec;
3use core::mem::MaybeUninit;
4use core::slice;
5
6use plonky2_maybe_rayon::*;
7use serde::{Deserialize, Serialize};
8
9use crate::hash::hash_types::RichField;
10use crate::hash::merkle_proofs::MerkleProof;
11use crate::plonk::config::{GenericHashOut, Hasher};
12use crate::util::log2_strict;
13
14/// The Merkle cap of height `h` of a Merkle tree is the `h`-th layer (from the root) of the tree.
15/// It can be used in place of the root to verify Merkle paths, which are `h` elements shorter.
16#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
17#[serde(bound = "")]
18// TODO: Change H to GenericHashOut<F>, since this only cares about the hash, not the hasher.
19pub struct MerkleCap<F: RichField, H: Hasher<F>>(pub Vec<H::Hash>);
20
21impl<F: RichField, H: Hasher<F>> Default for MerkleCap<F, H> {
22    fn default() -> Self {
23        Self(Vec::new())
24    }
25}
26
27impl<F: RichField, H: Hasher<F>> MerkleCap<F, H> {
28    pub fn len(&self) -> usize {
29        self.0.len()
30    }
31
32    pub fn is_empty(&self) -> bool {
33        self.len() == 0
34    }
35
36    pub fn height(&self) -> usize {
37        log2_strict(self.len())
38    }
39
40    pub fn flatten(&self) -> Vec<F> {
41        self.0.iter().flat_map(|&h| h.to_vec()).collect()
42    }
43}
44
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct MerkleTree<F: RichField, H: Hasher<F>> {
47    /// The data in the leaves of the Merkle tree.
48    pub leaves: Vec<Vec<F>>,
49
50    /// The digests in the tree. Consists of `cap.len()` sub-trees, each corresponding to one
51    /// element in `cap`. Each subtree is contiguous and located at
52    /// `digests[digests.len() / cap.len() * i..digests.len() / cap.len() * (i + 1)]`.
53    /// Within each subtree, siblings are stored next to each other. The layout is,
54    /// left_child_subtree || left_child_digest || right_child_digest || right_child_subtree, where
55    /// left_child_digest and right_child_digest are H::Hash and left_child_subtree and
56    /// right_child_subtree recurse. Observe that the digest of a node is stored by its _parent_.
57    /// Consequently, the digests of the roots are not stored here (they can be found in `cap`).
58    pub digests: Vec<H::Hash>,
59
60    /// The Merkle cap.
61    pub cap: MerkleCap<F, H>,
62}
63
64impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
65    fn default() -> Self {
66        Self {
67            leaves: Vec::new(),
68            digests: Vec::new(),
69            cap: MerkleCap::default(),
70        }
71    }
72}
73
74pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
75    assert!(v.capacity() >= len);
76    let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
77    unsafe {
78        // SAFETY: `v_ptr` is a valid pointer to a buffer of length at least `len`. Upon return, the
79        // lifetime will be bound to that of `v`. The underlying memory will not be deallocated as
80        // we hold the sole mutable reference to `v`. The contents of the slice may be
81        // uninitialized, but the `MaybeUninit` makes it safe.
82        slice::from_raw_parts_mut(v_ptr, len)
83    }
84}
85
86pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
87    digests_buf: &mut [MaybeUninit<H::Hash>],
88    leaves: &[Vec<F>],
89) -> H::Hash {
90    assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
91    if digests_buf.is_empty() {
92        H::hash_or_noop(&leaves[0])
93    } else {
94        // Layout is: left recursive output || left child digest
95        //             || right child digest || right recursive output.
96        // Split `digests_buf` into the two recursive outputs (slices) and two child digests
97        // (references).
98        let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
99        let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
100        let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
101        // Split `leaves` between both children.
102        let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
103
104        let (left_digest, right_digest) = plonky2_maybe_rayon::join(
105            || fill_subtree::<F, H>(left_digests_buf, left_leaves),
106            || fill_subtree::<F, H>(right_digests_buf, right_leaves),
107        );
108
109        left_digest_mem.write(left_digest);
110        right_digest_mem.write(right_digest);
111        H::two_to_one(left_digest, right_digest)
112    }
113}
114
115pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
116    digests_buf: &mut [MaybeUninit<H::Hash>],
117    cap_buf: &mut [MaybeUninit<H::Hash>],
118    leaves: &[Vec<F>],
119    cap_height: usize,
120) {
121    // Special case of a tree that's all cap. The usual case will panic because we'll try to split
122    // an empty slice into chunks of `0`. (We would not need this if there was a way to split into
123    // `blah` chunks as opposed to chunks _of_ `blah`.)
124    if digests_buf.is_empty() {
125        debug_assert_eq!(cap_buf.len(), leaves.len());
126        cap_buf
127            .par_iter_mut()
128            .zip(leaves)
129            .for_each(|(cap_buf, leaf)| {
130                cap_buf.write(H::hash_or_noop(leaf));
131            });
132        return;
133    }
134
135    let subtree_digests_len = digests_buf.len() >> cap_height;
136    let subtree_leaves_len = leaves.len() >> cap_height;
137    let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
138    let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
139    assert_eq!(digests_chunks.len(), cap_buf.len());
140    assert_eq!(digests_chunks.len(), leaves_chunks.len());
141    digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
142        |((subtree_digests, subtree_cap), subtree_leaves)| {
143            // We have `1 << cap_height` sub-trees, one for each entry in `cap`. They are totally
144            // independent, so we schedule one task for each. `digests_buf` and `leaves` are split
145            // into `1 << cap_height` slices, one for each sub-tree.
146            subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
147        },
148    );
149}
150
151pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
152    leaf_index: usize,
153    leaves_len: usize,
154    cap_height: usize,
155    digests: &[H::Hash],
156) -> Vec<H::Hash> {
157    let num_layers = log2_strict(leaves_len) - cap_height;
158    debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
159
160    let digest_len = 2 * (leaves_len - (1 << cap_height));
161    assert_eq!(digest_len, digests.len());
162
163    let digest_tree: &[H::Hash] = {
164        let tree_index = leaf_index >> num_layers;
165        let tree_len = digest_len >> cap_height;
166        &digests[tree_len * tree_index..tree_len * (tree_index + 1)]
167    };
168
169    // Mask out high bits to get the index within the sub-tree.
170    let mut pair_index = leaf_index & ((1 << num_layers) - 1);
171    (0..num_layers)
172        .map(|i| {
173            let parity = pair_index & 1;
174            pair_index >>= 1;
175
176            // The layers' data is interleaved as follows:
177            // [layer 0, layer 1, layer 0, layer 2, layer 0, layer 1, layer 0, layer 3, ...].
178            // Each of the above is a pair of siblings.
179            // `pair_index` is the index of the pair within layer `i`.
180            // The index of that the pair within `digests` is
181            // `pair_index * 2 ** (i + 1) + (2 ** i - 1)`.
182            let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
183            // We have an index for the _pair_, but we want the index of the _sibling_.
184            // Double the pair index to get the index of the left sibling. Conditionally add `1`
185            // if we are to retrieve the right sibling.
186            let sibling_index = 2 * siblings_index + (1 - parity);
187            digest_tree[sibling_index]
188        })
189        .collect()
190}
191
192impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
193    pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
194        let log2_leaves_len = log2_strict(leaves.len());
195        assert!(
196            cap_height <= log2_leaves_len,
197            "cap_height={} should be at most log2(leaves.len())={}",
198            cap_height,
199            log2_leaves_len
200        );
201
202        let num_digests = 2 * (leaves.len() - (1 << cap_height));
203        let mut digests = Vec::with_capacity(num_digests);
204
205        let len_cap = 1 << cap_height;
206        let mut cap = Vec::with_capacity(len_cap);
207
208        let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
209        let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
210        fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
211
212        unsafe {
213            // SAFETY: `fill_digests_buf` and `cap` initialized the spare capacity up to
214            // `num_digests` and `len_cap`, resp.
215            digests.set_len(num_digests);
216            cap.set_len(len_cap);
217        }
218
219        Self {
220            leaves,
221            digests,
222            cap: MerkleCap(cap),
223        }
224    }
225
226    pub fn get(&self, i: usize) -> &[F] {
227        &self.leaves[i]
228    }
229
230    /// Create a Merkle proof from a leaf index.
231    pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
232        let cap_height = log2_strict(self.cap.len());
233        let siblings =
234            merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
235
236        MerkleProof { siblings }
237    }
238}
239
240#[cfg(test)]
241pub(crate) mod tests {
242    use anyhow::Result;
243
244    use super::*;
245    use crate::field::extension::Extendable;
246    use crate::hash::merkle_proofs::verify_merkle_proof_to_cap;
247    use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
248
249    pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
250        (0..n).map(|_| F::rand_vec(k)).collect()
251    }
252
253    fn verify_all_leaves<
254        F: RichField + Extendable<D>,
255        C: GenericConfig<D, F = F>,
256        const D: usize,
257    >(
258        leaves: Vec<Vec<F>>,
259        cap_height: usize,
260    ) -> Result<()> {
261        let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
262        for (i, leaf) in leaves.into_iter().enumerate() {
263            let proof = tree.prove(i);
264            verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
265        }
266        Ok(())
267    }
268
269    #[test]
270    #[should_panic]
271    fn test_cap_height_too_big() {
272        const D: usize = 2;
273        type C = PoseidonGoldilocksConfig;
274        type F = <C as GenericConfig<D>>::F;
275
276        let log_n = 8;
277        let cap_height = log_n + 1; // Should panic if `cap_height > len_n`.
278
279        let leaves = random_data::<F>(1 << log_n, 7);
280        let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
281    }
282
283    #[test]
284    fn test_cap_height_eq_log2_len() -> Result<()> {
285        const D: usize = 2;
286        type C = PoseidonGoldilocksConfig;
287        type F = <C as GenericConfig<D>>::F;
288
289        let log_n = 8;
290        let n = 1 << log_n;
291        let leaves = random_data::<F>(n, 7);
292
293        verify_all_leaves::<F, C, D>(leaves, log_n)?;
294
295        Ok(())
296    }
297
298    #[test]
299    fn test_merkle_trees() -> Result<()> {
300        const D: usize = 2;
301        type C = PoseidonGoldilocksConfig;
302        type F = <C as GenericConfig<D>>::F;
303
304        let log_n = 8;
305        let n = 1 << log_n;
306        let leaves = random_data::<F>(n, 7);
307
308        verify_all_leaves::<F, C, D>(leaves, 1)?;
309
310        Ok(())
311    }
312}