sp1_stark/
prover.rs

1use crate::septic_curve::SepticCurve;
2use crate::septic_digest::SepticDigest;
3use crate::septic_extension::SepticExtension;
4use crate::{air::InteractionScope, AirOpenedValues, ChipOpenedValues, ShardOpenedValues};
5use core::fmt::Display;
6use itertools::Itertools;
7use p3_air::Air;
8use p3_challenger::{CanObserve, FieldChallenger};
9use p3_commit::{Pcs, PolynomialSpace};
10use p3_field::{AbstractExtensionField, AbstractField, PrimeField32};
11use p3_matrix::{dense::RowMajorMatrix, Matrix};
12use p3_maybe_rayon::prelude::*;
13use p3_uni_stark::SymbolicAirBuilder;
14use p3_util::log2_strict_usize;
15use serde::{de::DeserializeOwned, Serialize};
16use std::{cmp::Reverse, error::Error, time::Instant};
17
18use super::{
19    quotient_values, Com, OpeningProof, StarkGenericConfig, StarkMachine, StarkProvingKey, Val,
20    VerifierConstraintFolder,
21};
22use crate::{
23    air::MachineAir, lookup::InteractionBuilder, opts::SP1CoreOpts, record::MachineRecord,
24    Challenger, DebugConstraintBuilder, MachineChip, MachineProof, PackedChallenge, PcsProverData,
25    ProverConstraintFolder, ShardCommitment, ShardMainData, ShardProof, StarkVerifyingKey,
26};
27
28/// An algorithmic & hardware independent prover implementation for any [`MachineAir`].
29pub trait MachineProver<SC: StarkGenericConfig, A: MachineAir<SC::Val>>:
30    'static + Send + Sync
31{
32    /// The type used to store the traces.
33    type DeviceMatrix: Matrix<SC::Val>;
34
35    /// The type used to store the polynomial commitment schemes data.
36    type DeviceProverData;
37
38    /// The type used to store the proving key.
39    type DeviceProvingKey: MachineProvingKey<SC>;
40
41    /// The type used for error handling.
42    type Error: Error + Send + Sync;
43
44    /// Create a new prover from a given machine.
45    fn new(machine: StarkMachine<SC, A>) -> Self;
46
47    /// A reference to the machine that this prover is using.
48    fn machine(&self) -> &StarkMachine<SC, A>;
49
50    /// Setup the preprocessed data into a proving and verifying key.
51    fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>);
52
53    /// Setup the proving key given a verifying key. This is similar to `setup` but faster since
54    /// some computed information is already in the verifying key.
55    fn pk_from_vk(
56        &self,
57        program: &A::Program,
58        vk: &StarkVerifyingKey<SC>,
59    ) -> Self::DeviceProvingKey;
60
61    /// Copy the proving key from the host to the device.
62    fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey;
63
64    /// Copy the proving key from the device to the host.
65    fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC>;
66
67    /// Generate the main traces.
68    fn generate_traces(&self, record: &A::Record) -> Vec<(String, RowMajorMatrix<Val<SC>>)> {
69        let shard_chips = self.shard_chips(record).collect::<Vec<_>>();
70
71        // For each chip, generate the trace.
72        let parent_span = tracing::debug_span!("generate traces for shard");
73        parent_span.in_scope(|| {
74            shard_chips
75                .par_iter()
76                .map(|chip| {
77                    let chip_name = chip.name();
78                    let begin = Instant::now();
79                    let trace = chip.generate_trace(record, &mut A::Record::default());
80                    tracing::debug!(
81                        parent: &parent_span,
82                        "generated trace for chip {} in {:?}",
83                        chip_name,
84                        begin.elapsed()
85                    );
86                    (chip_name, trace)
87                })
88                .collect::<Vec<_>>()
89        })
90    }
91
92    /// Commit to the main traces.
93    fn commit(
94        &self,
95        record: &A::Record,
96        traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
97    ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>;
98
99    /// Observe the main commitment and public values and update the challenger.
100    fn observe(
101        &self,
102        challenger: &mut SC::Challenger,
103        commitment: Com<SC>,
104        public_values: &[SC::Val],
105    ) {
106        // Observe the commitment.
107        challenger.observe(commitment);
108
109        // Observe the public values.
110        challenger.observe_slice(public_values);
111    }
112
113    /// Compute the openings of the traces.
114    fn open(
115        &self,
116        pk: &Self::DeviceProvingKey,
117        data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
118        challenger: &mut SC::Challenger,
119    ) -> Result<ShardProof<SC>, Self::Error>;
120
121    /// Generate a proof for the given records.
122    fn prove(
123        &self,
124        pk: &Self::DeviceProvingKey,
125        records: Vec<A::Record>,
126        challenger: &mut SC::Challenger,
127        opts: <A::Record as MachineRecord>::Config,
128    ) -> Result<MachineProof<SC>, Self::Error>
129    where
130        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>;
131
132    /// The stark config for the machine.
133    fn config(&self) -> &SC {
134        self.machine().config()
135    }
136
137    /// The number of public values elements.
138    fn num_pv_elts(&self) -> usize {
139        self.machine().num_pv_elts()
140    }
141
142    /// The chips that will be necessary to prove this record.
143    fn shard_chips<'a, 'b>(
144        &'a self,
145        record: &'b A::Record,
146    ) -> impl Iterator<Item = &'b MachineChip<SC, A>>
147    where
148        'a: 'b,
149        SC: 'b,
150    {
151        self.machine().shard_chips(record)
152    }
153
154    /// Debug the constraints for the given inputs.
155    fn debug_constraints(
156        &self,
157        pk: &StarkProvingKey<SC>,
158        records: Vec<A::Record>,
159        challenger: &mut SC::Challenger,
160    ) where
161        SC::Val: PrimeField32,
162        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
163    {
164        self.machine().debug_constraints(pk, records, challenger);
165    }
166}
167
168/// A proving key for any [`MachineAir`] that is agnostic to hardware.
169pub trait MachineProvingKey<SC: StarkGenericConfig>: Send + Sync {
170    /// The main commitment.
171    fn preprocessed_commit(&self) -> Com<SC>;
172
173    /// The start pc.
174    fn pc_start(&self) -> Val<SC>;
175
176    /// The initial global cumulative sum.
177    fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>>;
178
179    /// Observe itself in the challenger.
180    fn observe_into(&self, challenger: &mut Challenger<SC>);
181}
182
183/// A prover implementation based on x86 and ARM CPUs.
184pub struct CpuProver<SC: StarkGenericConfig, A> {
185    machine: StarkMachine<SC, A>,
186}
187
188/// An error that occurs during the execution of the [`CpuProver`].
189#[derive(Debug, Clone, Copy)]
190pub struct CpuProverError;
191
192impl<SC, A> MachineProver<SC, A> for CpuProver<SC, A>
193where
194    SC: 'static + StarkGenericConfig + Send + Sync,
195    A: MachineAir<SC::Val>
196        + for<'a> Air<ProverConstraintFolder<'a, SC>>
197        + Air<InteractionBuilder<Val<SC>>>
198        + for<'a> Air<VerifierConstraintFolder<'a, SC>>
199        + for<'a> Air<SymbolicAirBuilder<Val<SC>>>,
200    A::Record: MachineRecord<Config = SP1CoreOpts>,
201    SC::Val: PrimeField32,
202    Com<SC>: Send + Sync,
203    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
204    OpeningProof<SC>: Send + Sync,
205    SC::Challenger: Clone,
206{
207    type DeviceMatrix = RowMajorMatrix<Val<SC>>;
208    type DeviceProverData = PcsProverData<SC>;
209    type DeviceProvingKey = StarkProvingKey<SC>;
210    type Error = CpuProverError;
211
212    fn new(machine: StarkMachine<SC, A>) -> Self {
213        Self { machine }
214    }
215
216    fn machine(&self) -> &StarkMachine<SC, A> {
217        &self.machine
218    }
219
220    fn setup(&self, program: &A::Program) -> (Self::DeviceProvingKey, StarkVerifyingKey<SC>) {
221        self.machine().setup(program)
222    }
223
224    fn pk_from_vk(
225        &self,
226        program: &A::Program,
227        vk: &StarkVerifyingKey<SC>,
228    ) -> Self::DeviceProvingKey {
229        self.machine().setup_core(program, vk.initial_global_cumulative_sum).0
230    }
231
232    fn pk_to_device(&self, pk: &StarkProvingKey<SC>) -> Self::DeviceProvingKey {
233        pk.clone()
234    }
235
236    fn pk_to_host(&self, pk: &Self::DeviceProvingKey) -> StarkProvingKey<SC> {
237        pk.clone()
238    }
239
240    fn commit(
241        &self,
242        record: &A::Record,
243        mut named_traces: Vec<(String, RowMajorMatrix<Val<SC>>)>,
244    ) -> ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData> {
245        // Order the chips and traces by trace size (biggest first), and get the ordering map.
246        named_traces.sort_by_key(|(name, trace)| (Reverse(trace.height()), name.clone()));
247
248        let pcs = self.config().pcs();
249
250        let domains_and_traces = named_traces
251            .iter()
252            .map(|(_, trace)| {
253                let domain = pcs.natural_domain_for_degree(trace.height());
254                (domain, trace.to_owned())
255            })
256            .collect::<Vec<_>>();
257
258        // Commit to the batch of traces.
259        let (main_commit, main_data) = pcs.commit(domains_and_traces);
260
261        // Get the chip ordering.
262        let chip_ordering =
263            named_traces.iter().enumerate().map(|(i, (name, _))| (name.to_owned(), i)).collect();
264
265        let traces = named_traces.into_iter().map(|(_, trace)| trace).collect::<Vec<_>>();
266
267        ShardMainData {
268            traces,
269            main_commit,
270            main_data,
271            chip_ordering,
272            public_values: record.public_values(),
273        }
274    }
275
276    /// Prove the program for the given shard and given a commitment to the main data.
277    #[allow(clippy::too_many_lines)]
278    #[allow(clippy::redundant_closure_for_method_calls)]
279    #[allow(clippy::map_unwrap_or)]
280    fn open(
281        &self,
282        pk: &StarkProvingKey<SC>,
283        data: ShardMainData<SC, Self::DeviceMatrix, Self::DeviceProverData>,
284        challenger: &mut <SC as StarkGenericConfig>::Challenger,
285    ) -> Result<ShardProof<SC>, Self::Error> {
286        let chips = self.machine().shard_chips_ordered(&data.chip_ordering).collect::<Vec<_>>();
287        let traces = data.traces;
288
289        let config = self.machine().config();
290
291        let degrees = traces.iter().map(|trace| trace.height()).collect::<Vec<_>>();
292
293        let log_degrees =
294            degrees.iter().map(|degree| log2_strict_usize(*degree)).collect::<Vec<_>>();
295
296        let log_quotient_degrees =
297            chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
298
299        let pcs = config.pcs();
300        let trace_domains =
301            degrees.iter().map(|degree| pcs.natural_domain_for_degree(*degree)).collect::<Vec<_>>();
302
303        // Observe the public values and the main commitment.
304        challenger.observe_slice(&data.public_values[0..self.num_pv_elts()]);
305        challenger.observe(data.main_commit.clone());
306
307        // Obtain the challenges used for the local permutation argument.
308        let mut local_permutation_challenges: Vec<SC::Challenge> = Vec::new();
309        for _ in 0..2 {
310            local_permutation_challenges.push(challenger.sample_ext_element());
311        }
312
313        let packed_perm_challenges = local_permutation_challenges
314            .iter()
315            .map(|c| PackedChallenge::<SC>::from_f(*c))
316            .collect::<Vec<_>>();
317
318        // Generate the permutation traces.
319        let ((permutation_traces, prep_traces), (global_cumulative_sums, local_cumulative_sums)): (
320            (Vec<_>, Vec<_>),
321            (Vec<_>, Vec<_>),
322        ) = tracing::debug_span!("generate permutation traces").in_scope(|| {
323            chips
324                .par_iter()
325                .zip(traces.par_iter())
326                .map(|(chip, main_trace)| {
327                    let preprocessed_trace =
328                        pk.chip_ordering.get(&chip.name()).map(|&index| &pk.traces[index]);
329                    let (perm_trace, local_sum) = chip.generate_permutation_trace(
330                        preprocessed_trace,
331                        main_trace,
332                        &local_permutation_challenges,
333                    );
334                    let global_sum = if chip.commit_scope() == InteractionScope::Local {
335                        SepticDigest::<Val<SC>>::zero()
336                    } else {
337                        let main_trace_size = main_trace.height() * main_trace.width();
338                        let last_row = &main_trace.values[main_trace_size - 14..main_trace_size];
339                        SepticDigest(SepticCurve {
340                            x: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i]),
341                            y: SepticExtension::<Val<SC>>::from_base_fn(|i| last_row[i + 7]),
342                        })
343                    };
344                    ((perm_trace, preprocessed_trace), (global_sum, local_sum))
345                })
346                .unzip()
347        });
348
349        // Compute some statistics.
350        for i in 0..chips.len() {
351            let trace_width = traces[i].width();
352            let trace_height = traces[i].height();
353            let prep_width = prep_traces[i].map_or(0, |x| x.width());
354            let permutation_width = permutation_traces[i].width();
355            let total_width = trace_width
356                + prep_width
357                + permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D;
358            tracing::debug!(
359                "{:<15} | Main Cols = {:<5} | Pre Cols = {:<5}  | Perm Cols = {:<5} | Rows = {:<5} | Cells = {:<10}",
360                chips[i].name(),
361                trace_width,
362                prep_width,
363                permutation_width * <SC::Challenge as AbstractExtensionField<SC::Val>>::D,
364                trace_height,
365                total_width * trace_height,
366            );
367        }
368
369        let domains_and_perm_traces =
370            tracing::debug_span!("flatten permutation traces and collect domains").in_scope(|| {
371                permutation_traces
372                    .into_iter()
373                    .zip(trace_domains.iter())
374                    .map(|(perm_trace, domain)| {
375                        let trace = perm_trace.flatten_to_base();
376                        (*domain, trace.clone())
377                    })
378                    .collect::<Vec<_>>()
379            });
380
381        let pcs = config.pcs();
382
383        let (permutation_commit, permutation_data) =
384            tracing::debug_span!("commit to permutation traces")
385                .in_scope(|| pcs.commit(domains_and_perm_traces));
386
387        // Observe the permutation commitment and cumulative sums.
388        challenger.observe(permutation_commit.clone());
389        for (local_sum, global_sum) in
390            local_cumulative_sums.iter().zip(global_cumulative_sums.iter())
391        {
392            challenger.observe_slice(local_sum.as_base_slice());
393            challenger.observe_slice(&global_sum.0.x.0);
394            challenger.observe_slice(&global_sum.0.y.0);
395        }
396
397        // Compute the quotient polynomial for all chips.
398        let quotient_domains = trace_domains
399            .iter()
400            .zip_eq(log_degrees.iter())
401            .zip_eq(log_quotient_degrees.iter())
402            .map(|((domain, log_degree), log_quotient_degree)| {
403                domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree))
404            })
405            .collect::<Vec<_>>();
406
407        // Compute the quotient values.
408        let alpha: SC::Challenge = challenger.sample_ext_element::<SC::Challenge>();
409        let parent_span = tracing::debug_span!("compute quotient values");
410        let quotient_values = parent_span.in_scope(|| {
411            quotient_domains
412                .into_par_iter()
413                .enumerate()
414                .map(|(i, quotient_domain)| {
415                    tracing::debug_span!(parent: &parent_span, "compute quotient values for domain")
416                        .in_scope(|| {
417                            let preprocessed_trace_on_quotient_domains =
418                                pk.chip_ordering.get(&chips[i].name()).map(|&index| {
419                                    pcs.get_evaluations_on_domain(&pk.data, index, *quotient_domain)
420                                        .to_row_major_matrix()
421                                });
422                            let main_trace_on_quotient_domains = pcs
423                                .get_evaluations_on_domain(&data.main_data, i, *quotient_domain)
424                                .to_row_major_matrix();
425                            let permutation_trace_on_quotient_domains = pcs
426                                .get_evaluations_on_domain(&permutation_data, i, *quotient_domain)
427                                .to_row_major_matrix();
428
429                            let chip_num_constraints =
430                                pk.constraints_map.get(&chips[i].name()).unwrap();
431
432                            // Calculate powers of alpha for constraint evaluation:
433                            // 1. Generate sequence [α⁰, α¹, ..., α^(n-1)] where n = chip_num_constraints.
434                            // 2. Reverse to [α^(n-1), ..., α¹, α⁰] to align with Horner's method in the verifier.
435                            let powers_of_alpha =
436                                alpha.powers().take(*chip_num_constraints).collect::<Vec<_>>();
437                            let mut powers_of_alpha_rev = powers_of_alpha.clone();
438                            powers_of_alpha_rev.reverse();
439
440                            quotient_values(
441                                chips[i],
442                                &local_cumulative_sums[i],
443                                &global_cumulative_sums[i],
444                                trace_domains[i],
445                                *quotient_domain,
446                                preprocessed_trace_on_quotient_domains,
447                                main_trace_on_quotient_domains,
448                                permutation_trace_on_quotient_domains,
449                                &packed_perm_challenges,
450                                &powers_of_alpha_rev,
451                                &data.public_values,
452                            )
453                        })
454                })
455                .collect::<Vec<_>>()
456        });
457
458        // Split the quotient values and commit to them.
459        let quotient_domains_and_chunks = quotient_domains
460            .into_iter()
461            .zip_eq(quotient_values)
462            .zip_eq(log_quotient_degrees.iter())
463            .flat_map(|((quotient_domain, quotient_values), log_quotient_degree)| {
464                let quotient_degree = 1 << *log_quotient_degree;
465                let quotient_flat = RowMajorMatrix::new_col(quotient_values).flatten_to_base();
466                let quotient_chunks = quotient_domain.split_evals(quotient_degree, quotient_flat);
467                let qc_domains = quotient_domain.split_domains(quotient_degree);
468                qc_domains.into_iter().zip_eq(quotient_chunks)
469            })
470            .collect::<Vec<_>>();
471
472        let num_quotient_chunks = quotient_domains_and_chunks.len();
473        assert_eq!(
474            num_quotient_chunks,
475            chips.iter().map(|c| 1 << c.log_quotient_degree()).sum::<usize>()
476        );
477
478        let (quotient_commit, quotient_data) = tracing::debug_span!("commit to quotient traces")
479            .in_scope(|| pcs.commit(quotient_domains_and_chunks));
480        challenger.observe(quotient_commit.clone());
481
482        // Compute the quotient argument.
483        let zeta: SC::Challenge = challenger.sample_ext_element();
484
485        let preprocessed_opening_points =
486            tracing::debug_span!("compute preprocessed opening points").in_scope(|| {
487                pk.traces
488                    .iter()
489                    .zip(pk.local_only.iter())
490                    .map(|(trace, local_only)| {
491                        let domain = pcs.natural_domain_for_degree(trace.height());
492                        if !local_only {
493                            vec![zeta, domain.next_point(zeta).unwrap()]
494                        } else {
495                            vec![zeta]
496                        }
497                    })
498                    .collect::<Vec<_>>()
499            });
500
501        let main_trace_opening_points = tracing::debug_span!("compute main trace opening points")
502            .in_scope(|| {
503                trace_domains
504                    .iter()
505                    .zip(chips.iter())
506                    .map(|(domain, chip)| {
507                        if !chip.local_only() {
508                            vec![zeta, domain.next_point(zeta).unwrap()]
509                        } else {
510                            vec![zeta]
511                        }
512                    })
513                    .collect::<Vec<_>>()
514            });
515
516        let permutation_trace_opening_points =
517            tracing::debug_span!("compute permutation trace opening points").in_scope(|| {
518                trace_domains
519                    .iter()
520                    .map(|domain| vec![zeta, domain.next_point(zeta).unwrap()])
521                    .collect::<Vec<_>>()
522            });
523
524        // Compute quotient opening points, open every chunk at zeta.
525        let quotient_opening_points =
526            (0..num_quotient_chunks).map(|_| vec![zeta]).collect::<Vec<_>>();
527
528        let (openings, opening_proof) = tracing::debug_span!("open multi batches").in_scope(|| {
529            pcs.open(
530                vec![
531                    (&pk.data, preprocessed_opening_points),
532                    (&data.main_data, main_trace_opening_points.clone()),
533                    (&permutation_data, permutation_trace_opening_points.clone()),
534                    (&quotient_data, quotient_opening_points),
535                ],
536                challenger,
537            )
538        });
539
540        // Collect the opened values for each chip.
541        let [preprocessed_values, main_values, permutation_values, mut quotient_values] =
542            openings.try_into().unwrap();
543        assert!(main_values.len() == chips.len());
544        let preprocessed_opened_values = preprocessed_values
545            .into_iter()
546            .zip(pk.local_only.iter())
547            .map(|(op, local_only)| {
548                if !local_only {
549                    let [local, next] = op.try_into().unwrap();
550                    AirOpenedValues { local, next }
551                } else {
552                    let [local] = op.try_into().unwrap();
553                    let width = local.len();
554                    AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
555                }
556            })
557            .collect::<Vec<_>>();
558
559        let main_opened_values = main_values
560            .into_iter()
561            .zip(chips.iter())
562            .map(|(op, chip)| {
563                if !chip.local_only() {
564                    let [local, next] = op.try_into().unwrap();
565                    AirOpenedValues { local, next }
566                } else {
567                    let [local] = op.try_into().unwrap();
568                    let width = local.len();
569                    AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
570                }
571            })
572            .collect::<Vec<_>>();
573        let permutation_opened_values = permutation_values
574            .into_iter()
575            .map(|op| {
576                let [local, next] = op.try_into().unwrap();
577                AirOpenedValues { local, next }
578            })
579            .collect::<Vec<_>>();
580        let mut quotient_opened_values = Vec::with_capacity(log_quotient_degrees.len());
581        for log_quotient_degree in log_quotient_degrees.iter() {
582            let degree = 1 << *log_quotient_degree;
583            let slice = quotient_values.drain(0..degree);
584            quotient_opened_values.push(slice.map(|mut op| op.pop().unwrap()).collect::<Vec<_>>());
585        }
586
587        let opened_values = main_opened_values
588            .into_iter()
589            .zip_eq(permutation_opened_values)
590            .zip_eq(quotient_opened_values)
591            .zip_eq(local_cumulative_sums)
592            .zip_eq(global_cumulative_sums)
593            .zip_eq(log_degrees.iter())
594            .enumerate()
595            .map(
596                |(
597                    i,
598                    (
599                        (
600                            (((main, permutation), quotient), local_cumulative_sum),
601                            global_cumulative_sum,
602                        ),
603                        log_degree,
604                    ),
605                )| {
606                    let preprocessed = pk
607                        .chip_ordering
608                        .get(&chips[i].name())
609                        .map(|&index| preprocessed_opened_values[index].clone())
610                        .unwrap_or(AirOpenedValues { local: vec![], next: vec![] });
611                    ChipOpenedValues {
612                        preprocessed,
613                        main,
614                        permutation,
615                        quotient,
616                        global_cumulative_sum,
617                        local_cumulative_sum,
618                        log_degree: *log_degree,
619                    }
620                },
621            )
622            .collect::<Vec<_>>();
623
624        Ok(ShardProof::<SC> {
625            commitment: ShardCommitment {
626                main_commit: data.main_commit.clone(),
627                permutation_commit,
628                quotient_commit,
629            },
630            opened_values: ShardOpenedValues { chips: opened_values },
631            opening_proof,
632            chip_ordering: data.chip_ordering,
633            public_values: data.public_values,
634        })
635    }
636
637    /// Prove the execution record is valid.
638    ///
639    /// Given a proving key `pk` and a matching execution record `record`, this function generates
640    /// a STARK proof that the execution record is valid.
641    #[allow(clippy::needless_for_each)]
642    fn prove(
643        &self,
644        pk: &StarkProvingKey<SC>,
645        mut records: Vec<A::Record>,
646        challenger: &mut SC::Challenger,
647        opts: <A::Record as MachineRecord>::Config,
648    ) -> Result<MachineProof<SC>, Self::Error>
649    where
650        A: for<'a> Air<DebugConstraintBuilder<'a, Val<SC>, SC::Challenge>>,
651    {
652        // Generate dependencies.
653        self.machine().generate_dependencies(&mut records, &opts, None);
654
655        // Observe the preprocessed commitment.
656        pk.observe_into(challenger);
657
658        let shard_proofs = tracing::info_span!("prove_shards").in_scope(|| {
659            records
660                .into_par_iter()
661                .map(|record| {
662                    let named_traces = self.generate_traces(&record);
663                    let shard_data = self.commit(&record, named_traces);
664                    self.open(pk, shard_data, &mut challenger.clone())
665                })
666                .collect::<Result<Vec<_>, _>>()
667        })?;
668
669        Ok(MachineProof { shard_proofs })
670    }
671}
672
673impl<SC> MachineProvingKey<SC> for StarkProvingKey<SC>
674where
675    SC: 'static + StarkGenericConfig + Send + Sync,
676    PcsProverData<SC>: Send + Sync + Serialize + DeserializeOwned,
677    Com<SC>: Send + Sync,
678{
679    fn preprocessed_commit(&self) -> Com<SC> {
680        self.commit.clone()
681    }
682
683    fn pc_start(&self) -> Val<SC> {
684        self.pc_start
685    }
686
687    fn initial_global_cumulative_sum(&self) -> SepticDigest<Val<SC>> {
688        self.initial_global_cumulative_sum
689    }
690
691    fn observe_into(&self, challenger: &mut Challenger<SC>) {
692        challenger.observe(self.commit.clone());
693        challenger.observe(self.pc_start);
694        challenger.observe_slice(&self.initial_global_cumulative_sum.0.x.0);
695        challenger.observe_slice(&self.initial_global_cumulative_sum.0.y.0);
696        let zero = Val::<SC>::zero();
697        challenger.observe(zero);
698    }
699}
700
701impl Display for CpuProverError {
702    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
703        write!(f, "DefaultProverError")
704    }
705}
706
707impl Error for CpuProverError {}