p3_merkle_tree/
hiding_mmcs.rs

1use alloc::vec::Vec;
2use core::cell::RefCell;
3
4use itertools::Itertools;
5use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
6use p3_field::PackedValue;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::stack::HorizontalPair;
9use p3_matrix::{Dimensions, Matrix};
10use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
11use p3_util::zip_eq::zip_eq;
12use rand::Rng;
13use rand::distr::{Distribution, StandardUniform};
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17use crate::{MerkleTree, MerkleTreeError, MerkleTreeMmcs};
18
19/// A vector commitment scheme backed by a `MerkleTree`.
20///
21/// This is similar to `MerkleTreeMmcs`, but each leaf is "salted" with random elements. This is
22/// done to turn the Merkle tree into a hiding commitment. See e.g. Section 3 of
23/// [Interactive Oracle Proofs](https://eprint.iacr.org/2016/116).
24///
25/// `SALT_ELEMS` should be set such that the product of `SALT_ELEMS` with the size of the value
26/// (`P::Value`) is at least the target security parameter.
27///
28/// `R` should be an appropriately seeded cryptographically secure pseudorandom number generator
29/// (CSPRNG). Something like `ThreadRng` may work, although it relies on the operating system to
30/// provide sufficient entropy.
31///
32/// Generics:
33/// - `P`: a leaf value
34/// - `PW`: an element of a digest
35/// - `H`: the leaf hasher
36/// - `C`: the digest compression function
37/// - `R`: a random number generator for blinding leaves
38#[derive(Clone, Debug)]
39pub struct MerkleTreeHidingMmcs<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
40{
41    inner: MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>,
42    rng: RefCell<R>,
43}
44
45impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
46    MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
47{
48    pub const fn new(hash: H, compress: C, rng: R) -> Self {
49        let inner = MerkleTreeMmcs::new(hash, compress);
50        Self {
51            inner,
52            rng: RefCell::new(rng),
53        }
54    }
55}
56
57impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize> Mmcs<P::Value>
58    for MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
59where
60    P: PackedValue,
61    P::Value: Serialize + DeserializeOwned,
62    PW: PackedValue,
63    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
64        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
65        + Sync,
66    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
67        + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
68        + Sync,
69    R: Rng + Clone,
70    PW::Value: Eq,
71    [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
72    StandardUniform: Distribution<P::Value>,
73{
74    type ProverData<M> =
75        MerkleTree<P::Value, PW::Value, HorizontalPair<M, RowMajorMatrix<P::Value>>, DIGEST_ELEMS>;
76    type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
77    /// The first item is salts; the second is the usual Merkle proof (sibling digests).
78    type Proof = (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>);
79    type Error = MerkleTreeError;
80
81    fn commit<M: Matrix<P::Value>>(
82        &self,
83        inputs: Vec<M>,
84    ) -> (Self::Commitment, Self::ProverData<M>) {
85        let salted_inputs = inputs
86            .into_iter()
87            .map(|mat| {
88                let salts =
89                    RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), mat.height(), SALT_ELEMS);
90                HorizontalPair::new(mat, salts)
91            })
92            .collect();
93        self.inner.commit(salted_inputs)
94    }
95
96    fn open_batch<M: Matrix<P::Value>>(
97        &self,
98        index: usize,
99        prover_data: &Self::ProverData<M>,
100    ) -> BatchOpening<P::Value, Self> {
101        let (salted_openings, siblings) = self.inner.open_batch(index, prover_data).unpack();
102        let (openings, salts): (Vec<_>, Vec<_>) = salted_openings
103            .into_iter()
104            .map(|row| {
105                let (a, b) = row.split_at(row.len() - SALT_ELEMS);
106                (a.to_vec(), b.to_vec())
107            })
108            .unzip();
109        BatchOpening::new(openings, (salts, siblings))
110    }
111
112    fn get_matrices<'a, M: Matrix<P::Value>>(
113        &self,
114        prover_data: &'a Self::ProverData<M>,
115    ) -> Vec<&'a M> {
116        prover_data.leaves.iter().map(|mat| &mat.left).collect()
117    }
118
119    fn verify_batch(
120        &self,
121        commit: &Self::Commitment,
122        dimensions: &[Dimensions],
123        index: usize,
124        batch_opening: BatchOpeningRef<P::Value, Self>,
125    ) -> Result<(), Self::Error> {
126        let (opened_values, (salts, siblings)) = batch_opening.unpack();
127
128        let opened_salted_values = zip_eq(opened_values, salts, MerkleTreeError::WrongBatchSize)?
129            .map(|(opened, salt)| opened.iter().chain(salt.iter()).copied().collect_vec())
130            .collect_vec();
131
132        self.inner.verify_batch(
133            commit,
134            dimensions,
135            index,
136            BatchOpeningRef::new(&opened_salted_values, siblings),
137        )
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use alloc::vec;
144
145    use itertools::Itertools;
146    use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
147    use p3_commit::Mmcs;
148    use p3_field::{Field, PrimeCharacteristicRing};
149    use p3_matrix::Matrix;
150    use p3_matrix::dense::RowMajorMatrix;
151    use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
152    use rand::SeedableRng;
153    use rand::rngs::SmallRng;
154
155    use super::MerkleTreeHidingMmcs;
156    use crate::MerkleTreeError;
157
158    type F = BabyBear;
159    const SALT_ELEMS: usize = 4;
160
161    type Perm = Poseidon2BabyBear<16>;
162    type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
163    type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
164    type MyMmcs = MerkleTreeHidingMmcs<
165        <F as Field>::Packing,
166        <F as Field>::Packing,
167        MyHash,
168        MyCompress,
169        SmallRng,
170        8,
171        SALT_ELEMS,
172    >;
173
174    #[test]
175    #[should_panic]
176    fn mismatched_heights() {
177        let mut rng = SmallRng::seed_from_u64(1);
178        let perm = Perm::new_from_rng_128(&mut rng);
179        let hash = MyHash::new(perm.clone());
180        let compress = MyCompress::new(perm);
181        let mmcs = MyMmcs::new(hash, compress, rng);
182
183        // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic.
184        let large_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u8).to_vec(), 1);
185        let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_u8).to_vec(), 1);
186        let _ = mmcs.commit(vec![large_mat, small_mat]);
187    }
188
189    #[test]
190    fn different_widths() -> Result<(), MerkleTreeError> {
191        let mut rng = SmallRng::seed_from_u64(1);
192        // 10 mats with 32 rows where the ith mat has i + 1 cols
193        let mats = (0..10)
194            .map(|i| RowMajorMatrix::<F>::rand(&mut rng, 32, i + 1))
195            .collect_vec();
196        let perm = Perm::new_from_rng_128(&mut rng);
197        let hash = MyHash::new(perm.clone());
198        let compress = MyCompress::new(perm);
199        let mmcs = MyMmcs::new(hash, compress, rng);
200
201        let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
202
203        let (commit, prover_data) = mmcs.commit(mats);
204        let batch_proof = mmcs.open_batch(17, &prover_data);
205        mmcs.verify_batch(&commit, &dims, 17, (&batch_proof).into())
206    }
207}