1use alloc::vec::Vec;
2use core::cell::RefCell;
3
4use itertools::Itertools;
5use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
6use p3_field::PackedValue;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::stack::HorizontalPair;
9use p3_matrix::{Dimensions, Matrix};
10use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
11use p3_util::zip_eq::zip_eq;
12use rand::Rng;
13use rand::distr::{Distribution, StandardUniform};
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17use crate::{MerkleTree, MerkleTreeError, MerkleTreeMmcs};
18
19#[derive(Clone, Debug)]
39pub struct MerkleTreeHidingMmcs<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
40{
41 inner: MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>,
42 rng: RefCell<R>,
43}
44
45impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
46 MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
47{
48 pub const fn new(hash: H, compress: C, rng: R) -> Self {
49 let inner = MerkleTreeMmcs::new(hash, compress);
50 Self {
51 inner,
52 rng: RefCell::new(rng),
53 }
54 }
55}
56
57impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize> Mmcs<P::Value>
58 for MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
59where
60 P: PackedValue,
61 P::Value: Serialize + DeserializeOwned,
62 PW: PackedValue,
63 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
64 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
65 + Sync,
66 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
67 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
68 + Sync,
69 R: Rng + Clone,
70 PW::Value: Eq,
71 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
72 StandardUniform: Distribution<P::Value>,
73{
74 type ProverData<M> =
75 MerkleTree<P::Value, PW::Value, HorizontalPair<M, RowMajorMatrix<P::Value>>, DIGEST_ELEMS>;
76 type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
77 type Proof = (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>);
79 type Error = MerkleTreeError;
80
81 fn commit<M: Matrix<P::Value>>(
82 &self,
83 inputs: Vec<M>,
84 ) -> (Self::Commitment, Self::ProverData<M>) {
85 let salted_inputs = inputs
86 .into_iter()
87 .map(|mat| {
88 let salts =
89 RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), mat.height(), SALT_ELEMS);
90 HorizontalPair::new(mat, salts)
91 })
92 .collect();
93 self.inner.commit(salted_inputs)
94 }
95
96 fn open_batch<M: Matrix<P::Value>>(
97 &self,
98 index: usize,
99 prover_data: &Self::ProverData<M>,
100 ) -> BatchOpening<P::Value, Self> {
101 let (salted_openings, siblings) = self.inner.open_batch(index, prover_data).unpack();
102 let (openings, salts): (Vec<_>, Vec<_>) = salted_openings
103 .into_iter()
104 .map(|row| {
105 let (a, b) = row.split_at(row.len() - SALT_ELEMS);
106 (a.to_vec(), b.to_vec())
107 })
108 .unzip();
109 BatchOpening::new(openings, (salts, siblings))
110 }
111
112 fn get_matrices<'a, M: Matrix<P::Value>>(
113 &self,
114 prover_data: &'a Self::ProverData<M>,
115 ) -> Vec<&'a M> {
116 prover_data.leaves.iter().map(|mat| &mat.left).collect()
117 }
118
119 fn verify_batch(
120 &self,
121 commit: &Self::Commitment,
122 dimensions: &[Dimensions],
123 index: usize,
124 batch_opening: BatchOpeningRef<P::Value, Self>,
125 ) -> Result<(), Self::Error> {
126 let (opened_values, (salts, siblings)) = batch_opening.unpack();
127
128 let opened_salted_values = zip_eq(opened_values, salts, MerkleTreeError::WrongBatchSize)?
129 .map(|(opened, salt)| opened.iter().chain(salt.iter()).copied().collect_vec())
130 .collect_vec();
131
132 self.inner.verify_batch(
133 commit,
134 dimensions,
135 index,
136 BatchOpeningRef::new(&opened_salted_values, siblings),
137 )
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use alloc::vec;
144
145 use itertools::Itertools;
146 use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
147 use p3_commit::Mmcs;
148 use p3_field::{Field, PrimeCharacteristicRing};
149 use p3_matrix::Matrix;
150 use p3_matrix::dense::RowMajorMatrix;
151 use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
152 use rand::SeedableRng;
153 use rand::rngs::SmallRng;
154
155 use super::MerkleTreeHidingMmcs;
156 use crate::MerkleTreeError;
157
158 type F = BabyBear;
159 const SALT_ELEMS: usize = 4;
160
161 type Perm = Poseidon2BabyBear<16>;
162 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
163 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
164 type MyMmcs = MerkleTreeHidingMmcs<
165 <F as Field>::Packing,
166 <F as Field>::Packing,
167 MyHash,
168 MyCompress,
169 SmallRng,
170 8,
171 SALT_ELEMS,
172 >;
173
174 #[test]
175 #[should_panic]
176 fn mismatched_heights() {
177 let mut rng = SmallRng::seed_from_u64(1);
178 let perm = Perm::new_from_rng_128(&mut rng);
179 let hash = MyHash::new(perm.clone());
180 let compress = MyCompress::new(perm);
181 let mmcs = MyMmcs::new(hash, compress, rng);
182
183 let large_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u8).to_vec(), 1);
185 let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_u8).to_vec(), 1);
186 let _ = mmcs.commit(vec![large_mat, small_mat]);
187 }
188
189 #[test]
190 fn different_widths() -> Result<(), MerkleTreeError> {
191 let mut rng = SmallRng::seed_from_u64(1);
192 let mats = (0..10)
194 .map(|i| RowMajorMatrix::<F>::rand(&mut rng, 32, i + 1))
195 .collect_vec();
196 let perm = Perm::new_from_rng_128(&mut rng);
197 let hash = MyHash::new(perm.clone());
198 let compress = MyCompress::new(perm);
199 let mmcs = MyMmcs::new(hash, compress, rng);
200
201 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
202
203 let (commit, prover_data) = mmcs.commit(mats);
204 let batch_proof = mmcs.open_batch(17, &prover_data);
205 mmcs.verify_batch(&commit, &dims, 17, (&batch_proof).into())
206 }
207}