p3_merkle_tree/
mmcs.rs

1//! A MerkleTreeMmcs is a generalization of the standard MerkleTree commitment scheme which supports
2//! committing to several matrices of different dimensions.
3//!
4//! Say we wish to commit to 2 matrices M and N with dimensions (8, i) and (2, j) respectively.
5//! Let H denote the hash function and C the compression function for our tree.
6//! Then MerkleTreeMmcs produces a commitment to M and N using the following tree structure:
7//!
8//! ```rust,ignore
9//! ///
10//! ///                                      root = c00 = C(c10, c11)                              
11//! ///                       /                                                \         
12//! ///         c10 = C(C(c20, c21), H(N[0]))                     c11 = C(C(c22, c23), H(N[1]))
13//! ///           /                      \                          /                      \     
14//! ///      c20 = C(L, R)            c21 = C(L, R)            c22 = C(L, R)            c23 = C(L, R)  
15//! ///   L/             \R        L/             \R        L/             \R        L/             \R     
16//! /// H(M[0])         H(M[1])  H(M[2])         H(M[3])  H(M[4])         H(M[5])  H(M[6])         H(M[7])
17//! ```
18//! E.g. we start by making a standard MerkleTree commitment for each row of M and then add in the rows of N when we
19//! get to the correct level. A proof for the values of say `M[5]` and `N[1]` consists of the siblings `H(M[4]), c23, c10`.
20//!
21
22use alloc::vec::Vec;
23use core::cmp::Reverse;
24use core::marker::PhantomData;
25
26use itertools::Itertools;
27use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
28use p3_field::PackedValue;
29use p3_matrix::{Dimensions, Matrix};
30use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
31use p3_util::{log2_ceil_usize, log2_strict_usize};
32use serde::{Deserialize, Serialize};
33
34use crate::MerkleTree;
35use crate::MerkleTreeError::{
36    EmptyBatch, IncompatibleHeights, RootMismatch, WrongBatchSize, WrongHeight,
37};
38
39/// A vector commitment scheme backed by a `MerkleTree`.
40///
41/// Generics:
42/// - `P`: a leaf value
43/// - `PW`: an element of a digest
44/// - `H`: the leaf hasher
45/// - `C`: the digest compression function
46#[derive(Copy, Clone, Debug)]
47pub struct MerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
48    hash: H,
49    compress: C,
50    _phantom: PhantomData<(P, PW)>,
51}
52
53#[derive(Debug)]
54pub enum MerkleTreeError {
55    WrongBatchSize,
56    WrongWidth,
57    WrongHeight {
58        log_max_height: usize,
59        num_siblings: usize,
60    },
61    IncompatibleHeights,
62    RootMismatch,
63    EmptyBatch,
64}
65
66impl<P, PW, H, C, const DIGEST_ELEMS: usize> MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
67    pub const fn new(hash: H, compress: C) -> Self {
68        Self {
69            hash,
70            compress,
71            _phantom: PhantomData,
72        }
73    }
74}
75
76impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Value>
77    for MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
78where
79    P: PackedValue,
80    PW: PackedValue,
81    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
82        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
83        + Sync,
84    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
85        + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
86        + Sync,
87    PW::Value: Eq,
88    [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
89{
90    type ProverData<M> = MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>;
91    type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
92    type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
93    type Error = MerkleTreeError;
94
95    fn commit<M: Matrix<P::Value>>(
96        &self,
97        inputs: Vec<M>,
98    ) -> (Self::Commitment, Self::ProverData<M>) {
99        let tree = MerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
100        let root = tree.root();
101        (root, tree)
102    }
103
104    /// Opens a batch of rows from committed matrices.
105    ///
106    /// Returns `(openings, proof)` where `openings` is a vector whose `i`th element is
107    /// the `j`th row of the ith matrix `M[i]`, with
108    ///     `j == index >> (log2_ceil(max_height) - log2_ceil(M[i].height))`
109    /// and `proof` is the vector of sibling Merkle tree nodes allowing the verifier to
110    /// reconstruct the committed root.
111    fn open_batch<M: Matrix<P::Value>>(
112        &self,
113        index: usize,
114        prover_data: &MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>,
115    ) -> BatchOpening<P::Value, Self> {
116        let max_height = self.get_max_height(prover_data);
117        let log_max_height = log2_ceil_usize(max_height);
118
119        // Get the matrix rows encountered along the path from the root to the given leaf index.
120        let openings = prover_data
121            .leaves
122            .iter()
123            .map(|matrix| {
124                let log2_height = log2_ceil_usize(matrix.height());
125                let bits_reduced = log_max_height - log2_height;
126                let reduced_index = index >> bits_reduced;
127                matrix.row(reduced_index).unwrap().into_iter().collect()
128            })
129            .collect_vec();
130
131        // Get all the siblings nodes corresponding to the path from the root to the given leaf index.
132        let proof = (0..log_max_height)
133            .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
134            .collect();
135
136        BatchOpening::new(openings, proof)
137    }
138
139    fn get_matrices<'a, M: Matrix<P::Value>>(
140        &self,
141        prover_data: &'a Self::ProverData<M>,
142    ) -> Vec<&'a M> {
143        prover_data.leaves.iter().collect()
144    }
145
146    /// Verifies an opened batch of rows with respect to a given commitment.
147    ///
148    /// - `commit`: The merkle root of the tree.
149    /// - `dimensions`: A vector of the dimensions of the matrices committed to.
150    /// - `index`: The index of a leaf in the tree.
151    /// - `opened_values`: A vector of matrix rows. Assume that the tallest matrix committed
152    ///   to has height `2^n >= M_tall.height() > 2^{n - 1}` and the `j`th matrix has height
153    ///   `2^m >= Mj.height() > 2^{m - 1}`. Then `j`'th value of opened values must be the row `Mj[index >> (m - n)]`.
154    /// - `proof`: A vector of sibling nodes. The `i`th element should be the node at level `i`
155    ///   with index `(index << i) ^ 1`.
156    ///
157    /// Returns nothing if the verification is successful, otherwise returns an error.
158    fn verify_batch(
159        &self,
160        commit: &Self::Commitment,
161        dimensions: &[Dimensions],
162        mut index: usize,
163        batch_proof: BatchOpeningRef<P::Value, Self>,
164    ) -> Result<(), Self::Error> {
165        let (opened_values, opening_proof) = batch_proof.unpack();
166        // Check that the openings have the correct shape.
167        if dimensions.len() != opened_values.len() {
168            return Err(WrongBatchSize);
169        }
170
171        // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width.
172        // for (dims, opened_vals) in zip_eq(dimensions.iter(), opened_values) {
173        //     if opened_vals.len() != dims.width {
174        //         return Err(WrongWidth);
175        //     }
176        // }
177
178        let mut heights_tallest_first = dimensions
179            .iter()
180            .enumerate()
181            .sorted_by_key(|(_, dims)| Reverse(dims.height))
182            .peekable();
183
184        // Matrix heights that round up to the same power of two must be equal
185        if !heights_tallest_first
186            .clone()
187            .map(|(_, dims)| dims.height)
188            .tuple_windows()
189            .all(|(curr, next)| {
190                curr == next || curr.next_power_of_two() != next.next_power_of_two()
191            })
192        {
193            return Err(IncompatibleHeights);
194        }
195
196        // Get the initial height padded to a power of two. As heights_tallest_first is sorted,
197        // the initial height will be the maximum height.
198        // Returns an error if either:
199        //              1. proof.len() != log_max_height
200        //              2. heights_tallest_first is empty.
201        let mut curr_height_padded = match heights_tallest_first.peek() {
202            Some((_, dims)) => {
203                let max_height = dims.height.next_power_of_two();
204                let log_max_height = log2_strict_usize(max_height);
205                if opening_proof.len() != log_max_height {
206                    return Err(WrongHeight {
207                        log_max_height,
208                        num_siblings: opening_proof.len(),
209                    });
210                }
211                max_height
212            }
213            None => return Err(EmptyBatch),
214        };
215
216        // Hash all matrix openings at the current height.
217        let mut root = self.hash.hash_iter_slices(
218            heights_tallest_first
219                .peeking_take_while(|(_, dims)| {
220                    dims.height.next_power_of_two() == curr_height_padded
221                })
222                .map(|(i, _)| opened_values[i].as_slice()),
223        );
224
225        for &sibling in opening_proof {
226            // The last bit of index informs us whether the current node is on the left or right.
227            let (left, right) = if index & 1 == 0 {
228                (root, sibling)
229            } else {
230                (sibling, root)
231            };
232
233            // Combine the current node with the sibling node to get the parent node.
234            root = self.compress.compress([left, right]);
235            index >>= 1;
236            curr_height_padded >>= 1;
237
238            // Check if there are any new matrix rows to inject at the next height.
239            let next_height = heights_tallest_first
240                .peek()
241                .map(|(_, dims)| dims.height)
242                .filter(|h| h.next_power_of_two() == curr_height_padded);
243            if let Some(next_height) = next_height {
244                // If there are new matrix rows, hash the rows together and then combine with the current root.
245                let next_height_openings_digest = self.hash.hash_iter_slices(
246                    heights_tallest_first
247                        .peeking_take_while(|(_, dims)| dims.height == next_height)
248                        .map(|(i, _)| opened_values[i].as_slice()),
249                );
250
251                root = self.compress.compress([root, next_height_openings_digest]);
252            }
253        }
254
255        // The computed root should equal the committed one.
256        if commit == &root {
257            Ok(())
258        } else {
259            Err(RootMismatch)
260        }
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use alloc::vec;
267
268    use itertools::Itertools;
269    use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
270    use p3_commit::Mmcs;
271    use p3_field::{Field, PrimeCharacteristicRing};
272    use p3_matrix::dense::RowMajorMatrix;
273    use p3_matrix::{Dimensions, Matrix};
274    use p3_symmetric::{
275        CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
276    };
277    use rand::SeedableRng;
278    use rand::rngs::SmallRng;
279
280    use super::MerkleTreeMmcs;
281
282    type F = BabyBear;
283
284    type Perm = Poseidon2BabyBear<16>;
285    type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
286    type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
287    type MyMmcs =
288        MerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
289
290    #[test]
291    fn commit_single_1x8() {
292        let mut rng = SmallRng::seed_from_u64(1);
293        let perm = Perm::new_from_rng_128(&mut rng);
294        let hash = MyHash::new(perm.clone());
295        let compress = MyCompress::new(perm);
296        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
297
298        // v = [2, 1, 2, 2, 0, 0, 1, 0]
299        let v = vec![
300            F::TWO,
301            F::ONE,
302            F::TWO,
303            F::TWO,
304            F::ZERO,
305            F::ZERO,
306            F::ONE,
307            F::ZERO,
308        ];
309        let (commit, _) = mmcs.commit_vec(v.clone());
310
311        let expected_result = compress.compress([
312            compress.compress([
313                compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
314                compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
315            ]),
316            compress.compress([
317                compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
318                compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
319            ]),
320        ]);
321        assert_eq!(commit, expected_result);
322    }
323
324    #[test]
325    fn commit_single_8x1() {
326        let mut rng = SmallRng::seed_from_u64(1);
327        let perm = Perm::new_from_rng_128(&mut rng);
328        let hash = MyHash::new(perm.clone());
329        let compress = MyCompress::new(perm);
330        let mmcs = MyMmcs::new(hash.clone(), compress);
331
332        let mat = RowMajorMatrix::<F>::rand(&mut rng, 1, 8);
333        let (commit, _) = mmcs.commit(vec![mat.clone()]);
334
335        let expected_result = hash.hash_iter(mat.vertically_packed_row(0));
336        assert_eq!(commit, expected_result);
337    }
338
339    #[test]
340    fn commit_single_2x2() {
341        let mut rng = SmallRng::seed_from_u64(1);
342        let perm = Perm::new_from_rng_128(&mut rng);
343        let hash = MyHash::new(perm.clone());
344        let compress = MyCompress::new(perm);
345        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
346
347        // mat = [
348        //   0 1
349        //   2 1
350        // ]
351        let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE], 2);
352
353        let (commit, _) = mmcs.commit(vec![mat]);
354
355        let expected_result = compress.compress([
356            hash.hash_slice(&[F::ZERO, F::ONE]),
357            hash.hash_slice(&[F::TWO, F::ONE]),
358        ]);
359        assert_eq!(commit, expected_result);
360    }
361
362    #[test]
363    fn commit_single_2x3() {
364        let mut rng = SmallRng::seed_from_u64(1);
365        let perm = Perm::new_from_rng_128(&mut rng);
366        let hash = MyHash::new(perm.clone());
367        let compress = MyCompress::new(perm);
368        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
369        let default_digest = [F::ZERO; 8];
370
371        // mat = [
372        //   0 1
373        //   2 1
374        //   2 2
375        // ]
376        let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE, F::TWO, F::TWO], 2);
377
378        let (commit, _) = mmcs.commit(vec![mat]);
379
380        let expected_result = compress.compress([
381            compress.compress([
382                hash.hash_slice(&[F::ZERO, F::ONE]),
383                hash.hash_slice(&[F::TWO, F::ONE]),
384            ]),
385            compress.compress([hash.hash_slice(&[F::TWO, F::TWO]), default_digest]),
386        ]);
387        assert_eq!(commit, expected_result);
388    }
389
390    #[test]
391    fn commit_mixed() {
392        let mut rng = SmallRng::seed_from_u64(1);
393        let perm = Perm::new_from_rng_128(&mut rng);
394        let hash = MyHash::new(perm.clone());
395        let compress = MyCompress::new(perm);
396        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
397        let default_digest = [F::ZERO; 8];
398
399        // mat_1 = [
400        //   0 1
401        //   2 1
402        //   2 2
403        //   2 1
404        //   2 2
405        // ]
406        let mat_1 = RowMajorMatrix::new(
407            vec![
408                F::ZERO,
409                F::ONE,
410                F::TWO,
411                F::ONE,
412                F::TWO,
413                F::TWO,
414                F::TWO,
415                F::ONE,
416                F::TWO,
417                F::TWO,
418            ],
419            2,
420        );
421        // mat_2 = [
422        //   1 2 1
423        //   0 2 2
424        //   1 2 1
425        // ]
426        let mat_2 = RowMajorMatrix::new(
427            vec![
428                F::ONE,
429                F::TWO,
430                F::ONE,
431                F::ZERO,
432                F::TWO,
433                F::TWO,
434                F::ONE,
435                F::TWO,
436                F::ONE,
437            ],
438            3,
439        );
440
441        let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
442
443        let mat_1_leaf_hashes = [
444            hash.hash_slice(&[F::ZERO, F::ONE]),
445            hash.hash_slice(&[F::TWO, F::ONE]),
446            hash.hash_slice(&[F::TWO, F::TWO]),
447            hash.hash_slice(&[F::TWO, F::ONE]),
448            hash.hash_slice(&[F::TWO, F::TWO]),
449        ];
450        let mat_2_leaf_hashes = [
451            hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
452            hash.hash_slice(&[F::ZERO, F::TWO, F::TWO]),
453            hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
454        ];
455
456        let expected_result = compress.compress([
457            compress.compress([
458                compress.compress([
459                    compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
460                    mat_2_leaf_hashes[0],
461                ]),
462                compress.compress([
463                    compress.compress([mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]]),
464                    mat_2_leaf_hashes[1],
465                ]),
466            ]),
467            compress.compress([
468                compress.compress([
469                    compress.compress([mat_1_leaf_hashes[4], default_digest]),
470                    mat_2_leaf_hashes[2],
471                ]),
472                default_digest,
473            ]),
474        ]);
475
476        assert_eq!(commit, expected_result);
477
478        let (opened_values, _) = mmcs.open_batch(2, &prover_data).unpack();
479        assert_eq!(
480            opened_values,
481            vec![vec![F::TWO, F::TWO], vec![F::ZERO, F::TWO, F::TWO]]
482        );
483    }
484
485    #[test]
486    fn commit_either_order() {
487        let mut rng = SmallRng::seed_from_u64(1);
488        let perm = Perm::new_from_rng_128(&mut rng);
489        let hash = MyHash::new(perm.clone());
490        let compress = MyCompress::new(perm);
491        let mmcs = MyMmcs::new(hash, compress);
492
493        let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
494        let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
495
496        let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
497        let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
498        assert_eq!(commit_1_2, commit_2_1);
499    }
500
501    #[test]
502    #[should_panic]
503    fn mismatched_heights() {
504        let mut rng = SmallRng::seed_from_u64(1);
505        let perm = Perm::new_from_rng_128(&mut rng);
506        let hash = MyHash::new(perm.clone());
507        let compress = MyCompress::new(perm);
508        let mmcs = MyMmcs::new(hash, compress);
509
510        // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic.
511        let large_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u8).to_vec(), 1);
512        let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_u8).to_vec(), 1);
513        let _ = mmcs.commit(vec![large_mat, small_mat]);
514    }
515
516    #[test]
517    fn verify_tampered_proof_fails() {
518        let mut rng = SmallRng::seed_from_u64(1);
519        let perm = Perm::new_from_rng_128(&mut rng);
520        let hash = MyHash::new(perm.clone());
521        let compress = MyCompress::new(perm);
522        let mmcs = MyMmcs::new(hash, compress);
523
524        // 4 8x1 matrixes, 4 8x2 matrixes
525        let mut mats = (0..4)
526            .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 1))
527            .collect_vec();
528        let large_mat_dims = (0..4).map(|_| Dimensions {
529            height: 8,
530            width: 1,
531        });
532        mats.extend((0..4).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 2)));
533        let small_mat_dims = (0..4).map(|_| Dimensions {
534            height: 8,
535            width: 2,
536        });
537
538        let (commit, prover_data) = mmcs.commit(mats);
539
540        // open the 3rd row of each matrix, mess with proof, and verify
541        let mut batch_opening = mmcs.open_batch(3, &prover_data);
542        batch_opening.opening_proof[0][0] += F::ONE;
543        mmcs.verify_batch(
544            &commit,
545            &large_mat_dims.chain(small_mat_dims).collect_vec(),
546            3,
547            (&batch_opening).into(),
548        )
549        .expect_err("expected verification to fail");
550    }
551
552    #[test]
553    fn size_gaps() {
554        let mut rng = SmallRng::seed_from_u64(1);
555        let perm = Perm::new_from_rng_128(&mut rng);
556        let hash = MyHash::new(perm.clone());
557        let compress = MyCompress::new(perm);
558        let mmcs = MyMmcs::new(hash, compress);
559
560        // 4 mats with 1000 rows, 8 columns
561        let mut mats = (0..4)
562            .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1000, 8))
563            .collect_vec();
564        let large_mat_dims = (0..4).map(|_| Dimensions {
565            height: 1000,
566            width: 8,
567        });
568
569        // 5 mats with 70 rows, 8 columns
570        mats.extend((0..5).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 70, 8)));
571        let medium_mat_dims = (0..5).map(|_| Dimensions {
572            height: 70,
573            width: 8,
574        });
575
576        // 6 mats with 8 rows, 8 columns
577        mats.extend((0..6).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 8)));
578        let small_mat_dims = (0..6).map(|_| Dimensions {
579            height: 8,
580            width: 8,
581        });
582
583        // 7 tiny mat with 1 row, 8 columns
584        mats.extend((0..7).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1, 8)));
585        let tiny_mat_dims = (0..7).map(|_| Dimensions {
586            height: 1,
587            width: 8,
588        });
589
590        let (commit, prover_data) = mmcs.commit(mats);
591
592        // open the 6th row of each matrix and verify
593        let batch_opening = mmcs.open_batch(6, &prover_data);
594        mmcs.verify_batch(
595            &commit,
596            &large_mat_dims
597                .chain(medium_mat_dims)
598                .chain(small_mat_dims)
599                .chain(tiny_mat_dims)
600                .collect_vec(),
601            6,
602            (&batch_opening).into(),
603        )
604        .expect("expected verification to succeed");
605    }
606
607    #[test]
608    fn different_widths() {
609        let mut rng = SmallRng::seed_from_u64(1);
610        let perm = Perm::new_from_rng_128(&mut rng);
611        let hash = MyHash::new(perm.clone());
612        let compress = MyCompress::new(perm);
613        let mmcs = MyMmcs::new(hash, compress);
614
615        // 10 mats with 32 rows where the ith mat has i + 1 cols
616        let mats = (0..10)
617            .map(|i| RowMajorMatrix::<F>::rand(&mut rng, 32, i + 1))
618            .collect_vec();
619        let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
620
621        let (commit, prover_data) = mmcs.commit(mats);
622        let batch_opening = mmcs.open_batch(17, &prover_data);
623        mmcs.verify_batch(&commit, &dims, 17, (&batch_opening).into())
624            .expect("expected verification to succeed");
625    }
626}