1use alloc::vec::Vec;
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5use itertools::Itertools;
6use p3_commit::Mmcs;
7use p3_field::{PackedField, PackedValue};
8use p3_matrix::{Dimensions, Matrix};
9use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
10use p3_util::log2_ceil_usize;
11use serde::{Deserialize, Serialize};
12
13use crate::FieldMerkleTree;
14use crate::FieldMerkleTreeError::{RootMismatch, WrongBatchSize, WrongHeight};
15
16#[derive(Copy, Clone, Debug)]
23pub struct FieldMerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
24 hash: H,
25 compress: C,
26 _phantom: PhantomData<(P, PW)>,
27}
28
29#[derive(Debug)]
30pub enum FieldMerkleTreeError {
31 WrongBatchSize,
32 WrongWidth,
33 WrongHeight {
34 max_height: usize,
35 num_siblings: usize,
36 },
37 RootMismatch,
38}
39
40impl<P, PW, H, C, const DIGEST_ELEMS: usize> FieldMerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
41 pub const fn new(hash: H, compress: C) -> Self {
42 Self {
43 hash,
44 compress,
45 _phantom: PhantomData,
46 }
47 }
48}
49
50impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Scalar>
51 for FieldMerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
52where
53 P: PackedField,
54 PW: PackedValue,
55 H: CryptographicHasher<P::Scalar, [PW::Value; DIGEST_ELEMS]>,
56 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
57 H: Sync,
58 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
59 C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
60 C: Sync,
61 PW::Value: Eq,
62 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
63{
64 type Commitment = Hash<P::Scalar, PW::Value, DIGEST_ELEMS>;
65 type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
66 type Error = FieldMerkleTreeError;
67 type ProverData<M> = FieldMerkleTree<P::Scalar, PW::Value, M, DIGEST_ELEMS>;
68
69 fn commit<M: Matrix<P::Scalar>>(
70 &self,
71 inputs: Vec<M>,
72 ) -> (Self::Commitment, Self::ProverData<M>) {
73 let tree = FieldMerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
74 let root = tree.root();
75 (root, tree)
76 }
77
78 fn open_batch<M: Matrix<P::Scalar>>(
79 &self,
80 index: usize,
81 prover_data: &FieldMerkleTree<P::Scalar, PW::Value, M, DIGEST_ELEMS>,
82 ) -> (Vec<Vec<P::Scalar>>, Vec<[PW::Value; DIGEST_ELEMS]>) {
83 let max_height = self.get_max_height(prover_data);
84 let log_max_height = log2_ceil_usize(max_height);
85
86 let openings = prover_data
87 .leaves
88 .iter()
89 .map(|matrix| {
90 let log2_height = log2_ceil_usize(matrix.height());
91 let bits_reduced = log_max_height - log2_height;
92 let reduced_index = index >> bits_reduced;
93 matrix.row(reduced_index).collect()
94 })
95 .collect_vec();
96
97 let proof: Vec<_> = (0..log_max_height)
98 .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
99 .collect();
100
101 (openings, proof)
102 }
103
104 fn get_matrices<'a, M: Matrix<P::Scalar>>(
105 &self,
106 prover_data: &'a Self::ProverData<M>,
107 ) -> Vec<&'a M> {
108 prover_data.leaves.iter().collect()
109 }
110
111 fn verify_batch(
112 &self,
113 commit: &Self::Commitment,
114 dimensions: &[Dimensions],
115 mut index: usize,
116 opened_values: &[Vec<P::Scalar>],
117 proof: &Self::Proof,
118 ) -> Result<(), Self::Error> {
119 if dimensions.len() != opened_values.len() {
121 return Err(WrongBatchSize);
122 }
123
124 let max_height = dimensions.iter().map(|dim| dim.height).max().unwrap();
133 let log_max_height = log2_ceil_usize(max_height);
134 if proof.len() != log_max_height {
135 return Err(WrongHeight {
136 max_height,
137 num_siblings: proof.len(),
138 });
139 }
140
141 let mut heights_tallest_first = dimensions
142 .iter()
143 .enumerate()
144 .sorted_by_key(|(_, dims)| Reverse(dims.height))
145 .peekable();
146
147 let mut curr_height_padded = heights_tallest_first
148 .peek()
149 .unwrap()
150 .1
151 .height
152 .next_power_of_two();
153
154 let mut root = self.hash.hash_iter_slices(
155 heights_tallest_first
156 .peeking_take_while(|(_, dims)| {
157 dims.height.next_power_of_two() == curr_height_padded
158 })
159 .map(|(i, _)| opened_values[i].as_slice()),
160 );
161
162 for &sibling in proof.iter() {
163 let (left, right) = if index & 1 == 0 {
164 (root, sibling)
165 } else {
166 (sibling, root)
167 };
168
169 root = self.compress.compress([left, right]);
170 index >>= 1;
171 curr_height_padded >>= 1;
172
173 let next_height = heights_tallest_first
174 .peek()
175 .map(|(_, dims)| dims.height)
176 .filter(|h| h.next_power_of_two() == curr_height_padded);
177 if let Some(next_height) = next_height {
178 let next_height_openings_digest = self.hash.hash_iter_slices(
179 heights_tallest_first
180 .peeking_take_while(|(_, dims)| dims.height == next_height)
181 .map(|(i, _)| opened_values[i].as_slice()),
182 );
183
184 root = self.compress.compress([root, next_height_openings_digest]);
185 }
186 }
187
188 if commit == &root {
189 Ok(())
190 } else {
191 Err(RootMismatch)
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use alloc::vec;
199
200 use itertools::Itertools;
201 use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear};
202 use p3_commit::Mmcs;
203 use p3_field::{AbstractField, Field};
204 use p3_matrix::dense::RowMajorMatrix;
205 use p3_matrix::{Dimensions, Matrix};
206 use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
207 use p3_symmetric::{
208 CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
209 };
210 use rand::thread_rng;
211
212 use super::FieldMerkleTreeMmcs;
213
214 type F = BabyBear;
215
216 type Perm = Poseidon2<F, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>;
217 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
218 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
219 type MyMmcs =
220 FieldMerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
221
222 #[test]
223 fn commit_single_1x8() {
224 let perm = Perm::new_from_rng_128(
225 Poseidon2ExternalMatrixGeneral,
226 DiffusionMatrixBabyBear::default(),
227 &mut thread_rng(),
228 );
229 let hash = MyHash::new(perm.clone());
230 let compress = MyCompress::new(perm);
231 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
232
233 let v = vec![
235 F::two(),
236 F::one(),
237 F::two(),
238 F::two(),
239 F::zero(),
240 F::zero(),
241 F::one(),
242 F::zero(),
243 ];
244 let (commit, _) = mmcs.commit_vec(v.clone());
245
246 let expected_result = compress.compress([
247 compress.compress([
248 compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
249 compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
250 ]),
251 compress.compress([
252 compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
253 compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
254 ]),
255 ]);
256 assert_eq!(commit, expected_result);
257 }
258
259 #[test]
260 fn commit_single_2x2() {
261 let perm = Perm::new_from_rng_128(
262 Poseidon2ExternalMatrixGeneral,
263 DiffusionMatrixBabyBear::default(),
264 &mut thread_rng(),
265 );
266 let hash = MyHash::new(perm.clone());
267 let compress = MyCompress::new(perm);
268 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
269
270 let mat = RowMajorMatrix::new(vec![F::zero(), F::one(), F::two(), F::one()], 2);
275
276 let (commit, _) = mmcs.commit(vec![mat]);
277
278 let expected_result = compress.compress([
279 hash.hash_slice(&[F::zero(), F::one()]),
280 hash.hash_slice(&[F::two(), F::one()]),
281 ]);
282 assert_eq!(commit, expected_result);
283 }
284
285 #[test]
286 fn commit_single_2x3() {
287 let perm = Perm::new_from_rng_128(
288 Poseidon2ExternalMatrixGeneral,
289 DiffusionMatrixBabyBear::default(),
290 &mut thread_rng(),
291 );
292 let hash = MyHash::new(perm.clone());
293 let compress = MyCompress::new(perm);
294 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
295 let default_digest = [F::zero(); 8];
296
297 let mat = RowMajorMatrix::new(
303 vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()],
304 2,
305 );
306
307 let (commit, _) = mmcs.commit(vec![mat]);
308
309 let expected_result = compress.compress([
310 compress.compress([
311 hash.hash_slice(&[F::zero(), F::one()]),
312 hash.hash_slice(&[F::two(), F::one()]),
313 ]),
314 compress.compress([hash.hash_slice(&[F::two(), F::two()]), default_digest]),
315 ]);
316 assert_eq!(commit, expected_result);
317 }
318
319 #[test]
320 fn commit_mixed() {
321 let perm = Perm::new_from_rng_128(
322 Poseidon2ExternalMatrixGeneral,
323 DiffusionMatrixBabyBear::default(),
324 &mut thread_rng(),
325 );
326 let hash = MyHash::new(perm.clone());
327 let compress = MyCompress::new(perm);
328 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
329 let default_digest = [F::zero(); 8];
330
331 let mat_1 = RowMajorMatrix::new(
337 vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()],
338 2,
339 );
340 let mat_2 = RowMajorMatrix::new(
345 vec![F::one(), F::two(), F::one(), F::zero(), F::two(), F::two()],
346 3,
347 );
348
349 let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
350
351 let mat_1_leaf_hashes = [
352 hash.hash_slice(&[F::zero(), F::one()]),
353 hash.hash_slice(&[F::two(), F::one()]),
354 hash.hash_slice(&[F::two(), F::two()]),
355 ];
356 let mat_2_leaf_hashes = [
357 hash.hash_slice(&[F::one(), F::two(), F::one()]),
358 hash.hash_slice(&[F::zero(), F::two(), F::two()]),
359 ];
360
361 let expected_result = compress.compress([
362 compress.compress([
363 compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
364 mat_2_leaf_hashes[0],
365 ]),
366 compress.compress([
367 compress.compress([mat_1_leaf_hashes[2], default_digest]),
368 mat_2_leaf_hashes[1],
369 ]),
370 ]);
371 assert_eq!(commit, expected_result);
372
373 let (opened_values, _proof) = mmcs.open_batch(2, &prover_data);
374 assert_eq!(
375 opened_values,
376 vec![
377 vec![F::two(), F::two()],
378 vec![F::zero(), F::two(), F::two()]
379 ]
380 );
381 }
382
383 #[test]
384 fn commit_either_order() {
385 let mut rng = thread_rng();
386 let perm = Perm::new_from_rng_128(
387 Poseidon2ExternalMatrixGeneral,
388 DiffusionMatrixBabyBear::default(),
389 &mut rng,
390 );
391 let hash = MyHash::new(perm.clone());
392 let compress = MyCompress::new(perm);
393 let mmcs = MyMmcs::new(hash, compress);
394
395 let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
396 let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
397
398 let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
399 let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
400 assert_eq!(commit_1_2, commit_2_1);
401 }
402
403 #[test]
404 #[should_panic]
405 fn mismatched_heights() {
406 let mut rng = thread_rng();
407 let perm = Perm::new_from_rng_128(
408 Poseidon2ExternalMatrixGeneral,
409 DiffusionMatrixBabyBear::default(),
410 &mut rng,
411 );
412 let hash = MyHash::new(perm.clone());
413 let compress = MyCompress::new(perm);
414 let mmcs = MyMmcs::new(hash, compress);
415
416 let large_mat = RowMajorMatrix::new(
418 [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
419 1,
420 );
421 let small_mat =
422 RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
423 let _ = mmcs.commit(vec![large_mat, small_mat]);
424 }
425
426 #[test]
427 fn verify_tampered_proof_fails() {
428 let mut rng = thread_rng();
429 let perm = Perm::new_from_rng_128(
430 Poseidon2ExternalMatrixGeneral,
431 DiffusionMatrixBabyBear::default(),
432 &mut rng,
433 );
434 let hash = MyHash::new(perm.clone());
435 let compress = MyCompress::new(perm);
436 let mmcs = MyMmcs::new(hash, compress);
437
438 let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 1));
440 let large_mat_dims = (0..4).map(|_| Dimensions {
441 height: 8,
442 width: 1,
443 });
444 let small_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 2));
445 let small_mat_dims = (0..4).map(|_| Dimensions {
446 height: 8,
447 width: 2,
448 });
449
450 let (commit, prover_data) = mmcs.commit(large_mats.chain(small_mats).collect_vec());
451
452 let (opened_values, mut proof) = mmcs.open_batch(3, &prover_data);
454 proof[0][0] += F::one();
455 mmcs.verify_batch(
456 &commit,
457 &large_mat_dims.chain(small_mat_dims).collect_vec(),
458 3,
459 &opened_values,
460 &proof,
461 )
462 .expect_err("expected verification to fail");
463 }
464
465 #[test]
466 fn size_gaps() {
467 let mut rng = thread_rng();
468 let perm = Perm::new_from_rng_128(
469 Poseidon2ExternalMatrixGeneral,
470 DiffusionMatrixBabyBear::default(),
471 &mut rng,
472 );
473 let hash = MyHash::new(perm.clone());
474 let compress = MyCompress::new(perm);
475 let mmcs = MyMmcs::new(hash, compress);
476
477 let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1000, 8));
479 let large_mat_dims = (0..4).map(|_| Dimensions {
480 height: 1000,
481 width: 8,
482 });
483
484 let medium_mats = (0..5).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 70, 8));
486 let medium_mat_dims = (0..5).map(|_| Dimensions {
487 height: 70,
488 width: 8,
489 });
490
491 let small_mats = (0..6).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 8));
493 let small_mat_dims = (0..6).map(|_| Dimensions {
494 height: 8,
495 width: 8,
496 });
497
498 let (commit, prover_data) = mmcs.commit(
499 large_mats
500 .chain(medium_mats)
501 .chain(small_mats)
502 .collect_vec(),
503 );
504
505 let (opened_values, proof) = mmcs.open_batch(6, &prover_data);
507 mmcs.verify_batch(
508 &commit,
509 &large_mat_dims
510 .chain(medium_mat_dims)
511 .chain(small_mat_dims)
512 .collect_vec(),
513 6,
514 &opened_values,
515 &proof,
516 )
517 .expect("expected verification to succeed");
518 }
519
520 #[test]
521 fn different_widths() {
522 let mut rng = thread_rng();
523 let perm = Perm::new_from_rng_128(
524 Poseidon2ExternalMatrixGeneral,
525 DiffusionMatrixBabyBear::default(),
526 &mut rng,
527 );
528 let hash = MyHash::new(perm.clone());
529 let compress = MyCompress::new(perm);
530 let mmcs = MyMmcs::new(hash, compress);
531
532 let mats = (0..10)
534 .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
535 .collect_vec();
536 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
537
538 let (commit, prover_data) = mmcs.commit(mats);
539 let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
540 mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
541 .expect("expected verification to succeed");
542 }
543}