1use crate::{
2 basefold::RecursiveBasefoldProof,
3 challenger::CanObserveVariable,
4 jagged::{
5 JaggedPcsProofVariable, RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier,
6 },
7 logup_gkr::RecursiveLogUpGkrVerifier,
8 zerocheck::RecursiveVerifierConstraintFolder,
9 CircuitConfig, SP1FieldConfigVariable,
10};
11use slop_air::Air;
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_commit::Rounds;
15use slop_multilinear::{Evaluations, MleEval};
16use slop_sumcheck::PartialSumcheckProof;
17
18use sp1_hypercube::{
19 air::MachineAir, septic_digest::SepticDigest, GenericVerifierPublicValuesConstraintFolder,
20 LogupGkrProof, Machine, ShardOpenedValues,
21};
22use sp1_primitives::{SP1ExtensionField, SP1Field};
23use sp1_recursion_compiler::{
24 circuit::CircuitV2Builder,
25 ir::{Builder, Felt, SymbolicExt},
26 prelude::{Ext, SymbolicFelt},
27};
28use sp1_recursion_executor::{DIGEST_SIZE, NUM_BITS};
29use std::collections::{BTreeMap, BTreeSet};
30
31#[allow(clippy::type_complexity)]
32pub struct ShardProofVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C> + Send + Sync> {
33 pub main_commitment: SC::DigestVariable,
35 pub opened_values: ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
37 pub zerocheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
39 pub public_values: Vec<Felt<SP1Field>>,
41 pub logup_gkr_proof: LogupGkrProof<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
43 pub evaluation_proof: JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
45}
46
47pub struct MachineVerifyingKeyVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C>> {
48 pub pc_start: [Felt<SP1Field>; 3],
49 pub initial_global_cumulative_sum: SepticDigest<Felt<SP1Field>>,
51 pub preprocessed_commit: SC::DigestVariable,
53 pub enable_untrusted_programs: Felt<SP1Field>,
55}
56impl<C, SC> MachineVerifyingKeyVariable<C, SC>
57where
58 C: CircuitConfig,
59 SC: SP1FieldConfigVariable<C>,
60{
61 pub fn hash(&self, builder: &mut Builder<C>) -> SC::DigestVariable
65 where
66 SC::DigestVariable: IntoIterator<Item = Felt<SP1Field>>,
67 {
68 let num_inputs = DIGEST_SIZE + 3 + 14 + 1;
69 let mut inputs = Vec::with_capacity(num_inputs);
70 inputs.extend(self.preprocessed_commit);
71 inputs.extend(self.pc_start);
72 inputs.extend(self.initial_global_cumulative_sum.0.x.0);
73 inputs.extend(self.initial_global_cumulative_sum.0.y.0);
74 inputs.push(self.enable_untrusted_programs);
75
76 SC::hash(builder, &inputs)
77 }
78}
79
80pub struct RecursiveShardVerifier<
82 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
83 A: MachineAir<SP1Field>,
84 C: CircuitConfig,
85> {
86 pub machine: Machine<SP1Field, A>,
88 pub pcs_verifier: RecursiveJaggedPcsVerifier<GC, C>,
90 pub _phantom: std::marker::PhantomData<(GC, C, A)>,
91}
92
93impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
94where
95 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
96 A: MachineAir<SP1Field>,
97 C: CircuitConfig,
98{
99 pub fn verify_shard(
100 &self,
101 builder: &mut Builder<C>,
102 vk: &MachineVerifyingKeyVariable<C, GC>,
103 proof: &ShardProofVariable<C, GC>,
104 challenger: &mut GC::FriChallengerVariable,
105 ) where
106 A: for<'b> Air<RecursiveVerifierConstraintFolder<'b>>,
107 {
108 let ShardProofVariable {
109 main_commitment,
110 opened_values,
111 evaluation_proof,
112 zerocheck_proof,
113 public_values,
114 logup_gkr_proof,
115 } = proof;
116
117 let heights = opened_values
119 .chips
120 .iter()
121 .map(|(name, x)| (name.clone(), x.degree.clone()))
122 .collect::<BTreeMap<_, _>>();
123 let mut height_felts_map: BTreeMap<String, Felt<SP1Field>> = BTreeMap::new();
124 let two = SymbolicFelt::from_canonical_u32(2);
125 for (name, height) in &heights {
126 let mut acc = SymbolicFelt::zero();
127 assert!(height.len() == self.pcs_verifier.max_log_row_count + 1);
129 height.iter().for_each(|x| {
130 acc = *x + two * acc;
131 });
132 height_felts_map.insert(name.clone(), builder.eval(acc));
133 }
134
135 challenger.observe_slice(builder, public_values.to_vec());
137
138 for value in public_values[self.machine.num_pv_elts()..].iter() {
139 builder.assert_felt_eq(value, GC::F::zero());
140 }
141
142 challenger.observe(builder, *main_commitment);
144 let num_chips: Felt<GC::F> = builder.eval(GC::F::from_canonical_usize(heights.len()));
145 challenger.observe(builder, num_chips);
147
148 for (name, height) in height_felts_map.iter() {
149 challenger.observe(builder, *height);
150 let mut inputs: Vec<Felt<GC::F>> = vec![];
151 inputs.push(builder.eval(GC::F::from_canonical_usize(name.len())));
152 for byte in name.as_bytes() {
153 inputs.push(builder.eval(GC::F::from_canonical_u8(*byte)));
154 }
155 challenger.observe_slice(builder, inputs);
156 }
157
158 let shard_chips = self
159 .machine
160 .chips()
161 .iter()
162 .filter(|chip| heights.contains_key(chip.name()))
163 .cloned()
164 .collect::<BTreeSet<_>>();
165
166 let degrees = opened_values.chips.values().map(|x| x.degree.clone()).collect::<Vec<_>>();
167
168 let max_log_row_count = self.pcs_verifier.max_log_row_count;
169
170 builder.cycle_tracker_v2_enter("verify-logup-gkr");
172 RecursiveLogUpGkrVerifier::<C, GC, A>::verify_logup_gkr(
173 builder,
174 &shard_chips,
175 °rees,
176 max_log_row_count,
177 logup_gkr_proof,
178 public_values,
179 challenger,
180 );
181 builder.cycle_tracker_v2_exit();
182
183 builder.cycle_tracker_v2_enter("verify-zerocheck");
185 self.verify_zerocheck(
186 builder,
187 &shard_chips,
188 opened_values,
189 &logup_gkr_proof.logup_evaluations,
190 zerocheck_proof,
191 public_values,
192 challenger,
193 );
194 builder.cycle_tracker_v2_exit();
195
196 let (preprocessed_openings_for_proof, main_openings_for_proof): (Vec<_>, Vec<_>) = proof
198 .opened_values
199 .chips
200 .values()
201 .map(|opening| (opening.preprocessed.clone(), opening.main.clone()))
202 .unzip();
203
204 let preprocessed_openings = preprocessed_openings_for_proof
205 .iter()
206 .map(|x| x.local.iter().as_slice())
207 .collect::<Vec<_>>();
208
209 let main_openings = main_openings_for_proof
210 .iter()
211 .map(|x| x.local.iter().copied().collect::<MleEval<_>>())
212 .collect::<Evaluations<_>>();
213
214 let filtered_preprocessed_openings = preprocessed_openings
215 .clone()
216 .into_iter()
217 .filter(|x| !x.is_empty())
218 .map(|x| x.iter().copied().collect::<MleEval<_>>())
219 .collect::<Evaluations<_>>();
220
221 let preprocessed_column_count = filtered_preprocessed_openings
222 .iter()
223 .map(|table_openings| table_openings.len())
224 .collect::<Vec<_>>();
225
226 let added_columns: Vec<usize> =
227 proof.evaluation_proof.column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
228
229 let unfiltered_preprocessed_column_count = preprocessed_openings
230 .iter()
231 .map(|table_openings| table_openings.len())
232 .chain(std::iter::once(added_columns[0] - 1))
233 .collect::<Vec<_>>();
234
235 let main_column_count =
236 main_openings.iter().map(|table_openings| table_openings.len()).collect::<Vec<_>>();
237
238 let unfiltered_main_column_count = main_openings
239 .iter()
240 .map(|table_openings| table_openings.len())
241 .chain(std::iter::once(added_columns[1] - 1))
242 .collect::<Vec<_>>();
243
244 let (commitments, column_counts, unfiltered_column_counts, openings) = (
245 vec![vk.preprocessed_commit, *main_commitment],
246 vec![preprocessed_column_count, main_column_count.clone()],
247 vec![unfiltered_preprocessed_column_count, unfiltered_main_column_count],
248 Rounds { rounds: vec![filtered_preprocessed_openings, main_openings] },
249 );
250
251 let machine_jagged_verifier =
252 RecursiveMachineJaggedPcsVerifier::new(&self.pcs_verifier, column_counts.clone());
253
254 let openings = openings
255 .into_iter()
256 .map(|round| {
257 round
258 .into_iter()
259 .flat_map(std::iter::IntoIterator::into_iter)
260 .collect::<MleEval<_>>()
261 })
262 .collect::<Vec<_>>();
263
264 builder.cycle_tracker_v2_enter("jagged-verifier");
265 let prefix_sum_felts = machine_jagged_verifier.verify_trusted_evaluations(
266 builder,
267 &commitments,
268 zerocheck_proof.point_and_eval.0.clone(),
269 &openings,
270 evaluation_proof,
271 challenger,
272 );
273 builder.cycle_tracker_v2_exit();
274
275 let row_count_felt: Felt<_> = builder
276 .constant(SP1Field::from_canonical_u32(1 << self.pcs_verifier.max_log_row_count));
277
278 let params: Vec<Vec<Felt<SP1Field>>> = unfiltered_column_counts
279 .iter()
280 .map(|round| {
281 round
282 .iter()
283 .copied()
284 .zip(height_felts_map.values().copied().chain(std::iter::once(row_count_felt)))
285 .flat_map(|(column_count, height)| {
286 std::iter::repeat_n(height, column_count).collect::<Vec<_>>()
287 })
288 .collect::<Vec<_>>()
289 })
290 .collect();
291
292 let preprocessed_count = params[0].len();
293 let params = params.into_iter().flatten().collect::<Vec<_>>();
294
295 builder.cycle_tracker_v2_enter("jagged - prefix-sum-checks");
296 let mut param_index = 0;
297 let skip_indices = [preprocessed_count];
302
303 prefix_sum_felts
304 .iter()
305 .zip(prefix_sum_felts.iter().skip(1))
306 .enumerate()
307 .filter(|(i, _)| !skip_indices.contains(i))
308 .for_each(|(_, (x, y))| {
309 let sum = *x + params[param_index];
310 builder.assert_felt_eq(sum, *y);
311 param_index += 1;
312 });
313
314 builder.assert_felt_eq(prefix_sum_felts[0], SP1Field::zero());
315
316 builder.assert_felt_eq(
318 prefix_sum_felts[skip_indices[0] + 1],
319 SP1Field::from_canonical_usize(
320 (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
321 * evaluation_proof.pcs_proof.batch_evaluations.rounds[0].num_polynomials(),
322 ),
323 );
324
325 let preprocessed_padding_col_height =
326 builder.eval(prefix_sum_felts[skip_indices[0] + 1] - prefix_sum_felts[skip_indices[0]]);
327 let preprocessed_padding_col_bit_decomp = C::num2bits(
328 builder,
329 preprocessed_padding_col_height,
330 self.pcs_verifier.max_log_row_count + 1,
331 );
332
333 let max_bit = preprocessed_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
344 let max_bit = C::bits2num(builder, vec![max_bit]);
345 let zero: Felt<_> = builder.constant(SP1Field::zero());
346 for bit in
347 preprocessed_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count)
348 {
349 let bit_felt = C::bits2num(builder, vec![*bit]);
350 builder.assert_felt_eq(max_bit * bit_felt, zero);
351 }
352 let num_cols = prefix_sum_felts.len();
353
354 let main_padding_col_height =
356 builder.eval(prefix_sum_felts[num_cols - 1] - prefix_sum_felts[num_cols - 2]);
357
358 let main_padding_col_bit_decomp = C::num2bits(builder, main_padding_col_height, NUM_BITS);
359
360 let max_bit = main_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
361 let max_bit = C::bits2num(builder, vec![max_bit]);
362 for bit in main_padding_col_bit_decomp.iter().skip(self.pcs_verifier.max_log_row_count + 1)
363 {
364 C::assert_bit_zero(builder, *bit);
365 }
366 for bit in main_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count) {
367 let bit_felt = C::bits2num(builder, vec![*bit]);
368 builder.assert_felt_eq(max_bit * bit_felt, zero);
369 }
370
371 let total_area_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(
373 (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
374 * proof
375 .evaluation_proof
376 .pcs_proof
377 .batch_evaluations
378 .iter()
379 .map(|evaluations| evaluations.num_polynomials())
380 .sum::<usize>(),
381 ));
382
383 let mut acc = SymbolicFelt::zero();
385 proof.evaluation_proof.params.col_prefix_sums.iter().last().unwrap().iter().for_each(|x| {
387 acc = *x + two * acc;
388 });
389
390 builder.assert_felt_eq(acc, total_area_felt);
392
393 builder.cycle_tracker_v2_exit();
394 }
395}
396
397pub type RecursiveVerifierPublicValuesConstraintFolder<'a> =
398 GenericVerifierPublicValuesConstraintFolder<
399 'a,
400 SP1Field,
401 SP1ExtensionField,
402 Felt<SP1Field>,
403 Ext<SP1Field, SP1ExtensionField>,
404 SymbolicExt<SP1Field, SP1ExtensionField>,
405 >;
406
407#[cfg(test)]
408mod tests {
409 use std::{marker::PhantomData, sync::Arc};
410
411 use slop_basefold::{BasefoldVerifier, FriConfig};
412 use sp1_core_executor::{Program, SP1Context, SP1CoreOpts};
413 use sp1_core_machine::{
414 io::SP1Stdin,
415 riscv::RiscvAir,
416 utils::{prove_core, setup_logger},
417 };
418 use sp1_hypercube::{
419 prover::{CpuShardProver, SP1InnerPcsProver, SimpleProver},
420 MachineVerifier, SP1InnerPcs, ShardVerifier, NUM_SP1_COMMITMENTS,
421 };
422 use sp1_recursion_compiler::{
423 circuit::{AsmCompiler, AsmConfig},
424 config::InnerConfig,
425 };
426 use sp1_recursion_machine::test::run_recursion_test_machines;
427
428 use crate::{
429 basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
430 challenger::DuplexChallengerVariable,
431 dummy::dummy_shard_proof,
432 jagged::RecursiveJaggedEvalSumcheckConfig,
433 witness::Witnessable,
434 };
435
436 use super::*;
437
438 use sp1_primitives::{SP1Field, SP1GlobalContext};
439 type GC = SP1GlobalContext;
440 type C = InnerConfig;
441 type A = RiscvAir<SP1Field>;
442
443 #[tokio::test]
444 async fn test_verify_shard() {
445 setup_logger();
446 let log_stacking_height = 21;
447 let max_log_row_count = 22;
448 let machine = RiscvAir::machine();
449 let verifier = ShardVerifier::from_basefold_parameters(
450 FriConfig::default_fri_config(),
451 log_stacking_height,
452 max_log_row_count,
453 machine.clone(),
454 );
455
456 let elf = test_artifacts::FIBONACCI_ELF;
457 let program = Arc::new(Program::from(&elf).unwrap());
458 let shard_prover =
459 CpuShardProver::<SP1GlobalContext, SP1InnerPcs, SP1InnerPcsProver, _>::new(
460 verifier.clone(),
461 );
462 let prover = SimpleProver::new(verifier.clone(), shard_prover);
463
464 let (pk, vk) = prover.setup(program.clone()).await;
465 let pk = unsafe { pk.into_inner() };
466 let (proof, _) = prove_core(
467 &prover,
468 pk,
469 program,
470 SP1Stdin::default(),
471 SP1CoreOpts::default(),
472 SP1Context::default(),
473 )
474 .await
475 .unwrap();
476
477 let mut builder = Builder::<C>::default();
478
479 let mut initial_challenger = verifier.jagged_pcs_verifier.challenger();
482 vk.observe_into(&mut initial_challenger);
483
484 let machine_verifier = MachineVerifier::new(verifier);
485 machine_verifier.verify(&vk, &proof).unwrap();
486
487 let shard_proof = proof.shard_proofs[0].clone();
488 let shape = machine_verifier.shape_from_proof(&shard_proof);
489
490 let dummy_proof = dummy_shard_proof(
491 shape.shard_chips,
492 max_log_row_count,
493 FriConfig::default_fri_config(),
494 log_stacking_height as usize,
495 &[shape.preprocessed_multiple, shape.main_multiple],
496 &[shape.preprocessed_padding_cols, shape.main_padding_cols],
497 );
498
499 let vk_variable = vk.read(&mut builder);
500 let shard_proof_variable = dummy_proof.read(&mut builder);
501
502 let verifier =
503 BasefoldVerifier::<GC>::new(FriConfig::default_fri_config(), NUM_SP1_COMMITMENTS);
504 let recursive_verifier = crate::basefold::RecursiveBasefoldVerifier::<C, GC> {
505 fri_config: verifier.fri_config,
506 tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
507 };
508 let recursive_verifier =
509 RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
510
511 let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<GC, C> {
512 stacked_pcs_verifier: recursive_verifier,
513 max_log_row_count,
514 jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
515 };
516
517 let stark_verifier = RecursiveShardVerifier::<GC, A, C> {
518 machine,
519 pcs_verifier: recursive_jagged_verifier,
520 _phantom: std::marker::PhantomData,
521 };
522
523 let mut challenger_variable =
524 DuplexChallengerVariable::from_challenger(&mut builder, &initial_challenger);
525
526 builder.cycle_tracker_v2_enter("verify-shard");
527 stark_verifier.verify_shard(
528 &mut builder,
529 &vk_variable,
530 &shard_proof_variable,
531 &mut challenger_variable,
532 );
533 builder.cycle_tracker_v2_exit();
534
535 let block = builder.into_root_block();
536 let mut compiler = AsmCompiler::default();
537 let program = compiler.compile_inner(block).validate().unwrap();
538
539 let mut witness_stream = Vec::new();
540 Witnessable::<AsmConfig>::write(&vk, &mut witness_stream);
541 Witnessable::<AsmConfig>::write(&shard_proof, &mut witness_stream);
542
543 run_recursion_test_machines(program.clone(), witness_stream).await;
544 }
545}