sp1_recursion_circuit/zerocheck.rs
1use std::{collections::BTreeSet, ops::Deref};
2
3use crate::{
4 challenger::{CanObserveVariable, FieldChallengerVariable},
5 shard::RecursiveShardVerifier,
6 sumcheck::verify_sumcheck,
7 symbolic::IntoSymbolic,
8 CircuitConfig, SP1FieldConfigVariable,
9};
10use itertools::Itertools;
11use slop_air::{Air, BaseAir};
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_matrix::dense::RowMajorMatrixView;
15use slop_multilinear::{full_geq, Mle, Point};
16use slop_sumcheck::PartialSumcheckProof;
17use sp1_hypercube::{
18 air::MachineAir, Chip, ChipOpenedValues, GenericVerifierConstraintFolder, LogUpEvaluations,
19 OpeningShapeError, ShardOpenedValues,
20};
21use sp1_primitives::{SP1ExtensionField, SP1Field};
22use sp1_recursion_compiler::{
23 ir::Felt,
24 prelude::{Builder, Ext, SymbolicExt},
25};
26
27pub type RecursiveVerifierConstraintFolder<'a> = GenericVerifierConstraintFolder<
28 'a,
29 SP1Field,
30 SP1ExtensionField,
31 Felt<SP1Field>,
32 Ext<SP1Field, SP1ExtensionField>,
33 SymbolicExt<SP1Field, SP1ExtensionField>,
34>;
35
36#[allow(clippy::type_complexity)]
37pub fn eval_constraints<C: CircuitConfig, SC: SP1FieldConfigVariable<C>, A>(
38 builder: &mut Builder<C>,
39 chip: &Chip<SP1Field, A>,
40 opening: &ChipOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
41 alpha: Ext<SP1Field, SP1ExtensionField>,
42 public_values: &[Felt<SP1Field>],
43) -> Ext<SP1Field, SP1ExtensionField>
44where
45 A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
46{
47 let mut folder = RecursiveVerifierConstraintFolder {
48 preprocessed: RowMajorMatrixView::new_row(&opening.preprocessed.local),
49 main: RowMajorMatrixView::new_row(&opening.main.local),
50 public_values,
51 alpha,
52 accumulator: SymbolicExt::zero(),
53 _marker: std::marker::PhantomData,
54 };
55
56 chip.eval(&mut folder);
57 builder.eval(folder.accumulator)
58}
59
60/// Compute the padded row adjustment for a chip.
61pub fn compute_padded_row_adjustment<C: CircuitConfig, A>(
62 builder: &mut Builder<C>,
63 chip: &Chip<SP1Field, A>,
64 alpha: Ext<SP1Field, SP1ExtensionField>,
65 public_values: &[Felt<SP1Field>],
66) -> Ext<SP1Field, SP1ExtensionField>
67where
68 A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
69{
70 let zero = builder.constant(SP1ExtensionField::zero());
71 let dummy_preprocessed_trace = vec![zero; chip.preprocessed_width()];
72 let dummy_main_trace = vec![zero; chip.width()];
73
74 let mut folder = RecursiveVerifierConstraintFolder {
75 preprocessed: RowMajorMatrixView::new_row(&dummy_preprocessed_trace),
76 main: RowMajorMatrixView::new_row(&dummy_main_trace),
77 alpha,
78 accumulator: SymbolicExt::zero(),
79 public_values,
80 _marker: std::marker::PhantomData,
81 };
82
83 chip.eval(&mut folder);
84 builder.eval(folder.accumulator)
85}
86
87#[allow(clippy::type_complexity)]
88pub fn verify_opening_shape<C: CircuitConfig, A>(
89 chip: &Chip<SP1Field, A>,
90 opening: &ChipOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
91) -> Result<(), OpeningShapeError>
92where
93 A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
94{
95 // Verify that the preprocessed width matches the expected value for the chip.
96 if opening.preprocessed.local.len() != chip.preprocessed_width() {
97 return Err(OpeningShapeError::PreprocessedWidthMismatch(
98 chip.preprocessed_width(),
99 opening.preprocessed.local.len(),
100 ));
101 }
102
103 // Verify that the main width matches the expected value for the chip.
104 if opening.main.local.len() != chip.width() {
105 return Err(OpeningShapeError::MainWidthMismatch(chip.width(), opening.main.local.len()));
106 }
107
108 Ok(())
109}
110
111impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
112where
113 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
114 C: CircuitConfig,
115 A: MachineAir<SP1Field>,
116{
117 #[allow(clippy::too_many_arguments)]
118 #[allow(clippy::type_complexity)]
119 pub fn verify_zerocheck(
120 &self,
121 builder: &mut Builder<C>,
122 shard_chips: &BTreeSet<Chip<SP1Field, A>>,
123 opened_values: &ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
124 gkr_evaluations: &LogUpEvaluations<Ext<SP1Field, SP1ExtensionField>>,
125 zerocheck_proof: &PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
126 public_values: &[Felt<SP1Field>],
127 challenger: &mut GC::FriChallengerVariable,
128 ) where
129 A: for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
130 {
131 let zero: Ext<SP1Field, SP1ExtensionField> = builder.constant(SP1ExtensionField::zero());
132 let one: Ext<SP1Field, SP1ExtensionField> = builder.constant(SP1ExtensionField::one());
133 let mut rlc_eval: Ext<SP1Field, SP1ExtensionField> = zero;
134
135 let alpha = challenger.sample_ext(builder);
136 let gkr_batch_open_challenge: SymbolicExt<SP1Field, SP1ExtensionField> =
137 challenger.sample_ext(builder).into();
138 let lambda = challenger.sample_ext(builder);
139
140 // Get the value of eq(zeta, sumcheck's reduced point).
141 let point_symbolic =
142 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(
143 &zerocheck_proof.point_and_eval.0,
144 );
145
146 let gkr_evaluations_point = IntoSymbolic::<C>::as_symbolic(&gkr_evaluations.point);
147
148 let zerocheck_eq_val = Mle::full_lagrange_eval(&gkr_evaluations_point, &point_symbolic);
149
150 let max_elements = shard_chips
151 .iter()
152 .map(|chip| chip.width() + chip.preprocessed_width())
153 .max()
154 .unwrap_or(0);
155
156 let gkr_batch_open_challenge_powers =
157 gkr_batch_open_challenge.powers().skip(1).take(max_elements).collect::<Vec<_>>();
158
159 for (chip, openings) in shard_chips.iter().zip_eq(opened_values.chips.values()) {
160 // Verify the shape of the opening arguments matches the expected values.
161 verify_opening_shape::<C, A>(chip, openings).unwrap();
162
163 let dimension = zerocheck_proof.point_and_eval.0.dimension();
164
165 assert_eq!(dimension, self.pcs_verifier.max_log_row_count);
166
167 let mut proof_point_extended = point_symbolic.clone();
168 proof_point_extended.add_dimension(zero.into());
169 let degree_symbolic_ext: Point<SymbolicExt<SP1Field, SP1ExtensionField>> =
170 openings.degree.iter().map(|x| SymbolicExt::from(*x)).collect::<Point<_>>();
171 degree_symbolic_ext.iter().enumerate().for_each(|(i, x)| {
172 builder.assert_ext_eq(*x * (*x - one), zero);
173 if i >= 1 {
174 builder.assert_ext_eq(*x * *degree_symbolic_ext.first().unwrap(), zero);
175 }
176 });
177 let geq_val = full_geq(°ree_symbolic_ext, &proof_point_extended);
178
179 let padded_row_adjustment =
180 compute_padded_row_adjustment(builder, chip, alpha, public_values);
181
182 let constraint_eval =
183 eval_constraints::<C, GC, A>(builder, chip, openings, alpha, public_values)
184 - padded_row_adjustment * geq_val;
185
186 let openings_batch = openings
187 .main
188 .local
189 .iter()
190 .chain(openings.preprocessed.local.iter())
191 .copied()
192 .zip(
193 gkr_batch_open_challenge_powers
194 .iter()
195 .take(openings.main.local.len() + openings.preprocessed.local.len())
196 .copied(),
197 )
198 .map(|(opening, power)| opening * power)
199 .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>();
200
201 rlc_eval = builder
202 .eval(rlc_eval * lambda + zerocheck_eq_val * (constraint_eval + openings_batch));
203 }
204
205 builder.assert_ext_eq(rlc_eval, zerocheck_proof.point_and_eval.1);
206
207 let zerocheck_sum_modifications_from_gkr = gkr_evaluations
208 .chip_openings
209 .values()
210 .map(|chip_evaluation| {
211 chip_evaluation
212 .main_trace_evaluations
213 .deref()
214 .iter()
215 .copied()
216 .chain(
217 chip_evaluation
218 .preprocessed_trace_evaluations
219 .as_ref()
220 .iter()
221 .flat_map(|&evals| evals.deref().iter().copied()),
222 )
223 .zip(gkr_batch_open_challenge_powers.iter().copied())
224 .map(|(opening, power)| opening * power)
225 .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>()
226 })
227 .collect::<Vec<_>>();
228
229 let zerocheck_sum_modification: SymbolicExt<SP1Field, SP1ExtensionField> =
230 zerocheck_sum_modifications_from_gkr
231 .iter()
232 .fold(zero.into(), |acc, modification| lambda * acc + *modification);
233
234 // Verify that the rlc claim is zero.
235 builder.assert_ext_eq(zerocheck_proof.claimed_sum, zerocheck_sum_modification);
236
237 // Verify the zerocheck proof.
238 verify_sumcheck::<C, GC>(builder, challenger, zerocheck_proof);
239
240 // Observe the openings
241 let len_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(shard_chips.len()));
242 challenger.observe(builder, len_felt);
243 for opening in opened_values.chips.values() {
244 challenger
245 .observe_variable_length_extension_slice(builder, &opening.preprocessed.local);
246 challenger.observe_variable_length_extension_slice(builder, &opening.main.local);
247 }
248 }
249}
250
251// TODO: Add tests back.
252// #[cfg(test)]
253// mod tests {
254// use std::{marker::PhantomData, sync::Arc};
255
256// use slop_algebra::extension::BinomialExtensionField;
257// use sp1_primitives::SP1DiffusionMatrix;
258// use slop_basefold::{BasefoldVerifier, SP1BasefoldConfig};
259// use slop_jagged::SP1InnerPcs;
260// use sp1_hypercube::inner_perm;
261// use sp1_core_executor::{Program, SP1Context};
262// use sp1_core_machine::{io::SP1Stdin, riscv::RiscvAir, utils::prove_core};
263// use sp1_recursion_compiler::{
264// circuit::{AsmCompiler, AsmConfig},
265// config::InnerConfig,
266// };
267// use sp1_recursion_executor::Runtime;
268// use sp1_hypercube::{prover::CpuProver, SP1CoreOpts, ShardVerifier};
269
270// use crate::{
271// basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
272// challenger::DuplexChallengerVariable,
273// jagged::{
274// RecursiveJaggedConfigImpl, RecursiveJaggedEvalSumcheckConfig,
275// RecursiveJaggedPcsVerifier,
276// },
277// witness::Witnessable,
278// };
279
280// use super::*;
281
282// use sp1_primitives::SP1Field;
283// type F = SP1Field;
284// type SC = SP1InnerPcs;
285// type JC = RecursiveJaggedConfigImpl<
286// C,
287// SC,
288// RecursiveBasefoldVerifier<RecursiveBasefoldConfigImpl<C, SC>>,
289// >;
290// type C = InnerConfig;
291// type EF = BinomialExtensionField<SP1Field, 4>;
292// type A = RiscvAir<SP1Field>;
293
294// #[tokio::test]
295// async fn test_zerocheck() {
296// let program = Program::from(test_artifacts::FIBONACCI_ELF).unwrap();
297// let log_blowup = 1;
298// let log_stacking_height = 21;
299// let max_log_row_count = 21;
300// let machine = RiscvAir::machine();
301// let verifier = ShardVerifier::from_basefold_parameters(
302// log_blowup,
303// log_stacking_height,
304// max_log_row_count,
305// machine.clone(),
306// );
307// let prover = CpuProver::new(verifier.clone());
308
309// let (pk, _) = prover.setup(Arc::new(program.clone())).await;
310
311// let challenger = verifier.pcs_verifier.challenger();
312
313// let (proof, _) = prove_core(
314// Arc::new(prover),
315// Arc::new(pk),
316// Arc::new(program.clone()),
317// &SP1Stdin::new(),
318// SP1CoreOpts::default(),
319// SP1Context::default(),
320// challenger,
321// )
322// .await
323// .unwrap();
324
325// let shard_proof = proof.shard_proofs[0].clone();
326// let challenger_state = shard_proof.testing_data.challenger_state.clone();
327
328// let mut builder = Builder::<C>::default();
329
330// let mut challenger_variable =
331// DuplexChallengerVariable::from_challenger(&mut builder, &challenger_state);
332
333// let shard_proof_variable = shard_proof.read(&mut builder);
334
335// let gkr_points_variable = shard_proof.testing_data.gkr_points.read(&mut builder);
336// let gkr_column_openings_variable = shard_proof
337// .gkr_proofs
338// .iter()
339// .map(|gkr_proof| {
340// let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
341// let main_openings_variable = main_openings.read(&mut builder);
342// let preprocessed_openings_variable: MleEval<Ext<_, _>> = preprocessed_openings
343// .as_ref()
344// .map(MleEval::to_vec)
345// .unwrap_or_default()
346// .read(&mut builder)
347// .into();
348// (main_openings_variable, preprocessed_openings_variable)
349// })
350// .collect::<Vec<_>>();
351
352// let verifier = BasefoldVerifier::<SP1BasefoldConfig>::new(log_blowup);
353// let recursive_verifier = RecursiveBasefoldVerifier::<RecursiveBasefoldConfigImpl<C, SC>>
354// { fri_config: verifier.fri_config,
355// tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
356// };
357// let recursive_verifier =
358// RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
359
360// let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<
361// SC,
362// C,
363// RecursiveJaggedConfigImpl<
364// C,
365// SC,
366// RecursiveBasefoldVerifier<RecursiveBasefoldConfigImpl<C, SC>>,
367// >,
368// > { stacked_pcs_verifier: recursive_verifier, max_log_row_count, jagged_evaluator:
369// > RecursiveJaggedEvalSumcheckConfig::<SP1InnerPcs>(PhantomData),
370// };
371
372// let stark_verifier = StarkVerifier::<A, SC, C, JC> {
373// machine,
374// pcs_verifier: recursive_jagged_verifier,
375// _phantom: std::marker::PhantomData,
376// };
377
378// stark_verifier.verify_zerocheck(
379// &mut builder,
380// &mut challenger_variable,
381// &shard_proof_variable.opened_values,
382// &shard_proof_variable.zerocheck_proof,
383// &gkr_points_variable,
384// &gkr_column_openings_variable,
385// &shard_proof_variable.public_values,
386// );
387
388// let mut witness_stream = Vec::new();
389// Witnessable::<AsmConfig<F, EF>>::write(&shard_proof, &mut witness_stream);
390// Witnessable::<AsmConfig<F, EF>>::write(
391// &shard_proof.testing_data.gkr_points,
392// &mut witness_stream,
393// );
394// shard_proof.gkr_proofs.iter().for_each(|gkr_proof| {
395// let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
396// Witnessable::<AsmConfig<F, EF>>::write(main_openings, &mut witness_stream);
397// let preprocessed_openings_unwrapped: MleEval<_> =
398// preprocessed_openings.as_ref().map(MleEval::to_vec).unwrap_or_default().into();
399// Witnessable::<AsmConfig<F, EF>>::write(
400// &preprocessed_openings_unwrapped,
401// &mut witness_stream,
402// );
403// });
404
405// let block = builder.into_root_block();
406// let mut compiler = AsmCompiler::<AsmConfig<F, EF>>::default();
407// let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
408// let mut executor =
409// Runtime::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
410// executor.witness_stream = witness_stream.into();
411// executor.run().unwrap();
412
413// // Test for a bad zerocheck proof.
414// let mut invalid_shard_proof = shard_proof.clone();
415// invalid_shard_proof.zerocheck_proof.univariate_polys[0].coefficients[0] += EF::one();
416// let mut witness_stream = Vec::new();
417// Witnessable::<AsmConfig<F, EF>>::write(&invalid_shard_proof, &mut witness_stream);
418// Witnessable::<AsmConfig<F, EF>>::write(
419// &invalid_shard_proof.testing_data.gkr_points,
420// &mut witness_stream,
421// );
422// invalid_shard_proof.gkr_proofs.iter().for_each(|gkr_proof| {
423// let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
424// Witnessable::<AsmConfig<F, EF>>::write(main_openings, &mut witness_stream);
425// let preprocessed_openings_unwrapped: MleEval<_> =
426// preprocessed_openings.as_ref().map(MleEval::to_vec).unwrap_or_default().into();
427// Witnessable::<AsmConfig<F, EF>>::write(
428// &preprocessed_openings_unwrapped,
429// &mut witness_stream,
430// );
431// });
432// let mut executor = Runtime::<F, EF, SP1DiffusionMatrix>::new(program,
433// inner_perm()); executor.witness_stream = witness_stream.into();
434// executor.run().expect_err("invalid proof should not be verified");
435// }
436// }