1use crate::septic_curve::SepticCurve;
2use crate::septic_digest::SepticDigest;
3use crate::septic_extension::SepticExtension;
4use crate::{air::InteractionScope, AirOpenedValues, ChipOpenedValues, ShardOpenedValues};
5use core::fmt::Display;
6use itertools::Itertools;
7use p3_air::Air;
8use p3_challenger::{CanObserve, FieldChallenger};
9use p3_commit::{Pcs, PolynomialSpace};
10use p3_field::{AbstractExtensionField, AbstractField, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Matrix};
12use p3_maybe_rayon::prelude::*;
13use p3_uni_stark::SymbolicAirBuilder;
14use p3_util::log2_strict_usize;
15use serde::{de::DeserializeOwned, Serialize};
16use std::{cmp::Reverse, error::Error, time::Instant};
17
18use super::{
19 quotient_values, Com, OpeningProof, StarkGenericConfig, StarkMachine, StarkProvingKey, Val,
20 VerifierConstraintFolder,
21};
22use crate::{
23 air::MachineAir, lookup::InteractionBuilder, opts::SP1CoreOpts, record::MachineRecord,
24 Challenger, DebugConstraintBuilder, MachineChip, MachineProof, PackedChallenge, PcsProverData,
25 ProverConstraintFolder, ShardCommitment, ShardMainData, ShardProof, StarkVerifyingKey,
26};
27
28pub trait MachineProver<SC: StarkGenericConfig, A: MachineAir<SC::Val>>:
30 'static + Send + Sync
31{
32 type DeviceMatrix: Matrix<SC::Val>;
34
35 type DeviceProverData;
37
38 type DeviceProvingKey: MachineProvingKey<SC>;
40
41 type Error: Error + Send + Sync;
43
44 fn new(machine: StarkMachine<SC, A>) -> Self;
46
47 fn machine(&self) -> &StarkMachine<SC, A>;
49
50 fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>);
52
53 fn pk_from_vk(
56 &self,
57 program: &A::Program,
58 vk: &StarkVerifyingKey<SC>,
59 ) -> Self::DeviceProvingKey;
60
61 fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey;
63
64 fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC>;
66
67 fn generate_traces(&self, record: &A::Record) -> Vec<(String, RowMajorMatrix<Val<SC>>)> {
69 let shard_chips = self.shard_chips(record).collect::<Vec<_>>();
70
71 let parent_span = tracing::debug_span!("generate traces for shard");
73 parent_span.in_scope(|| {
74 shard_chips
75 .par_iter()
76 .map(|chip| {
77 let chip_name = chip.name();
78 let begin = Instant::now();
79 let trace = chip.generate_trace(record, &mut A::Record::default());
80 tracing::debug!(
81 parent: &parent_span,
82 "generated trace for chip {} in {:?}",
83 chip_name,
84 begin.elapsed()
85 );
86 (chip_name, trace)
87 })
88 .collect::<Vec<_>>()
89 })
90 }
91
92 fn commit(
94 &self,
95 record: &A::Record,
96 traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
97 ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>;
98
99 fn observe(
101 &self,
102 challenger: &mut SC::Challenger,
103 commitment: Com<SC>,
104 public_values: &[SC::Val],
105 ) {
106 challenger.observe(commitment);
108
109 challenger.observe_slice(public_values);
111 }
112
113 fn open(
115 &self,
116 pk: &Self::DeviceProvingKey,
117 data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
118 challenger: &mut SC::Challenger,
119 ) -> Result<ShardProof<SC>, Self::Error>;
120
121 fn prove(
123 &self,
124 pk: &Self::DeviceProvingKey,
125 records: Vec<A::Record>,
126 challenger: &mut SC::Challenger,
127 opts: <A::Record as MachineRecord>::Config,
128 ) -> Result<MachineProof<SC>, Self::Error>
129 where
130 A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>;
131
132 fn config(&self) -> &SC {
134 self.machine().config()
135 }
136
137 fn num_pv_elts(&self) -> usize {
139 self.machine().num_pv_elts()
140 }
141
142 fn shard_chips<'a, 'b>(
144 &'a self,
145 record: &'b A::Record,
146 ) -> impl Iterator<Item = &'b MachineChip<SC, A>>
147 where
148 'a: 'b,
149 SC: 'b,
150 {
151 self.machine().shard_chips(record)
152 }
153
154 fn debug_constraints(
156 &self,
157 pk: &StarkProvingKey<SC>,
158 records: Vec<A::Record>,
159 challenger: &mut SC::Challenger,
160 ) where
161 SC::Val: PrimeField32,
162 A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
163 {
164 self.machine().debug_constraints(pk, records, challenger);
165 }
166}
167
168pub trait MachineProvingKey<SC: StarkGenericConfig>: Send + Sync {
170 fn preprocessed_commit(&self) -> Com<SC>;
172
173 fn pc_start(&self) -> Val<SC>;
175
176 fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>>;
178
179 fn observe_into(&self, challenger: &mut Challenger<SC>);
181}
182
183pub struct CpuProver<SC: StarkGenericConfig, A> {
185 machine: StarkMachine<SC, A>,
186}
187
188#[derive(Debug, Clone, Copy)]
190pub struct CpuProverError;
191
192impl<SC, A> MachineProver<SC, A> for CpuProver<SC, A>
193where
194 SC: 'static + StarkGenericConfig + Send + Sync,
195 A: MachineAir<SC::Val>
196 + for<'a> Air<ProverConstraintFolder<'a, SC>>
197 + Air<InteractionBuilder<Val<SC>>>
198 + for<'a> Air<VerifierConstraintFolder<'a, SC>>
199 + for<'a> Air<SymbolicAirBuilder<Val<SC>>>,
200 A::Record: MachineRecord<Config = SP1CoreOpts>,
201 SC::Val: PrimeField32,
202 Com<SC>: Send + Sync,
203 PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
204 OpeningProof<SC>: Send + Sync,
205 SC::Challenger: Clone,
206{
207 type DeviceMatrix = RowMajorMatrix<Val<SC>>;
208 type DeviceProverData = PcsProverData<SC>;
209 type DeviceProvingKey = StarkProvingKey<SC>;
210 type Error = CpuProverError;
211
212 fn new(machine: StarkMachine<SC, A>) -> Self {
213 Self { machine }
214 }
215
216 fn machine(&self) -> &StarkMachine<SC, A> {
217 &self.machine
218 }
219
220 fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>) {
221 self.machine().setup(program)
222 }
223
224 fn pk_from_vk(
225 &self,
226 program: &A::Program,
227 vk: &StarkVerifyingKey<SC>,
228 ) -> Self::DeviceProvingKey {
229 self.machine().setup_core(program, vk.initial_global_cumulative_sum).0
230 }
231
232 fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey {
233 pk.clone()
234 }
235
236 fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC> {
237 pk.clone()
238 }
239
240 fn commit(
241 &self,
242 record: &A::Record,
243 mut named_traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
244 ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData> {
245 named_traces.sort_by_key(|(name, trace)| (Reverse(trace.height()), name.clone()));
247
248 let pcs = self.config().pcs();
249
250 let domains_and_traces = named_traces
251 .iter()
252 .map(|(_, trace)| {
253 let domain = pcs.natural_domain_for_degree(trace.height());
254 (domain, trace.to_owned())
255 })
256 .collect::<Vec<_>>();
257
258 let (main_commit, main_data) = pcs.commit(domains_and_traces);
260
261 let chip_ordering =
263 named_traces.iter().enumerate().map(|(i, (name, _))| (name.to_owned(), i)).collect();
264
265 let traces = named_traces.into_iter().map(|(_, trace)| trace).collect::<Vec<_>>();
266
267 ShardMainData {
268 traces,
269 main_commit,
270 main_data,
271 chip_ordering,
272 public_values: record.public_values(),
273 }
274 }
275
276 #[allow(clippy::too_many_lines)]
278 #[allow(clippy::redundant_closure_for_method_calls)]
279 #[allow(clippy::map_unwrap_or)]
280 fn open(
281 &self,
282 pk: &StarkProvingKey<SC>,
283 data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
284 challenger: &mut <SC as StarkGenericConfig>::Challenger,
285 ) -> Result<ShardProof<SC>, Self::Error> {
286 let chips = self.machine().shard_chips_ordered(&data.chip_ordering).collect::<Vec<_>>();
287 let traces = data.traces;
288
289 let config = self.machine().config();
290
291 let degrees = traces.iter().map(|trace| trace.height()).collect::<Vec<_>>();
292
293 let log_degrees =
294 degrees.iter().map(|degree| log2_strict_usize(*degree)).collect::<Vec<_>>();
295
296 let log_quotient_degrees =
297 chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
298
299 let pcs = config.pcs();
300 let trace_domains =
301 degrees.iter().map(|degree| pcs.natural_domain_for_degree(*degree)).collect::<Vec<_>>();
302
303 challenger.observe_slice(&data.public_values[0..self.num_pv_elts()]);
305 challenger.observe(data.main_commit.clone());
306
307 let mut local_permutation_challenges: Vec<SC::Challenge> = Vec::new();
309 for _ in 0..2 {
310 local_permutation_challenges.push(challenger.sample_ext_element());
311 }
312
313 let packed_perm_challenges = local_permutation_challenges
314 .iter()
315 .map(|c| PackedChallenge::<SC>::from_f(*c))
316 .collect::<Vec<_>>();
317
318 let ((permutation_traces, prep_traces), (global_cumulative_sums, local_cumulative_sums)): (
320 (Vec<_>, Vec<_>),
321 (Vec<_>, Vec<_>),
322 ) = tracing::debug_span!("generate permutation traces").in_scope(|| {
323 chips
324 .par_iter()
325 .zip(traces.par_iter())
326 .map(|(chip, main_trace)| {
327 let preprocessed_trace =
328 pk.chip_ordering.get(&chip.name()).map(|&index| &pk.traces[index]);
329 let (perm_trace, local_sum) = chip.generate_permutation_trace(
330 preprocessed_trace,
331 main_trace,
332 &local_permutation_challenges,
333 );
334 let global_sum = if chip.commit_scope() == InteractionScope::Local {
335 SepticDigest::<Val<SC>>::zero()
336 } else {
337 let main_trace_size = main_trace.height() * main_trace.width();
338 let last_row = &main_trace.values[main_trace_size - 14..main_trace_size];
339 SepticDigest(SepticCurve {
340 x: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i]),
341 y: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i + 7]),
342 })
343 };
344 ((perm_trace, preprocessed_trace), (global_sum, local_sum))
345 })
346 .unzip()
347 });
348
349 for i in 0..chips.len() {
351 let trace_width = traces[i].width();
352 let trace_height = traces[i].height();
353 let prep_width = prep_traces[i].map_or(0, |x| x.width());
354 let permutation_width = permutation_traces[i].width();
355 let total_width = trace_width
356 + prep_width
357 + permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D;
358 tracing::debug!(
359 "{:<15} | Main Cols = {:<5} | Pre Cols = {:<5} | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}",
360 chips[i].name(),
361 trace_width,
362 prep_width,
363 permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D,
364 trace_height,
365 total_width * trace_height,
366 );
367 }
368
369 let domains_and_perm_traces =
370 tracing::debug_span!("flatten permutation traces and collect domains").in_scope(|| {
371 permutation_traces
372 .into_iter()
373 .zip(trace_domains.iter())
374 .map(|(perm_trace, domain)| {
375 let trace = perm_trace.flatten_to_base();
376 (*domain, trace.clone())
377 })
378 .collect::<Vec<_>>()
379 });
380
381 let pcs = config.pcs();
382
383 let (permutation_commit, permutation_data) =
384 tracing::debug_span!("commit to permutation traces")
385 .in_scope(|| pcs.commit(domains_and_perm_traces));
386
387 challenger.observe(permutation_commit.clone());
389 for (local_sum, global_sum) in
390 local_cumulative_sums.iter().zip(global_cumulative_sums.iter())
391 {
392 challenger.observe_slice(local_sum.as_base_slice());
393 challenger.observe_slice(&global_sum.0.x.0);
394 challenger.observe_slice(&global_sum.0.y.0);
395 }
396
397 let quotient_domains = trace_domains
399 .iter()
400 .zip_eq(log_degrees.iter())
401 .zip_eq(log_quotient_degrees.iter())
402 .map(|((domain, log_degree), log_quotient_degree)| {
403 domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree))
404 })
405 .collect::<Vec<_>>();
406
407 let alpha: SC::Challenge = challenger.sample_ext_element::<SC::Challenge>();
409 let parent_span = tracing::debug_span!("compute quotient values");
410 let quotient_values = parent_span.in_scope(|| {
411 quotient_domains
412 .into_par_iter()
413 .enumerate()
414 .map(|(i, quotient_domain)| {
415 tracing::debug_span!(parent: &parent_span, "compute quotient values for domain")
416 .in_scope(|| {
417 let preprocessed_trace_on_quotient_domains =
418 pk.chip_ordering.get(&chips[i].name()).map(|&index| {
419 pcs.get_evaluations_on_domain(&pk.data, index, *quotient_domain)
420 .to_row_major_matrix()
421 });
422 let main_trace_on_quotient_domains = pcs
423 .get_evaluations_on_domain(&data.main_data, i, *quotient_domain)
424 .to_row_major_matrix();
425 let permutation_trace_on_quotient_domains = pcs
426 .get_evaluations_on_domain(&permutation_data, i, *quotient_domain)
427 .to_row_major_matrix();
428
429 let chip_num_constraints =
430 pk.constraints_map.get(&chips[i].name()).unwrap();
431
432 let powers_of_alpha =
436 alpha.powers().take(*chip_num_constraints).collect::<Vec<_>>();
437 let mut powers_of_alpha_rev = powers_of_alpha.clone();
438 powers_of_alpha_rev.reverse();
439
440 quotient_values(
441 chips[i],
442 &local_cumulative_sums[i],
443 &global_cumulative_sums[i],
444 trace_domains[i],
445 *quotient_domain,
446 preprocessed_trace_on_quotient_domains,
447 main_trace_on_quotient_domains,
448 permutation_trace_on_quotient_domains,
449 &packed_perm_challenges,
450 &powers_of_alpha_rev,
451 &data.public_values,
452 )
453 })
454 })
455 .collect::<Vec<_>>()
456 });
457
458 let quotient_domains_and_chunks = quotient_domains
460 .into_iter()
461 .zip_eq(quotient_values)
462 .zip_eq(log_quotient_degrees.iter())
463 .flat_map(|((quotient_domain, quotient_values), log_quotient_degree)| {
464 let quotient_degree = 1 << *log_quotient_degree;
465 let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base();
466 let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
467 let qc_domains = quotient_domain.split_domains(quotient_degree);
468 qc_domains.into_iter().zip_eq(quotient_chunks)
469 })
470 .collect::<Vec<_>>();
471
472 let num_quotient_chunks = quotient_domains_and_chunks.len();
473 assert_eq!(
474 num_quotient_chunks,
475 chips.iter().map(|c| 1 << c.log_quotient_degree()).sum::<usize>()
476 );
477
478 let (quotient_commit, quotient_data) = tracing::debug_span!("commit to quotient traces")
479 .in_scope(|| pcs.commit(quotient_domains_and_chunks));
480 challenger.observe(quotient_commit.clone());
481
482 let zeta: SC::Challenge = challenger.sample_ext_element();
484
485 let preprocessed_opening_points =
486 tracing::debug_span!("compute preprocessed opening points").in_scope(|| {
487 pk.traces
488 .iter()
489 .zip(pk.local_only.iter())
490 .map(|(trace, local_only)| {
491 let domain = pcs.natural_domain_for_degree(trace.height());
492 if !local_only {
493 vec![zeta, domain.next_point(zeta).unwrap()]
494 } else {
495 vec![zeta]
496 }
497 })
498 .collect::<Vec<_>>()
499 });
500
501 let main_trace_opening_points = tracing::debug_span!("compute main trace opening points")
502 .in_scope(|| {
503 trace_domains
504 .iter()
505 .zip(chips.iter())
506 .map(|(domain, chip)| {
507 if !chip.local_only() {
508 vec![zeta, domain.next_point(zeta).unwrap()]
509 } else {
510 vec![zeta]
511 }
512 })
513 .collect::<Vec<_>>()
514 });
515
516 let permutation_trace_opening_points =
517 tracing::debug_span!("compute permutation trace opening points").in_scope(|| {
518 trace_domains
519 .iter()
520 .map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
521 .collect::<Vec<_>>()
522 });
523
524 let quotient_opening_points =
526 (0..num_quotient_chunks).map(|_| vec![zeta]).collect::<Vec<_>>();
527
528 let (openings, opening_proof) = tracing::debug_span!("open multi batches").in_scope(|| {
529 pcs.open(
530 vec![
531 (&pk.data, preprocessed_opening_points),
532 (&data.main_data, main_trace_opening_points.clone()),
533 (&permutation_data, permutation_trace_opening_points.clone()),
534 ("ient_data, quotient_opening_points),
535 ],
536 challenger,
537 )
538 });
539
540 let [preprocessed_values, main_values, permutation_values, mut quotient_values] =
542 openings.try_into().unwrap();
543 assert!(main_values.len() == chips.len());
544 let preprocessed_opened_values = preprocessed_values
545 .into_iter()
546 .zip(pk.local_only.iter())
547 .map(|(op, local_only)| {
548 if !local_only {
549 let [local, next] = op.try_into().unwrap();
550 AirOpenedValues { local, next }
551 } else {
552 let [local] = op.try_into().unwrap();
553 let width = local.len();
554 AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
555 }
556 })
557 .collect::<Vec<_>>();
558
559 let main_opened_values = main_values
560 .into_iter()
561 .zip(chips.iter())
562 .map(|(op, chip)| {
563 if !chip.local_only() {
564 let [local, next] = op.try_into().unwrap();
565 AirOpenedValues { local, next }
566 } else {
567 let [local] = op.try_into().unwrap();
568 let width = local.len();
569 AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
570 }
571 })
572 .collect::<Vec<_>>();
573 let permutation_opened_values = permutation_values
574 .into_iter()
575 .map(|op| {
576 let [local, next] = op.try_into().unwrap();
577 AirOpenedValues { local, next }
578 })
579 .collect::<Vec<_>>();
580 let mut quotient_opened_values = Vec::with_capacity(log_quotient_degrees.len());
581 for log_quotient_degree in log_quotient_degrees.iter() {
582 let degree = 1 << *log_quotient_degree;
583 let slice = quotient_values.drain(0..degree);
584 quotient_opened_values.push(slice.map(|mut op| op.pop().unwrap()).collect::<Vec<_>>());
585 }
586
587 let opened_values = main_opened_values
588 .into_iter()
589 .zip_eq(permutation_opened_values)
590 .zip_eq(quotient_opened_values)
591 .zip_eq(local_cumulative_sums)
592 .zip_eq(global_cumulative_sums)
593 .zip_eq(log_degrees.iter())
594 .enumerate()
595 .map(
596 |(
597 i,
598 (
599 (
600 (((main, permutation), quotient), local_cumulative_sum),
601 global_cumulative_sum,
602 ),
603 log_degree,
604 ),
605 )| {
606 let preprocessed = pk
607 .chip_ordering
608 .get(&chips[i].name())
609 .map(|&index| preprocessed_opened_values[index].clone())
610 .unwrap_or(AirOpenedValues { local: vec![], next: vec![] });
611 ChipOpenedValues {
612 preprocessed,
613 main,
614 permutation,
615 quotient,
616 global_cumulative_sum,
617 local_cumulative_sum,
618 log_degree: *log_degree,
619 }
620 },
621 )
622 .collect::<Vec<_>>();
623
624 Ok(ShardProof::<SC> {
625 commitment: ShardCommitment {
626 main_commit: data.main_commit.clone(),
627 permutation_commit,
628 quotient_commit,
629 },
630 opened_values: ShardOpenedValues { chips: opened_values },
631 opening_proof,
632 chip_ordering: data.chip_ordering,
633 public_values: data.public_values,
634 })
635 }
636
637 #[allow(clippy::needless_for_each)]
642 fn prove(
643 &self,
644 pk: &StarkProvingKey<SC>,
645 mut records: Vec<A::Record>,
646 challenger: &mut SC::Challenger,
647 opts: <A::Record as MachineRecord>::Config,
648 ) -> Result<MachineProof<SC>, Self::Error>
649 where
650 A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
651 {
652 self.machine().generate_dependencies(&mut records, &opts, None);
654
655 pk.observe_into(challenger);
657
658 let shard_proofs = tracing::info_span!("prove_shards").in_scope(|| {
659 records
660 .into_par_iter()
661 .map(|record| {
662 let named_traces = self.generate_traces(&record);
663 let shard_data = self.commit(&record, named_traces);
664 self.open(pk, shard_data, &mut challenger.clone())
665 })
666 .collect::<Result<Vec<_>, _>>()
667 })?;
668
669 Ok(MachineProof { shard_proofs })
670 }
671}
672
673impl<SC> MachineProvingKey<SC> for StarkProvingKey<SC>
674where
675 SC: 'static + StarkGenericConfig + Send + Sync,
676 PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
677 Com<SC>: Send + Sync,
678{
679 fn preprocessed_commit(&self) -> Com<SC> {
680 self.commit.clone()
681 }
682
683 fn pc_start(&self) -> Val<SC> {
684 self.pc_start
685 }
686
687 fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>> {
688 self.initial_global_cumulative_sum
689 }
690
691 fn observe_into(&self, challenger: &mut Challenger<SC>) {
692 challenger.observe(self.commit.clone());
693 challenger.observe(self.pc_start);
694 challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
695 challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
696 let zero = Val::<SC>::zero();
697 challenger.observe(zero);
698 }
699}
700
701impl Display for CpuProverError {
702 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
703 write!(f, "DefaultProverError")
704 }
705}
706
707impl Error for CpuProverError {}