1use alloc::vec::Vec;
23use core::cmp::Reverse;
24use core::marker::PhantomData;
25
26use itertools::Itertools;
27use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
28use p3_field::PackedValue;
29use p3_matrix::{Dimensions, Matrix};
30use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
31use p3_util::{log2_ceil_usize, log2_strict_usize};
32use serde::{Deserialize, Serialize};
33
34use crate::MerkleTree;
35use crate::MerkleTreeError::{
36 EmptyBatch, IncompatibleHeights, RootMismatch, WrongBatchSize, WrongHeight,
37};
38
39#[derive(Copy, Clone, Debug)]
47pub struct MerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
48 hash: H,
49 compress: C,
50 _phantom: PhantomData<(P, PW)>,
51}
52
53#[derive(Debug)]
54pub enum MerkleTreeError {
55 WrongBatchSize,
56 WrongWidth,
57 WrongHeight {
58 log_max_height: usize,
59 num_siblings: usize,
60 },
61 IncompatibleHeights,
62 RootMismatch,
63 EmptyBatch,
64}
65
66impl<P, PW, H, C, const DIGEST_ELEMS: usize> MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
67 pub const fn new(hash: H, compress: C) -> Self {
68 Self {
69 hash,
70 compress,
71 _phantom: PhantomData,
72 }
73 }
74}
75
76impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Value>
77 for MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
78where
79 P: PackedValue,
80 PW: PackedValue,
81 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
82 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
83 + Sync,
84 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
85 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
86 + Sync,
87 PW::Value: Eq,
88 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
89{
90 type ProverData<M> = MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>;
91 type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
92 type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
93 type Error = MerkleTreeError;
94
95 fn commit<M: Matrix<P::Value>>(
96 &self,
97 inputs: Vec<M>,
98 ) -> (Self::Commitment, Self::ProverData<M>) {
99 let tree = MerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
100 let root = tree.root();
101 (root, tree)
102 }
103
104 fn open_batch<M: Matrix<P::Value>>(
112 &self,
113 index: usize,
114 prover_data: &MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>,
115 ) -> BatchOpening<P::Value, Self> {
116 let max_height = self.get_max_height(prover_data);
117 let log_max_height = log2_ceil_usize(max_height);
118
119 let openings = prover_data
121 .leaves
122 .iter()
123 .map(|matrix| {
124 let log2_height = log2_ceil_usize(matrix.height());
125 let bits_reduced = log_max_height - log2_height;
126 let reduced_index = index >> bits_reduced;
127 matrix.row(reduced_index).unwrap().into_iter().collect()
128 })
129 .collect_vec();
130
131 let proof = (0..log_max_height)
133 .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
134 .collect();
135
136 BatchOpening::new(openings, proof)
137 }
138
139 fn get_matrices<'a, M: Matrix<P::Value>>(
140 &self,
141 prover_data: &'a Self::ProverData<M>,
142 ) -> Vec<&'a M> {
143 prover_data.leaves.iter().collect()
144 }
145
146 fn verify_batch(
159 &self,
160 commit: &Self::Commitment,
161 dimensions: &[Dimensions],
162 mut index: usize,
163 batch_proof: BatchOpeningRef<P::Value, Self>,
164 ) -> Result<(), Self::Error> {
165 let (opened_values, opening_proof) = batch_proof.unpack();
166 if dimensions.len() != opened_values.len() {
168 return Err(WrongBatchSize);
169 }
170
171 let mut heights_tallest_first = dimensions
179 .iter()
180 .enumerate()
181 .sorted_by_key(|(_, dims)| Reverse(dims.height))
182 .peekable();
183
184 if !heights_tallest_first
186 .clone()
187 .map(|(_, dims)| dims.height)
188 .tuple_windows()
189 .all(|(curr, next)| {
190 curr == next || curr.next_power_of_two() != next.next_power_of_two()
191 })
192 {
193 return Err(IncompatibleHeights);
194 }
195
196 let mut curr_height_padded = match heights_tallest_first.peek() {
202 Some((_, dims)) => {
203 let max_height = dims.height.next_power_of_two();
204 let log_max_height = log2_strict_usize(max_height);
205 if opening_proof.len() != log_max_height {
206 return Err(WrongHeight {
207 log_max_height,
208 num_siblings: opening_proof.len(),
209 });
210 }
211 max_height
212 }
213 None => return Err(EmptyBatch),
214 };
215
216 let mut root = self.hash.hash_iter_slices(
218 heights_tallest_first
219 .peeking_take_while(|(_, dims)| {
220 dims.height.next_power_of_two() == curr_height_padded
221 })
222 .map(|(i, _)| opened_values[i].as_slice()),
223 );
224
225 for &sibling in opening_proof {
226 let (left, right) = if index & 1 == 0 {
228 (root, sibling)
229 } else {
230 (sibling, root)
231 };
232
233 root = self.compress.compress([left, right]);
235 index >>= 1;
236 curr_height_padded >>= 1;
237
238 let next_height = heights_tallest_first
240 .peek()
241 .map(|(_, dims)| dims.height)
242 .filter(|h| h.next_power_of_two() == curr_height_padded);
243 if let Some(next_height) = next_height {
244 let next_height_openings_digest = self.hash.hash_iter_slices(
246 heights_tallest_first
247 .peeking_take_while(|(_, dims)| dims.height == next_height)
248 .map(|(i, _)| opened_values[i].as_slice()),
249 );
250
251 root = self.compress.compress([root, next_height_openings_digest]);
252 }
253 }
254
255 if commit == &root {
257 Ok(())
258 } else {
259 Err(RootMismatch)
260 }
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use alloc::vec;
267
268 use itertools::Itertools;
269 use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
270 use p3_commit::Mmcs;
271 use p3_field::{Field, PrimeCharacteristicRing};
272 use p3_matrix::dense::RowMajorMatrix;
273 use p3_matrix::{Dimensions, Matrix};
274 use p3_symmetric::{
275 CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
276 };
277 use rand::SeedableRng;
278 use rand::rngs::SmallRng;
279
280 use super::MerkleTreeMmcs;
281
282 type F = BabyBear;
283
284 type Perm = Poseidon2BabyBear<16>;
285 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
286 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
287 type MyMmcs =
288 MerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
289
290 #[test]
291 fn commit_single_1x8() {
292 let mut rng = SmallRng::seed_from_u64(1);
293 let perm = Perm::new_from_rng_128(&mut rng);
294 let hash = MyHash::new(perm.clone());
295 let compress = MyCompress::new(perm);
296 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
297
298 let v = vec![
300 F::TWO,
301 F::ONE,
302 F::TWO,
303 F::TWO,
304 F::ZERO,
305 F::ZERO,
306 F::ONE,
307 F::ZERO,
308 ];
309 let (commit, _) = mmcs.commit_vec(v.clone());
310
311 let expected_result = compress.compress([
312 compress.compress([
313 compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
314 compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
315 ]),
316 compress.compress([
317 compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
318 compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
319 ]),
320 ]);
321 assert_eq!(commit, expected_result);
322 }
323
324 #[test]
325 fn commit_single_8x1() {
326 let mut rng = SmallRng::seed_from_u64(1);
327 let perm = Perm::new_from_rng_128(&mut rng);
328 let hash = MyHash::new(perm.clone());
329 let compress = MyCompress::new(perm);
330 let mmcs = MyMmcs::new(hash.clone(), compress);
331
332 let mat = RowMajorMatrix::<F>::rand(&mut rng, 1, 8);
333 let (commit, _) = mmcs.commit(vec![mat.clone()]);
334
335 let expected_result = hash.hash_iter(mat.vertically_packed_row(0));
336 assert_eq!(commit, expected_result);
337 }
338
339 #[test]
340 fn commit_single_2x2() {
341 let mut rng = SmallRng::seed_from_u64(1);
342 let perm = Perm::new_from_rng_128(&mut rng);
343 let hash = MyHash::new(perm.clone());
344 let compress = MyCompress::new(perm);
345 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
346
347 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE], 2);
352
353 let (commit, _) = mmcs.commit(vec![mat]);
354
355 let expected_result = compress.compress([
356 hash.hash_slice(&[F::ZERO, F::ONE]),
357 hash.hash_slice(&[F::TWO, F::ONE]),
358 ]);
359 assert_eq!(commit, expected_result);
360 }
361
362 #[test]
363 fn commit_single_2x3() {
364 let mut rng = SmallRng::seed_from_u64(1);
365 let perm = Perm::new_from_rng_128(&mut rng);
366 let hash = MyHash::new(perm.clone());
367 let compress = MyCompress::new(perm);
368 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
369 let default_digest = [F::ZERO; 8];
370
371 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE, F::TWO, F::TWO], 2);
377
378 let (commit, _) = mmcs.commit(vec![mat]);
379
380 let expected_result = compress.compress([
381 compress.compress([
382 hash.hash_slice(&[F::ZERO, F::ONE]),
383 hash.hash_slice(&[F::TWO, F::ONE]),
384 ]),
385 compress.compress([hash.hash_slice(&[F::TWO, F::TWO]), default_digest]),
386 ]);
387 assert_eq!(commit, expected_result);
388 }
389
390 #[test]
391 fn commit_mixed() {
392 let mut rng = SmallRng::seed_from_u64(1);
393 let perm = Perm::new_from_rng_128(&mut rng);
394 let hash = MyHash::new(perm.clone());
395 let compress = MyCompress::new(perm);
396 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
397 let default_digest = [F::ZERO; 8];
398
399 let mat_1 = RowMajorMatrix::new(
407 vec![
408 F::ZERO,
409 F::ONE,
410 F::TWO,
411 F::ONE,
412 F::TWO,
413 F::TWO,
414 F::TWO,
415 F::ONE,
416 F::TWO,
417 F::TWO,
418 ],
419 2,
420 );
421 let mat_2 = RowMajorMatrix::new(
427 vec![
428 F::ONE,
429 F::TWO,
430 F::ONE,
431 F::ZERO,
432 F::TWO,
433 F::TWO,
434 F::ONE,
435 F::TWO,
436 F::ONE,
437 ],
438 3,
439 );
440
441 let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
442
443 let mat_1_leaf_hashes = [
444 hash.hash_slice(&[F::ZERO, F::ONE]),
445 hash.hash_slice(&[F::TWO, F::ONE]),
446 hash.hash_slice(&[F::TWO, F::TWO]),
447 hash.hash_slice(&[F::TWO, F::ONE]),
448 hash.hash_slice(&[F::TWO, F::TWO]),
449 ];
450 let mat_2_leaf_hashes = [
451 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
452 hash.hash_slice(&[F::ZERO, F::TWO, F::TWO]),
453 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
454 ];
455
456 let expected_result = compress.compress([
457 compress.compress([
458 compress.compress([
459 compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
460 mat_2_leaf_hashes[0],
461 ]),
462 compress.compress([
463 compress.compress([mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]]),
464 mat_2_leaf_hashes[1],
465 ]),
466 ]),
467 compress.compress([
468 compress.compress([
469 compress.compress([mat_1_leaf_hashes[4], default_digest]),
470 mat_2_leaf_hashes[2],
471 ]),
472 default_digest,
473 ]),
474 ]);
475
476 assert_eq!(commit, expected_result);
477
478 let (opened_values, _) = mmcs.open_batch(2, &prover_data).unpack();
479 assert_eq!(
480 opened_values,
481 vec![vec![F::TWO, F::TWO], vec![F::ZERO, F::TWO, F::TWO]]
482 );
483 }
484
485 #[test]
486 fn commit_either_order() {
487 let mut rng = SmallRng::seed_from_u64(1);
488 let perm = Perm::new_from_rng_128(&mut rng);
489 let hash = MyHash::new(perm.clone());
490 let compress = MyCompress::new(perm);
491 let mmcs = MyMmcs::new(hash, compress);
492
493 let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
494 let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
495
496 let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
497 let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
498 assert_eq!(commit_1_2, commit_2_1);
499 }
500
501 #[test]
502 #[should_panic]
503 fn mismatched_heights() {
504 let mut rng = SmallRng::seed_from_u64(1);
505 let perm = Perm::new_from_rng_128(&mut rng);
506 let hash = MyHash::new(perm.clone());
507 let compress = MyCompress::new(perm);
508 let mmcs = MyMmcs::new(hash, compress);
509
510 let large_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u8).to_vec(), 1);
512 let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_u8).to_vec(), 1);
513 let _ = mmcs.commit(vec![large_mat, small_mat]);
514 }
515
516 #[test]
517 fn verify_tampered_proof_fails() {
518 let mut rng = SmallRng::seed_from_u64(1);
519 let perm = Perm::new_from_rng_128(&mut rng);
520 let hash = MyHash::new(perm.clone());
521 let compress = MyCompress::new(perm);
522 let mmcs = MyMmcs::new(hash, compress);
523
524 let mut mats = (0..4)
526 .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 1))
527 .collect_vec();
528 let large_mat_dims = (0..4).map(|_| Dimensions {
529 height: 8,
530 width: 1,
531 });
532 mats.extend((0..4).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 2)));
533 let small_mat_dims = (0..4).map(|_| Dimensions {
534 height: 8,
535 width: 2,
536 });
537
538 let (commit, prover_data) = mmcs.commit(mats);
539
540 let mut batch_opening = mmcs.open_batch(3, &prover_data);
542 batch_opening.opening_proof[0][0] += F::ONE;
543 mmcs.verify_batch(
544 &commit,
545 &large_mat_dims.chain(small_mat_dims).collect_vec(),
546 3,
547 (&batch_opening).into(),
548 )
549 .expect_err("expected verification to fail");
550 }
551
552 #[test]
553 fn size_gaps() {
554 let mut rng = SmallRng::seed_from_u64(1);
555 let perm = Perm::new_from_rng_128(&mut rng);
556 let hash = MyHash::new(perm.clone());
557 let compress = MyCompress::new(perm);
558 let mmcs = MyMmcs::new(hash, compress);
559
560 let mut mats = (0..4)
562 .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1000, 8))
563 .collect_vec();
564 let large_mat_dims = (0..4).map(|_| Dimensions {
565 height: 1000,
566 width: 8,
567 });
568
569 mats.extend((0..5).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 70, 8)));
571 let medium_mat_dims = (0..5).map(|_| Dimensions {
572 height: 70,
573 width: 8,
574 });
575
576 mats.extend((0..6).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 8)));
578 let small_mat_dims = (0..6).map(|_| Dimensions {
579 height: 8,
580 width: 8,
581 });
582
583 mats.extend((0..7).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1, 8)));
585 let tiny_mat_dims = (0..7).map(|_| Dimensions {
586 height: 1,
587 width: 8,
588 });
589
590 let (commit, prover_data) = mmcs.commit(mats);
591
592 let batch_opening = mmcs.open_batch(6, &prover_data);
594 mmcs.verify_batch(
595 &commit,
596 &large_mat_dims
597 .chain(medium_mat_dims)
598 .chain(small_mat_dims)
599 .chain(tiny_mat_dims)
600 .collect_vec(),
601 6,
602 (&batch_opening).into(),
603 )
604 .expect("expected verification to succeed");
605 }
606
607 #[test]
608 fn different_widths() {
609 let mut rng = SmallRng::seed_from_u64(1);
610 let perm = Perm::new_from_rng_128(&mut rng);
611 let hash = MyHash::new(perm.clone());
612 let compress = MyCompress::new(perm);
613 let mmcs = MyMmcs::new(hash, compress);
614
615 let mats = (0..10)
617 .map(|i| RowMajorMatrix::<F>::rand(&mut rng, 32, i + 1))
618 .collect_vec();
619 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
620
621 let (commit, prover_data) = mmcs.commit(mats);
622 let batch_opening = mmcs.open_batch(17, &prover_data);
623 mmcs.verify_batch(&commit, &dims, 17, (&batch_opening).into())
624 .expect("expected verification to succeed");
625 }
626}