Skip to main content

sp1_hypercube/prover/
simple.rs

1//! Simplified prover for test and development use.
2
3use slop_air::BaseAir;
4use slop_algebra::PrimeField32;
5use slop_challenger::IopCtx;
6use std::{collections::BTreeMap, collections::BTreeSet, sync::Arc};
7
8use crate::{
9    air::MachineAir,
10    prover::{shard::AirProver, CoreProofShape, PcsProof, ProvingKey},
11    MachineVerifier, MachineVerifierConfigError, MachineVerifyingKey, ShardContext, ShardProof,
12    ShardVerifier,
13};
14
15use super::{PreprocessedData, ProverSemaphore};
16
17/// Given a record, compute the shape of the resulting shard proof.
18///
19/// This is a standalone function that can be used outside of `SimpleProver`.
20pub fn shape_from_record<GC: IopCtx, SC: ShardContext<GC>>(
21    verifier: &MachineVerifier<GC, SC>,
22    record: &<<SC as ShardContext<GC>>::Air as MachineAir<GC::F>>::Record,
23) -> Option<CoreProofShape<GC::F, SC::Air>> {
24    let log_stacking_height = verifier.log_stacking_height() as usize;
25    let max_log_row_count = verifier.max_log_row_count();
26    let airs = verifier.machine().chips();
27    let shard_chips: BTreeSet<_> =
28        airs.iter().filter(|air| air.included(record)).cloned().collect();
29    let preprocessed_multiple = shard_chips
30        .iter()
31        .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
32        .sum::<usize>()
33        .div_ceil(1 << log_stacking_height);
34    let main_multiple = shard_chips
35        .iter()
36        .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
37        .sum::<usize>()
38        .div_ceil(1 << log_stacking_height);
39
40    let main_padding_cols = (main_multiple * (1 << log_stacking_height)
41        - shard_chips
42            .iter()
43            .map(|air| air.width() * air.num_rows(record).unwrap_or_default())
44            .sum::<usize>())
45    .div_ceil(1 << max_log_row_count)
46    .max(1);
47
48    let preprocessed_padding_cols = (preprocessed_multiple * (1 << log_stacking_height)
49        - shard_chips
50            .iter()
51            .map(|air| air.preprocessed_width() * air.num_rows(record).unwrap_or_default())
52            .sum::<usize>())
53    .div_ceil(1 << max_log_row_count)
54    .max(1);
55
56    let shard_chips = verifier.machine().smallest_cluster(&shard_chips).cloned()?;
57    Some(CoreProofShape {
58        shard_chips,
59        preprocessed_multiple,
60        main_multiple,
61        preprocessed_padding_cols,
62        main_padding_cols,
63    })
64}
65
66/// Create a single-permit semaphore for simple prover operations.
67fn single_permit() -> ProverSemaphore {
68    ProverSemaphore::new(1)
69}
70
71/// The type of program this prover can make proofs for.
72pub type Program<GC, SC> =
73    <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Program;
74
75/// The execution record for this prover.
76pub type Record<GC, SC> = <<SC as ShardContext<GC>>::Air as MachineAir<<GC as IopCtx>::F>>::Record;
77
78/// A prover that proves traces sequentially using a single `AirProver`.
79///
80/// Prioritizes simplicity over performance - suitable for tests and development.
81pub struct SimpleProver<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> {
82    /// The underlying prover.
83    prover: Arc<C>,
84    /// The verifier.
85    verifier: MachineVerifier<GC, SC>,
86}
87
88impl<GC: IopCtx, SC: ShardContext<GC>, C: AirProver<GC, SC>> SimpleProver<GC, SC, C> {
89    /// Create a new simple prover.
90    #[must_use]
91    pub fn new(shard_verifier: ShardVerifier<GC, SC>, prover: C) -> Self {
92        Self { prover: Arc::new(prover), verifier: MachineVerifier::new(shard_verifier) }
93    }
94
95    /// Verify a machine proof.
96    pub fn verify(
97        &self,
98        vk: &MachineVerifyingKey<GC>,
99        proof: &crate::MachineProof<GC, PcsProof<GC, SC>>,
100    ) -> Result<(), MachineVerifierConfigError<GC, SC::Config>>
101    where
102        GC::F: PrimeField32,
103    {
104        self.verifier.verify(vk, proof)
105    }
106
107    /// Get the verifier.
108    #[must_use]
109    #[inline]
110    pub fn verifier(&self) -> &MachineVerifier<GC, SC> {
111        &self.verifier
112    }
113
114    /// Get a new challenger.
115    #[must_use]
116    #[inline]
117    pub fn challenger(&self) -> GC::Challenger {
118        self.verifier.challenger()
119    }
120
121    /// Get the machine.
122    #[must_use]
123    #[inline]
124    pub fn machine(&self) -> &crate::Machine<GC::F, SC::Air> {
125        self.verifier.machine()
126    }
127
128    /// Get the maximum log row count.
129    #[must_use]
130    pub fn max_log_row_count(&self) -> usize {
131        self.verifier.max_log_row_count()
132    }
133
134    /// Get the log stacking height.
135    #[must_use]
136    pub fn log_stacking_height(&self) -> u32 {
137        self.verifier.log_stacking_height()
138    }
139
140    /// Given a record, compute the shape of the resulting shard proof.
141    pub fn shape_from_record(
142        &self,
143        record: &Record<GC, SC>,
144    ) -> Option<CoreProofShape<GC::F, SC::Air>> {
145        shape_from_record(&self.verifier, record)
146    }
147
148    /// Setup the prover for a given program.
149    #[inline]
150    #[must_use]
151    #[tracing::instrument(skip_all, name = "simple_setup")]
152    pub async fn setup(
153        &self,
154        program: Arc<Program<GC, SC>>,
155    ) -> (PreprocessedData<ProvingKey<GC, SC, C>>, MachineVerifyingKey<GC>) {
156        self.prover.setup(program, single_permit()).await
157    }
158
159    /// Prove a shard with a given proving key.
160    #[inline]
161    #[must_use]
162    #[tracing::instrument(skip_all, name = "simple_prove_shard")]
163    pub async fn prove_shard(
164        &self,
165        pk: Arc<ProvingKey<GC, SC, C>>,
166        record: Record<GC, SC>,
167    ) -> ShardProof<GC, PcsProof<GC, SC>> {
168        let (proof, _) = self.prover.prove_shard_with_pk(pk, record, single_permit()).await;
169
170        proof
171    }
172
173    /// Setup and prove a shard in one call.
174    #[inline]
175    #[must_use]
176    #[allow(clippy::type_complexity)]
177    #[tracing::instrument(skip_all, name = "simple_setup_and_prove_shard")]
178    pub async fn setup_and_prove_shard(
179        &self,
180        program: Arc<Program<GC, SC>>,
181        vk: Option<MachineVerifyingKey<GC>>,
182        record: Record<GC, SC>,
183    ) -> (MachineVerifyingKey<GC>, ShardProof<GC, PcsProof<GC, SC>>) {
184        let (vk, proof, _) =
185            self.prover.setup_and_prove_shard(program, record, vk, single_permit()).await;
186
187        (vk, proof)
188    }
189
190    /// Get the preprocessed table heights from the proving key.
191    pub async fn preprocessed_table_heights(
192        &self,
193        pk: Arc<ProvingKey<GC, SC, C>>,
194    ) -> BTreeMap<String, usize> {
195        C::preprocessed_table_heights(pk).await
196    }
197}