p3_merkle_tree/
mmcs.rs

1use alloc::vec::Vec;
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5use itertools::Itertools;
6use p3_commit::Mmcs;
7use p3_field::{PackedField, PackedValue};
8use p3_matrix::{Dimensions, Matrix};
9use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
10use p3_util::log2_ceil_usize;
11use serde::{Deserialize, Serialize};
12
13use crate::FieldMerkleTree;
14use crate::FieldMerkleTreeError::{RootMismatch, WrongBatchSize, WrongHeight};
15
16/// A vector commitment scheme backed by a `FieldMerkleTree`.
17///
18/// Generics:
19/// - `P`: a leaf value TODO
20/// - `H`: the leaf hasher
21/// - `C`: the digest compression function
22#[derive(Copy, Clone, Debug)]
23pub struct FieldMerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
24    hash: H,
25    compress: C,
26    _phantom: PhantomData<(P, PW)>,
27}
28
29#[derive(Debug)]
30pub enum FieldMerkleTreeError {
31    WrongBatchSize,
32    WrongWidth,
33    WrongHeight {
34        max_height: usize,
35        num_siblings: usize,
36    },
37    RootMismatch,
38}
39
40impl<P, PW, H, C, const DIGEST_ELEMS: usize> FieldMerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
41    pub const fn new(hash: H, compress: C) -> Self {
42        Self {
43            hash,
44            compress,
45            _phantom: PhantomData,
46        }
47    }
48}
49
50impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Scalar>
51    for FieldMerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
52where
53    P: PackedField,
54    PW: PackedValue,
55    H: CryptographicHasher<P::Scalar, [PW::Value; DIGEST_ELEMS]>,
56    H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
57    H: Sync,
58    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
59    C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
60    C: Sync,
61    PW::Value: Eq,
62    [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
63{
64    type Commitment = Hash<P::Scalar, PW::Value, DIGEST_ELEMS>;
65    type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
66    type Error = FieldMerkleTreeError;
67    type ProverData<M> = FieldMerkleTree<P::Scalar, PW::Value, M, DIGEST_ELEMS>;
68
69    fn commit<M: Matrix<P::Scalar>>(
70        &self,
71        inputs: Vec<M>,
72    ) -> (Self::Commitment, Self::ProverData<M>) {
73        let tree = FieldMerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
74        let root = tree.root();
75        (root, tree)
76    }
77
78    fn open_batch<M: Matrix<P::Scalar>>(
79        &self,
80        index: usize,
81        prover_data: &FieldMerkleTree<P::Scalar, PW::Value, M, DIGEST_ELEMS>,
82    ) -> (Vec<Vec<P::Scalar>>, Vec<[PW::Value; DIGEST_ELEMS]>) {
83        let max_height = self.get_max_height(prover_data);
84        let log_max_height = log2_ceil_usize(max_height);
85
86        let openings = prover_data
87            .leaves
88            .iter()
89            .map(|matrix| {
90                let log2_height = log2_ceil_usize(matrix.height());
91                let bits_reduced = log_max_height - log2_height;
92                let reduced_index = index >> bits_reduced;
93                matrix.row(reduced_index).collect()
94            })
95            .collect_vec();
96
97        let proof: Vec<_> = (0..log_max_height)
98            .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
99            .collect();
100
101        (openings, proof)
102    }
103
104    fn get_matrices<'a, M: Matrix<P::Scalar>>(
105        &self,
106        prover_data: &'a Self::ProverData<M>,
107    ) -> Vec<&'a M> {
108        prover_data.leaves.iter().collect()
109    }
110
111    fn verify_batch(
112        &self,
113        commit: &Self::Commitment,
114        dimensions: &[Dimensions],
115        mut index: usize,
116        opened_values: &[Vec<P::Scalar>],
117        proof: &Self::Proof,
118    ) -> Result<(), Self::Error> {
119        // Check that the openings have the correct shape.
120        if dimensions.len() != opened_values.len() {
121            return Err(WrongBatchSize);
122        }
123
124        // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width.
125        // for (dims, opened_vals) in dimensions.iter().zip(opened_values) {
126        //     if opened_vals.len() != dims.width {
127        //         return Err(WrongWidth);
128        //     }
129        // }
130
131        // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit.
132        let max_height = dimensions.iter().map(|dim| dim.height).max().unwrap();
133        let log_max_height = log2_ceil_usize(max_height);
134        if proof.len() != log_max_height {
135            return Err(WrongHeight {
136                max_height,
137                num_siblings: proof.len(),
138            });
139        }
140
141        let mut heights_tallest_first = dimensions
142            .iter()
143            .enumerate()
144            .sorted_by_key(|(_, dims)| Reverse(dims.height))
145            .peekable();
146
147        let mut curr_height_padded = heights_tallest_first
148            .peek()
149            .unwrap()
150            .1
151            .height
152            .next_power_of_two();
153
154        let mut root = self.hash.hash_iter_slices(
155            heights_tallest_first
156                .peeking_take_while(|(_, dims)| {
157                    dims.height.next_power_of_two() == curr_height_padded
158                })
159                .map(|(i, _)| opened_values[i].as_slice()),
160        );
161
162        for &sibling in proof.iter() {
163            let (left, right) = if index & 1 == 0 {
164                (root, sibling)
165            } else {
166                (sibling, root)
167            };
168
169            root = self.compress.compress([left, right]);
170            index >>= 1;
171            curr_height_padded >>= 1;
172
173            let next_height = heights_tallest_first
174                .peek()
175                .map(|(_, dims)| dims.height)
176                .filter(|h| h.next_power_of_two() == curr_height_padded);
177            if let Some(next_height) = next_height {
178                let next_height_openings_digest = self.hash.hash_iter_slices(
179                    heights_tallest_first
180                        .peeking_take_while(|(_, dims)| dims.height == next_height)
181                        .map(|(i, _)| opened_values[i].as_slice()),
182                );
183
184                root = self.compress.compress([root, next_height_openings_digest]);
185            }
186        }
187
188        if commit == &root {
189            Ok(())
190        } else {
191            Err(RootMismatch)
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use alloc::vec;
199
200    use itertools::Itertools;
201    use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear};
202    use p3_commit::Mmcs;
203    use p3_field::{AbstractField, Field};
204    use p3_matrix::dense::RowMajorMatrix;
205    use p3_matrix::{Dimensions, Matrix};
206    use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
207    use p3_symmetric::{
208        CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
209    };
210    use rand::thread_rng;
211
212    use super::FieldMerkleTreeMmcs;
213
214    type F = BabyBear;
215
216    type Perm = Poseidon2<F, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>;
217    type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
218    type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
219    type MyMmcs =
220        FieldMerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
221
222    #[test]
223    fn commit_single_1x8() {
224        let perm = Perm::new_from_rng_128(
225            Poseidon2ExternalMatrixGeneral,
226            DiffusionMatrixBabyBear::default(),
227            &mut thread_rng(),
228        );
229        let hash = MyHash::new(perm.clone());
230        let compress = MyCompress::new(perm);
231        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
232
233        // v = [2, 1, 2, 2, 0, 0, 1, 0]
234        let v = vec![
235            F::two(),
236            F::one(),
237            F::two(),
238            F::two(),
239            F::zero(),
240            F::zero(),
241            F::one(),
242            F::zero(),
243        ];
244        let (commit, _) = mmcs.commit_vec(v.clone());
245
246        let expected_result = compress.compress([
247            compress.compress([
248                compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
249                compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
250            ]),
251            compress.compress([
252                compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
253                compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
254            ]),
255        ]);
256        assert_eq!(commit, expected_result);
257    }
258
259    #[test]
260    fn commit_single_2x2() {
261        let perm = Perm::new_from_rng_128(
262            Poseidon2ExternalMatrixGeneral,
263            DiffusionMatrixBabyBear::default(),
264            &mut thread_rng(),
265        );
266        let hash = MyHash::new(perm.clone());
267        let compress = MyCompress::new(perm);
268        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
269
270        // mat = [
271        //   0 1
272        //   2 1
273        // ]
274        let mat = RowMajorMatrix::new(vec![F::zero(), F::one(), F::two(), F::one()], 2);
275
276        let (commit, _) = mmcs.commit(vec![mat]);
277
278        let expected_result = compress.compress([
279            hash.hash_slice(&[F::zero(), F::one()]),
280            hash.hash_slice(&[F::two(), F::one()]),
281        ]);
282        assert_eq!(commit, expected_result);
283    }
284
285    #[test]
286    fn commit_single_2x3() {
287        let perm = Perm::new_from_rng_128(
288            Poseidon2ExternalMatrixGeneral,
289            DiffusionMatrixBabyBear::default(),
290            &mut thread_rng(),
291        );
292        let hash = MyHash::new(perm.clone());
293        let compress = MyCompress::new(perm);
294        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
295        let default_digest = [F::zero(); 8];
296
297        // mat = [
298        //   0 1
299        //   2 1
300        //   2 2
301        // ]
302        let mat = RowMajorMatrix::new(
303            vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()],
304            2,
305        );
306
307        let (commit, _) = mmcs.commit(vec![mat]);
308
309        let expected_result = compress.compress([
310            compress.compress([
311                hash.hash_slice(&[F::zero(), F::one()]),
312                hash.hash_slice(&[F::two(), F::one()]),
313            ]),
314            compress.compress([hash.hash_slice(&[F::two(), F::two()]), default_digest]),
315        ]);
316        assert_eq!(commit, expected_result);
317    }
318
319    #[test]
320    fn commit_mixed() {
321        let perm = Perm::new_from_rng_128(
322            Poseidon2ExternalMatrixGeneral,
323            DiffusionMatrixBabyBear::default(),
324            &mut thread_rng(),
325        );
326        let hash = MyHash::new(perm.clone());
327        let compress = MyCompress::new(perm);
328        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
329        let default_digest = [F::zero(); 8];
330
331        // mat_1 = [
332        //   0 1
333        //   2 1
334        //   2 2
335        // ]
336        let mat_1 = RowMajorMatrix::new(
337            vec![F::zero(), F::one(), F::two(), F::one(), F::two(), F::two()],
338            2,
339        );
340        // mat_2 = [
341        //   1 2 1
342        //   0 2 2
343        // ]
344        let mat_2 = RowMajorMatrix::new(
345            vec![F::one(), F::two(), F::one(), F::zero(), F::two(), F::two()],
346            3,
347        );
348
349        let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
350
351        let mat_1_leaf_hashes = [
352            hash.hash_slice(&[F::zero(), F::one()]),
353            hash.hash_slice(&[F::two(), F::one()]),
354            hash.hash_slice(&[F::two(), F::two()]),
355        ];
356        let mat_2_leaf_hashes = [
357            hash.hash_slice(&[F::one(), F::two(), F::one()]),
358            hash.hash_slice(&[F::zero(), F::two(), F::two()]),
359        ];
360
361        let expected_result = compress.compress([
362            compress.compress([
363                compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
364                mat_2_leaf_hashes[0],
365            ]),
366            compress.compress([
367                compress.compress([mat_1_leaf_hashes[2], default_digest]),
368                mat_2_leaf_hashes[1],
369            ]),
370        ]);
371        assert_eq!(commit, expected_result);
372
373        let (opened_values, _proof) = mmcs.open_batch(2, &prover_data);
374        assert_eq!(
375            opened_values,
376            vec![
377                vec![F::two(), F::two()],
378                vec![F::zero(), F::two(), F::two()]
379            ]
380        );
381    }
382
383    #[test]
384    fn commit_either_order() {
385        let mut rng = thread_rng();
386        let perm = Perm::new_from_rng_128(
387            Poseidon2ExternalMatrixGeneral,
388            DiffusionMatrixBabyBear::default(),
389            &mut rng,
390        );
391        let hash = MyHash::new(perm.clone());
392        let compress = MyCompress::new(perm);
393        let mmcs = MyMmcs::new(hash, compress);
394
395        let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
396        let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
397
398        let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
399        let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
400        assert_eq!(commit_1_2, commit_2_1);
401    }
402
403    #[test]
404    #[should_panic]
405    fn mismatched_heights() {
406        let mut rng = thread_rng();
407        let perm = Perm::new_from_rng_128(
408            Poseidon2ExternalMatrixGeneral,
409            DiffusionMatrixBabyBear::default(),
410            &mut rng,
411        );
412        let hash = MyHash::new(perm.clone());
413        let compress = MyCompress::new(perm);
414        let mmcs = MyMmcs::new(hash, compress);
415
416        // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic.
417        let large_mat = RowMajorMatrix::new(
418            [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
419            1,
420        );
421        let small_mat =
422            RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
423        let _ = mmcs.commit(vec![large_mat, small_mat]);
424    }
425
426    #[test]
427    fn verify_tampered_proof_fails() {
428        let mut rng = thread_rng();
429        let perm = Perm::new_from_rng_128(
430            Poseidon2ExternalMatrixGeneral,
431            DiffusionMatrixBabyBear::default(),
432            &mut rng,
433        );
434        let hash = MyHash::new(perm.clone());
435        let compress = MyCompress::new(perm);
436        let mmcs = MyMmcs::new(hash, compress);
437
438        // 4 8x1 matrixes, 4 8x2 matrixes
439        let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 1));
440        let large_mat_dims = (0..4).map(|_| Dimensions {
441            height: 8,
442            width: 1,
443        });
444        let small_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 2));
445        let small_mat_dims = (0..4).map(|_| Dimensions {
446            height: 8,
447            width: 2,
448        });
449
450        let (commit, prover_data) = mmcs.commit(large_mats.chain(small_mats).collect_vec());
451
452        // open the 3rd row of each matrix, mess with proof, and verify
453        let (opened_values, mut proof) = mmcs.open_batch(3, &prover_data);
454        proof[0][0] += F::one();
455        mmcs.verify_batch(
456            &commit,
457            &large_mat_dims.chain(small_mat_dims).collect_vec(),
458            3,
459            &opened_values,
460            &proof,
461        )
462        .expect_err("expected verification to fail");
463    }
464
465    #[test]
466    fn size_gaps() {
467        let mut rng = thread_rng();
468        let perm = Perm::new_from_rng_128(
469            Poseidon2ExternalMatrixGeneral,
470            DiffusionMatrixBabyBear::default(),
471            &mut rng,
472        );
473        let hash = MyHash::new(perm.clone());
474        let compress = MyCompress::new(perm);
475        let mmcs = MyMmcs::new(hash, compress);
476
477        // 4 mats with 1000 rows, 8 columns
478        let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1000, 8));
479        let large_mat_dims = (0..4).map(|_| Dimensions {
480            height: 1000,
481            width: 8,
482        });
483
484        // 5 mats with 70 rows, 8 columns
485        let medium_mats = (0..5).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 70, 8));
486        let medium_mat_dims = (0..5).map(|_| Dimensions {
487            height: 70,
488            width: 8,
489        });
490
491        // 6 mats with 8 rows, 8 columns
492        let small_mats = (0..6).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 8));
493        let small_mat_dims = (0..6).map(|_| Dimensions {
494            height: 8,
495            width: 8,
496        });
497
498        let (commit, prover_data) = mmcs.commit(
499            large_mats
500                .chain(medium_mats)
501                .chain(small_mats)
502                .collect_vec(),
503        );
504
505        // open the 6th row of each matrix and verify
506        let (opened_values, proof) = mmcs.open_batch(6, &prover_data);
507        mmcs.verify_batch(
508            &commit,
509            &large_mat_dims
510                .chain(medium_mat_dims)
511                .chain(small_mat_dims)
512                .collect_vec(),
513            6,
514            &opened_values,
515            &proof,
516        )
517        .expect("expected verification to succeed");
518    }
519
520    #[test]
521    fn different_widths() {
522        let mut rng = thread_rng();
523        let perm = Perm::new_from_rng_128(
524            Poseidon2ExternalMatrixGeneral,
525            DiffusionMatrixBabyBear::default(),
526            &mut rng,
527        );
528        let hash = MyHash::new(perm.clone());
529        let compress = MyCompress::new(perm);
530        let mmcs = MyMmcs::new(hash, compress);
531
532        // 10 mats with 32 rows where the ith mat has i + 1 cols
533        let mats = (0..10)
534            .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
535            .collect_vec();
536        let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
537
538        let (commit, prover_data) = mmcs.commit(mats);
539        let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
540        mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
541            .expect("expected verification to succeed");
542    }
543}