Skip to main content

slop_stacked/
prover.rs

1use serde::{Deserialize, Serialize};
2use slop_algebra::{Field, TwoAdicField};
3use slop_alloc::{CpuBackend, ToHost};
4use slop_basefold_prover::{BasefoldProver, BasefoldProverData, BasefoldProverError};
5use slop_challenger::IopCtx;
6use slop_commit::{Message, Rounds};
7use slop_merkle_tree::ComputeTcsOpenings;
8use slop_multilinear::{Evaluations, Mle, MleEval, MultilinearPcsProver, Point, ToMle};
9use std::fmt::Debug;
10
11use crate::{interleave_multilinears_with_fixed_rate, StackedBasefoldProof};
12
13#[derive(Clone)]
14pub struct StackedPcsProver<P: ComputeTcsOpenings<GC, CpuBackend>, GC: IopCtx<F: TwoAdicField>> {
15    basefold_prover: BasefoldProver<GC, P>,
16    pub log_stacking_height: u32,
17    pub batch_size: usize,
18    _marker: std::marker::PhantomData<GC>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct StackedBasefoldProverData<M, F, TcsProverData> {
23    pcs_batch_data: BasefoldProverData<F, TcsProverData>,
24    pub interleaved_mles: Message<M>,
25}
26
27impl<F: Field, PD> ToMle<F> for StackedBasefoldProverData<Mle<F>, F, PD> {
28    fn interleaved_mles(&self) -> Message<Mle<F, CpuBackend>> {
29        self.interleaved_mles.clone()
30    }
31}
32
33impl<GC, P> StackedPcsProver<P, GC>
34where
35    GC: IopCtx<F: TwoAdicField, EF: TwoAdicField>,
36    P: ComputeTcsOpenings<GC, CpuBackend>,
37{
38    pub const fn new(
39        basefold_prover: BasefoldProver<GC, P>,
40        log_stacking_height: u32,
41        batch_size: usize,
42    ) -> Self {
43        Self { basefold_prover, log_stacking_height, batch_size, _marker: std::marker::PhantomData }
44    }
45
46    pub fn round_batch_evaluations(
47        &self,
48        stacked_point: &Point<GC::EF>,
49        prover_data: &StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>,
50    ) -> Evaluations<GC::EF> {
51        prover_data
52            .interleaved_mles
53            .iter()
54            .map(|mle| mle.eval_at(stacked_point))
55            .collect::<Evaluations<_, _>>()
56    }
57
58    #[allow(clippy::type_complexity)]
59    pub fn commit_multilinears(
60        &self,
61        multilinears: Message<Mle<GC::F>>,
62    ) -> Result<
63        (GC::Digest, StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>, usize),
64        BasefoldProverError<P::ProverError>,
65    > {
66        // To commit to the batch of padded Mles, the underlying PCS prover commits to the dense
67        // representation of all of these Mles (i.e. a single "giga" Mle consisting of all the
68        // entries of all the individual Mles),
69        // padding the total area to the next multiple of the stacking height.
70        let next_multiple = multilinears
71            .iter()
72            .map(|mle| mle.num_non_zero_entries() * mle.num_polynomials())
73            .sum::<usize>()
74            .next_multiple_of(1 << self.log_stacking_height)
75            // Need to pad to at least one column.
76            .max(1 << self.log_stacking_height);
77
78        let num_added_vals = next_multiple
79            - multilinears
80                .iter()
81                .map(|mle| mle.num_non_zero_entries() * mle.num_polynomials())
82                .sum::<usize>();
83
84        let interleaved_mles = interleave_multilinears_with_fixed_rate(
85            self.batch_size,
86            multilinears,
87            self.log_stacking_height,
88        );
89        let (commit, pcs_batch_data) =
90            self.basefold_prover.commit_mles(interleaved_mles.clone())?;
91        let prover_data = StackedBasefoldProverData { pcs_batch_data, interleaved_mles };
92
93        Ok((commit, prover_data, num_added_vals))
94    }
95}
96
97impl<GC: IopCtx<F: TwoAdicField, EF: TwoAdicField>, P: ComputeTcsOpenings<GC, CpuBackend>>
98    MultilinearPcsProver<GC, StackedBasefoldProof<GC>> for StackedPcsProver<P, GC>
99{
100    type ProverData = StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>;
101
102    type ProverError = BasefoldProverError<P::ProverError>;
103
104    fn commit_multilinear(
105        &self,
106        mles: Message<Mle<<GC as IopCtx>::F>>,
107    ) -> Result<(<GC as IopCtx>::Digest, Self::ProverData, usize), Self::ProverError> {
108        self.commit_multilinears(mles)
109    }
110
111    fn prove_trusted_evaluation(
112        &self,
113        eval_point: Point<<GC as IopCtx>::EF>,
114        _evaluation_claim: <GC as IopCtx>::EF,
115        prover_data: Rounds<Self::ProverData>,
116        challenger: &mut <GC as IopCtx>::Challenger,
117    ) -> Result<StackedBasefoldProof<GC>, Self::ProverError> {
118        let (_, stack_point) =
119            eval_point.split_at(eval_point.dimension() - self.log_stacking_height as usize);
120        let batch_evaluations: Rounds<_> = prover_data
121            .iter()
122            .map(|data| self.round_batch_evaluations(&stack_point, data))
123            .collect();
124
125        let mut host_batch_evaluations = Rounds::new();
126        for round_evals in batch_evaluations.iter() {
127            let mut host_round_evals = vec![];
128            for eval in round_evals.iter() {
129                let host_eval = eval.to_host().unwrap();
130                host_round_evals.extend(host_eval);
131            }
132            let host_round_evals = Evaluations::new(vec![host_round_evals.into()]);
133            host_batch_evaluations.push(host_round_evals);
134        }
135        let (pcs_prover_data, mle_rounds): (Rounds<_>, Rounds<_>) = prover_data
136            .into_iter()
137            .map(|data| (data.pcs_batch_data, data.interleaved_mles))
138            .unzip();
139
140        let (_, stack_point) =
141            eval_point.split_at(eval_point.dimension() - self.log_stacking_height as usize);
142
143        let pcs_proof = self.basefold_prover.prove_untrusted_evaluations(
144            stack_point,
145            mle_rounds,
146            batch_evaluations,
147            pcs_prover_data,
148            challenger,
149        )?;
150
151        let host_batch_evaluations = host_batch_evaluations
152            .into_iter()
153            .map(|round| round.into_iter().flatten().collect::<MleEval<_>>())
154            .collect::<Rounds<_>>();
155
156        Ok(StackedBasefoldProof {
157            basefold_proof: pcs_proof,
158            batch_evaluations: host_batch_evaluations,
159        })
160    }
161
162    fn log_max_padding_amount(&self) -> u32 {
163        self.log_stacking_height
164    }
165}
166#[cfg(test)]
167mod tests {
168    use rand::thread_rng;
169    use slop_algebra::extension::BinomialExtensionField;
170    use slop_baby_bear::{baby_bear_poseidon2::BabyBearDegree4Duplex, BabyBear};
171    use slop_basefold::{BasefoldVerifier, FriConfig};
172    use slop_basefold_prover::BasefoldProver;
173    use slop_challenger::CanObserve;
174    use slop_merkle_tree::Poseidon2BabyBear16Prover;
175    use slop_tensor::Tensor;
176
177    use crate::StackedPcsVerifier;
178
179    use super::*;
180
181    #[test]
182    fn test_stacked_prover_with_fixed_rate_interleave() {
183        let log_stacking_height = 10;
184        let batch_size = 10;
185
186        type GC = BabyBearDegree4Duplex;
187        type Prover = BasefoldProver<GC, Poseidon2BabyBear16Prover>;
188        type EF = BinomialExtensionField<BabyBear, 4>;
189
190        let round_widths_and_log_heights = [vec![(1 << 10, 10), (1 << 4, 11), (496, 11)]];
191
192        let total_data_length = round_widths_and_log_heights
193            .iter()
194            .map(|dims| dims.iter().map(|&(w, log_h)| w << log_h).sum::<usize>())
195            .sum::<usize>();
196        let total_number_of_variables = total_data_length.next_power_of_two().ilog2();
197        assert_eq!(1 << total_number_of_variables, total_data_length);
198        let round_areas = round_widths_and_log_heights
199            .iter()
200            .map(|dims| {
201                dims.iter()
202                    .map(|&(w, log_h)| w << log_h)
203                    .sum::<usize>()
204                    .next_multiple_of(1 << log_stacking_height)
205            })
206            .collect::<Vec<_>>();
207
208        let mut rng = thread_rng();
209        let round_mles = round_widths_and_log_heights
210            .iter()
211            .map(|dims| {
212                dims.iter()
213                    .map(|&(w, log_h)| Mle::<BabyBear>::rand(&mut rng, w, log_h))
214                    .collect::<Message<_>>()
215            })
216            .collect::<Rounds<_>>();
217
218        let pcs_verifier = BasefoldVerifier::<GC>::new(
219            FriConfig::default_fri_config(),
220            round_widths_and_log_heights.len(),
221        );
222        let pcs_prover = Prover::new(&pcs_verifier);
223
224        let verifier = StackedPcsVerifier::new(pcs_verifier, log_stacking_height);
225        let prover = StackedPcsProver::new(pcs_prover, log_stacking_height, batch_size);
226
227        let mut challenger = GC::default_challenger();
228        let mut commitments = vec![];
229        let mut prover_data = Rounds::new();
230        let mut batch_evaluations = Rounds::new();
231        let point = Point::<EF>::rand(&mut rng, total_number_of_variables);
232
233        let concat_mle: Vec<BabyBear> = round_mles
234            .iter()
235            .flat_map(|mles| mles.iter())
236            .flat_map(|mle| mle.guts().transpose().as_slice().to_vec())
237            .collect();
238
239        let concat_mle =
240            Mle::new(Tensor::from(concat_mle).reshape([1 << total_number_of_variables, 1]));
241
242        let concat_eval_claim = concat_mle.eval_at(&point)[0];
243
244        let (batch_point, stack_point) =
245            point.split_at(point.dimension() - log_stacking_height as usize);
246        for mles in round_mles.iter() {
247            let (commitment, data, _) = prover.commit_multilinears(mles.clone()).unwrap();
248            challenger.observe(commitment);
249            commitments.push(commitment);
250            let evaluations = prover.round_batch_evaluations(&stack_point, &data);
251            prover_data.push(data);
252            batch_evaluations.push(evaluations);
253        }
254
255        // Interpolate the batch evaluations as a multilinear polynomial.
256        let batch_evaluations_mle =
257            batch_evaluations.iter().flatten().flatten().cloned().collect::<Mle<_>>();
258        // Verify that the climed evaluations matched the interpolated evaluations.
259        let eval_claim = batch_evaluations_mle.eval_at(&batch_point)[0];
260
261        assert_eq!(concat_eval_claim, eval_claim);
262
263        let proof = prover
264            .prove_trusted_evaluation(point.clone(), eval_claim, prover_data, &mut challenger)
265            .unwrap();
266
267        let mut challenger = GC::default_challenger();
268        for commitment in commitments.iter() {
269            challenger.observe(*commitment);
270        }
271        verifier
272            .verify_trusted_evaluation(
273                &commitments,
274                &round_areas,
275                &point,
276                &proof,
277                eval_claim,
278                &mut challenger,
279            )
280            .unwrap();
281    }
282}