1#[cfg(not(feature = "std"))]
2use alloc::vec::Vec;
3use core::mem::MaybeUninit;
4use core::slice;
5
6use plonky2_maybe_rayon::*;
7use serde::{Deserialize, Serialize};
8
9use crate::hash::hash_types::RichField;
10use crate::hash::merkle_proofs::MerkleProof;
11use crate::plonk::config::{GenericHashOut, Hasher};
12use crate::util::log2_strict;
13
14#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)]
17#[serde(bound = "")]
18pub struct MerkleCap<F: RichField, H: Hasher<F>>(pub Vec<H::Hash>);
20
21impl<F: RichField, H: Hasher<F>> Default for MerkleCap<F, H> {
22 fn default() -> Self {
23 Self(Vec::new())
24 }
25}
26
27impl<F: RichField, H: Hasher<F>> MerkleCap<F, H> {
28 pub fn len(&self) -> usize {
29 self.0.len()
30 }
31
32 pub fn is_empty(&self) -> bool {
33 self.len() == 0
34 }
35
36 pub fn height(&self) -> usize {
37 log2_strict(self.len())
38 }
39
40 pub fn flatten(&self) -> Vec<F> {
41 self.0.iter().flat_map(|&h| h.to_vec()).collect()
42 }
43}
44
45#[derive(Clone, Debug, Eq, PartialEq)]
46pub struct MerkleTree<F: RichField, H: Hasher<F>> {
47 pub leaves: Vec<Vec<F>>,
49
50 pub digests: Vec<H::Hash>,
59
60 pub cap: MerkleCap<F, H>,
62}
63
64impl<F: RichField, H: Hasher<F>> Default for MerkleTree<F, H> {
65 fn default() -> Self {
66 Self {
67 leaves: Vec::new(),
68 digests: Vec::new(),
69 cap: MerkleCap::default(),
70 }
71 }
72}
73
74pub(crate) fn capacity_up_to_mut<T>(v: &mut Vec<T>, len: usize) -> &mut [MaybeUninit<T>] {
75 assert!(v.capacity() >= len);
76 let v_ptr = v.as_mut_ptr().cast::<MaybeUninit<T>>();
77 unsafe {
78 slice::from_raw_parts_mut(v_ptr, len)
83 }
84}
85
86pub(crate) fn fill_subtree<F: RichField, H: Hasher<F>>(
87 digests_buf: &mut [MaybeUninit<H::Hash>],
88 leaves: &[Vec<F>],
89) -> H::Hash {
90 assert_eq!(leaves.len(), digests_buf.len() / 2 + 1);
91 if digests_buf.is_empty() {
92 H::hash_or_noop(&leaves[0])
93 } else {
94 let (left_digests_buf, right_digests_buf) = digests_buf.split_at_mut(digests_buf.len() / 2);
99 let (left_digest_mem, left_digests_buf) = left_digests_buf.split_last_mut().unwrap();
100 let (right_digest_mem, right_digests_buf) = right_digests_buf.split_first_mut().unwrap();
101 let (left_leaves, right_leaves) = leaves.split_at(leaves.len() / 2);
103
104 let (left_digest, right_digest) = plonky2_maybe_rayon::join(
105 || fill_subtree::<F, H>(left_digests_buf, left_leaves),
106 || fill_subtree::<F, H>(right_digests_buf, right_leaves),
107 );
108
109 left_digest_mem.write(left_digest);
110 right_digest_mem.write(right_digest);
111 H::two_to_one(left_digest, right_digest)
112 }
113}
114
115pub(crate) fn fill_digests_buf<F: RichField, H: Hasher<F>>(
116 digests_buf: &mut [MaybeUninit<H::Hash>],
117 cap_buf: &mut [MaybeUninit<H::Hash>],
118 leaves: &[Vec<F>],
119 cap_height: usize,
120) {
121 if digests_buf.is_empty() {
125 debug_assert_eq!(cap_buf.len(), leaves.len());
126 cap_buf
127 .par_iter_mut()
128 .zip(leaves)
129 .for_each(|(cap_buf, leaf)| {
130 cap_buf.write(H::hash_or_noop(leaf));
131 });
132 return;
133 }
134
135 let subtree_digests_len = digests_buf.len() >> cap_height;
136 let subtree_leaves_len = leaves.len() >> cap_height;
137 let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len);
138 let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len);
139 assert_eq!(digests_chunks.len(), cap_buf.len());
140 assert_eq!(digests_chunks.len(), leaves_chunks.len());
141 digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each(
142 |((subtree_digests, subtree_cap), subtree_leaves)| {
143 subtree_cap.write(fill_subtree::<F, H>(subtree_digests, subtree_leaves));
147 },
148 );
149}
150
151pub(crate) fn merkle_tree_prove<F: RichField, H: Hasher<F>>(
152 leaf_index: usize,
153 leaves_len: usize,
154 cap_height: usize,
155 digests: &[H::Hash],
156) -> Vec<H::Hash> {
157 let num_layers = log2_strict(leaves_len) - cap_height;
158 debug_assert_eq!(leaf_index >> (cap_height + num_layers), 0);
159
160 let digest_len = 2 * (leaves_len - (1 << cap_height));
161 assert_eq!(digest_len, digests.len());
162
163 let digest_tree: &[H::Hash] = {
164 let tree_index = leaf_index >> num_layers;
165 let tree_len = digest_len >> cap_height;
166 &digests[tree_len * tree_index..tree_len * (tree_index + 1)]
167 };
168
169 let mut pair_index = leaf_index & ((1 << num_layers) - 1);
171 (0..num_layers)
172 .map(|i| {
173 let parity = pair_index & 1;
174 pair_index >>= 1;
175
176 let siblings_index = (pair_index << (i + 1)) + (1 << i) - 1;
183 let sibling_index = 2 * siblings_index + (1 - parity);
187 digest_tree[sibling_index]
188 })
189 .collect()
190}
191
192impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
193 pub fn new(leaves: Vec<Vec<F>>, cap_height: usize) -> Self {
194 let log2_leaves_len = log2_strict(leaves.len());
195 assert!(
196 cap_height <= log2_leaves_len,
197 "cap_height={} should be at most log2(leaves.len())={}",
198 cap_height,
199 log2_leaves_len
200 );
201
202 let num_digests = 2 * (leaves.len() - (1 << cap_height));
203 let mut digests = Vec::with_capacity(num_digests);
204
205 let len_cap = 1 << cap_height;
206 let mut cap = Vec::with_capacity(len_cap);
207
208 let digests_buf = capacity_up_to_mut(&mut digests, num_digests);
209 let cap_buf = capacity_up_to_mut(&mut cap, len_cap);
210 fill_digests_buf::<F, H>(digests_buf, cap_buf, &leaves[..], cap_height);
211
212 unsafe {
213 digests.set_len(num_digests);
216 cap.set_len(len_cap);
217 }
218
219 Self {
220 leaves,
221 digests,
222 cap: MerkleCap(cap),
223 }
224 }
225
226 pub fn get(&self, i: usize) -> &[F] {
227 &self.leaves[i]
228 }
229
230 pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
232 let cap_height = log2_strict(self.cap.len());
233 let siblings =
234 merkle_tree_prove::<F, H>(leaf_index, self.leaves.len(), cap_height, &self.digests);
235
236 MerkleProof { siblings }
237 }
238}
239
240#[cfg(test)]
241pub(crate) mod tests {
242 use anyhow::Result;
243
244 use super::*;
245 use crate::field::extension::Extendable;
246 use crate::hash::merkle_proofs::verify_merkle_proof_to_cap;
247 use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
248
249 pub(crate) fn random_data<F: RichField>(n: usize, k: usize) -> Vec<Vec<F>> {
250 (0..n).map(|_| F::rand_vec(k)).collect()
251 }
252
253 fn verify_all_leaves<
254 F: RichField + Extendable<D>,
255 C: GenericConfig<D, F = F>,
256 const D: usize,
257 >(
258 leaves: Vec<Vec<F>>,
259 cap_height: usize,
260 ) -> Result<()> {
261 let tree = MerkleTree::<F, C::Hasher>::new(leaves.clone(), cap_height);
262 for (i, leaf) in leaves.into_iter().enumerate() {
263 let proof = tree.prove(i);
264 verify_merkle_proof_to_cap(leaf, i, &tree.cap, &proof)?;
265 }
266 Ok(())
267 }
268
269 #[test]
270 #[should_panic]
271 fn test_cap_height_too_big() {
272 const D: usize = 2;
273 type C = PoseidonGoldilocksConfig;
274 type F = <C as GenericConfig<D>>::F;
275
276 let log_n = 8;
277 let cap_height = log_n + 1; let leaves = random_data::<F>(1 << log_n, 7);
280 let _ = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new(leaves, cap_height);
281 }
282
283 #[test]
284 fn test_cap_height_eq_log2_len() -> Result<()> {
285 const D: usize = 2;
286 type C = PoseidonGoldilocksConfig;
287 type F = <C as GenericConfig<D>>::F;
288
289 let log_n = 8;
290 let n = 1 << log_n;
291 let leaves = random_data::<F>(n, 7);
292
293 verify_all_leaves::<F, C, D>(leaves, log_n)?;
294
295 Ok(())
296 }
297
298 #[test]
299 fn test_merkle_trees() -> Result<()> {
300 const D: usize = 2;
301 type C = PoseidonGoldilocksConfig;
302 type F = <C as GenericConfig<D>>::F;
303
304 let log_n = 8;
305 let n = 1 << log_n;
306 let leaves = random_data::<F>(n, 7);
307
308 verify_all_leaves::<F, C, D>(leaves, 1)?;
309
310 Ok(())
311 }
312}