sp1_prover/
shapes.rs

1use std::{
2    collections::{BTreeMap, BTreeSet, HashSet},
3    fs::File,
4    hash::{DefaultHasher, Hash, Hasher},
5    panic::{catch_unwind, AssertUnwindSafe},
6    path::PathBuf,
7    sync::{Arc, Mutex},
8};
9
10use eyre::Result;
11use p3_baby_bear::BabyBear;
12use p3_field::AbstractField;
13use serde::{Deserialize, Serialize};
14use sp1_core_machine::shape::CoreShapeConfig;
15use sp1_recursion_circuit::machine::{
16    SP1CompressWithVKeyWitnessValues, SP1DeferredWitnessValues, SP1RecursionWitnessValues,
17};
18use sp1_recursion_core::{
19    shape::{RecursionShape, RecursionShapeConfig},
20    RecursionProgram,
21};
22use sp1_stark::{shape::OrderedShape, MachineProver, DIGEST_SIZE};
23use thiserror::Error;
24
25pub use sp1_recursion_circuit::machine::{
26    SP1CompressWithVkeyShape, SP1DeferredShape, SP1RecursionShape,
27};
28
29use crate::{components::SP1ProverComponents, CompressAir, HashableKey, SP1Prover, ShrinkAir};
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub enum SP1ProofShape {
33    Recursion(OrderedShape),
34    Compress(Vec<OrderedShape>),
35    Deferred(OrderedShape),
36    Shrink(OrderedShape),
37}
38
39#[derive(Debug, Clone, Hash)]
40pub enum SP1CompressProgramShape {
41    Recursion(SP1RecursionShape),
42    Compress(SP1CompressWithVkeyShape),
43    Deferred(SP1DeferredShape),
44    Shrink(SP1CompressWithVkeyShape),
45}
46
47impl SP1CompressProgramShape {
48    pub fn hash_u64(&self) -> u64 {
49        let mut hasher = DefaultHasher::new();
50        Hash::hash(&self, &mut hasher);
51        hasher.finish()
52    }
53}
54
55#[derive(Debug, Error)]
56pub enum VkBuildError {
57    #[error("IO error: {0}")]
58    IO(#[from] std::io::Error),
59    #[error("Serialization error: {0}")]
60    Bincode(#[from] bincode::Error),
61}
62
63pub fn check_shapes<C: SP1ProverComponents>(
64    reduce_batch_size: usize,
65    no_precompiles: bool,
66    num_compiler_workers: usize,
67    prover: &mut SP1Prover<C>,
68) -> bool {
69    let (shape_tx, shape_rx) =
70        std::sync::mpsc::sync_channel::<SP1CompressProgramShape>(num_compiler_workers);
71    let (panic_tx, panic_rx) = std::sync::mpsc::channel();
72    let core_shape_config = prover.core_shape_config.as_ref().expect("core shape config not found");
73    let recursion_shape_config =
74        prover.compress_shape_config.as_ref().expect("recursion shape config not found");
75
76    let shape_rx = Mutex::new(shape_rx);
77
78    let all_maximal_shapes = SP1ProofShape::generate_maximal_shapes(
79        core_shape_config,
80        recursion_shape_config,
81        reduce_batch_size,
82        no_precompiles,
83    )
84    .collect::<BTreeSet<SP1ProofShape>>();
85    let num_shapes = all_maximal_shapes.len();
86    tracing::debug!("number of shapes: {}", num_shapes);
87
88    // The Merkle tree height.
89    let height = num_shapes.next_power_of_two().ilog2() as usize;
90
91    // Empty the join program map so that we recompute the join program.
92    prover.join_programs_map.clear();
93
94    let compress_ok = std::thread::scope(|s| {
95        // Initialize compiler workers.
96        for _ in 0..num_compiler_workers {
97            let shape_rx = &shape_rx;
98            let prover = &prover;
99            let panic_tx = panic_tx.clone();
100            s.spawn(move || {
101                while let Ok(shape) = shape_rx.lock().unwrap().recv() {
102                    tracing::debug!("shape is {:?}", shape);
103                    let program = catch_unwind(AssertUnwindSafe(|| {
104                        // Try to build the recursion program from the given shape.
105                        prover.program_from_shape(shape.clone(), None)
106                    }));
107                    match program {
108                        Ok(_) => {}
109                        Err(e) => {
110                            tracing::warn!(
111                                "Program generation failed for shape {:?}, with error: {:?}",
112                                shape,
113                                e
114                            );
115                            panic_tx.send(true).unwrap();
116                        }
117                    }
118                }
119            });
120        }
121
122        // Generate shapes and send them to the compiler workers.
123        all_maximal_shapes.into_iter().for_each(|program_shape| {
124            shape_tx
125                .send(SP1CompressProgramShape::from_proof_shape(program_shape, height))
126                .unwrap();
127        });
128
129        drop(shape_tx);
130        drop(panic_tx);
131
132        // If the panic receiver has no panics, then the shape is correct.
133        panic_rx.iter().next().is_none()
134    });
135
136    compress_ok
137}
138
139pub fn build_vk_map<C: SP1ProverComponents + 'static>(
140    reduce_batch_size: usize,
141    dummy: bool,
142    num_compiler_workers: usize,
143    num_setup_workers: usize,
144    indices: Option<Vec<usize>>,
145) -> (BTreeSet<[BabyBear; DIGEST_SIZE]>, Vec<usize>, usize) {
146    // Setup the prover.
147    let mut prover = SP1Prover::<C>::new();
148    prover.vk_verification = !dummy;
149    if !dummy {
150        prover.join_programs_map.clear();
151    }
152    let prover = Arc::new(prover);
153
154    // Get the shape configs.
155    let core_shape_config = prover.core_shape_config.as_ref().expect("core shape config not found");
156    let recursion_shape_config =
157        prover.compress_shape_config.as_ref().expect("recursion shape config not found");
158
159    let (vk_set, panic_indices, height) = if dummy {
160        tracing::warn!("building a dummy vk map");
161        let dummy_set = SP1ProofShape::dummy_vk_map(
162            core_shape_config,
163            recursion_shape_config,
164            reduce_batch_size,
165        )
166        .into_keys()
167        .collect::<BTreeSet<_>>();
168        let height = dummy_set.len().next_power_of_two().ilog2() as usize;
169        (dummy_set, vec![], height)
170    } else {
171        tracing::debug!("building vk map");
172
173        // Setup the channels.
174        let (vk_tx, vk_rx) = std::sync::mpsc::channel();
175        let (shape_tx, shape_rx) =
176            std::sync::mpsc::sync_channel::<(usize, SP1CompressProgramShape)>(num_compiler_workers);
177        let (program_tx, program_rx) = std::sync::mpsc::sync_channel(num_setup_workers);
178        let (panic_tx, panic_rx) = std::sync::mpsc::channel();
179
180        // Setup the mutexes.
181        let shape_rx = Mutex::new(shape_rx);
182        let program_rx = Mutex::new(program_rx);
183
184        // Generate all the possible shape inputs we encounter in recursion. This may span lift,
185        // join, deferred, shrink, etc.
186        let indices_set = indices.map(|indices| indices.into_iter().collect::<HashSet<_>>());
187        let mut all_shapes = BTreeSet::new();
188        let start = std::time::Instant::now();
189        for shape in
190            SP1ProofShape::generate(core_shape_config, recursion_shape_config, reduce_batch_size)
191        {
192            all_shapes.insert(shape);
193        }
194
195        let num_shapes = all_shapes.len();
196        tracing::debug!("number of shapes: {} in {:?}", num_shapes, start.elapsed());
197
198        let height = num_shapes.next_power_of_two().ilog2() as usize;
199        let chunk_size = indices_set.as_ref().map(|indices| indices.len()).unwrap_or(num_shapes);
200
201        std::thread::scope(|s| {
202            // Initialize compiler workers.
203            for _ in 0..num_compiler_workers {
204                let program_tx = program_tx.clone();
205                let shape_rx = &shape_rx;
206                let prover = prover.clone();
207                let panic_tx = panic_tx.clone();
208                s.spawn(move || {
209                    while let Ok((i, shape)) = shape_rx.lock().unwrap().recv() {
210                        eprintln!("shape: {shape:?}");
211                        let is_shrink = matches!(shape, SP1CompressProgramShape::Shrink(_));
212                        let prover = prover.clone();
213                        let shape_clone = shape.clone();
214                        // Spawn on another thread to handle panics.
215                        let program_thread = std::thread::spawn(move || {
216                            prover.program_from_shape(shape_clone, None)
217                        });
218                        match program_thread.join() {
219                            Ok(program) => program_tx.send((i, program, is_shrink)).unwrap(),
220                            Err(e) => {
221                                tracing::warn!(
222                                    "Program generation failed for shape {} {:?}, with error: {:?}",
223                                    i,
224                                    shape,
225                                    e
226                                );
227                                panic_tx.send(i).unwrap();
228                            }
229                        }
230                    }
231                });
232            }
233
234            // Initialize setup workers.
235            for _ in 0..num_setup_workers {
236                let vk_tx = vk_tx.clone();
237                let program_rx = &program_rx;
238                let prover = &prover;
239                let panic_tx = panic_tx.clone();
240                s.spawn(move || {
241                    let mut done = 0;
242                    while let Ok((i, program, is_shrink)) = program_rx.lock().unwrap().recv() {
243                        let prover = prover.clone();
244                        let vk_thread = std::thread::spawn(move || {
245                            if is_shrink {
246                                prover.shrink_prover.setup(&program).1
247                            } else {
248                                prover.compress_prover.setup(&program).1
249                            }
250                        });
251                        let vk = tracing::debug_span!("setup for program {}", i)
252                            .in_scope(|| vk_thread.join());
253                        done += 1;
254
255                        if let Err(e) = vk {
256                            tracing::error!("failed to setup program {}: {:?}", i, e);
257                            panic_tx.send(i).unwrap();
258                            continue;
259                        }
260                        let vk = vk.unwrap();
261
262                        let vk_digest = vk.hash_babybear();
263                        tracing::debug!(
264                            "program {} = {:?}, {}% done",
265                            i,
266                            vk_digest,
267                            done * 100 / chunk_size
268                        );
269                        vk_tx.send(vk_digest).unwrap();
270                    }
271                });
272            }
273
274            // Generate shapes and send them to the compiler workers.
275            let subset_shapes = all_shapes
276                .into_iter()
277                .enumerate()
278                .filter(|(i, _)| indices_set.as_ref().map(|set| set.contains(i)).unwrap_or(true))
279                .collect::<Vec<_>>();
280
281            subset_shapes
282                .clone()
283                .into_iter()
284                .map(|(i, shape)| (i, SP1CompressProgramShape::from_proof_shape(shape, height)))
285                .for_each(|(i, program_shape)| {
286                    shape_tx.send((i, program_shape)).unwrap();
287                });
288
289            drop(shape_tx);
290            drop(program_tx);
291            drop(vk_tx);
292            drop(panic_tx);
293
294            let vk_set = vk_rx.iter().collect::<BTreeSet<_>>();
295
296            let panic_indices = panic_rx.iter().collect::<Vec<_>>();
297            for (i, shape) in subset_shapes {
298                if panic_indices.contains(&i) {
299                    tracing::debug!("panic shape {}: {:?}", i, shape);
300                }
301            }
302
303            (vk_set, panic_indices, height)
304        })
305    };
306    tracing::debug!("compress vks generated, number of keys: {}", vk_set.len());
307    (vk_set, panic_indices, height)
308}
309
310pub fn build_vk_map_to_file<C: SP1ProverComponents + 'static>(
311    build_dir: PathBuf,
312    reduce_batch_size: usize,
313    dummy: bool,
314    num_compiler_workers: usize,
315    num_setup_workers: usize,
316    range_start: Option<usize>,
317    range_end: Option<usize>,
318) -> Result<(), VkBuildError> {
319    // Create the build directory if it doesn't exist.
320    std::fs::create_dir_all(&build_dir)?;
321
322    // Build the vk map.
323    let (vk_set, _, _) = build_vk_map::<C>(
324        reduce_batch_size,
325        dummy,
326        num_compiler_workers,
327        num_setup_workers,
328        range_start.and_then(|start| range_end.map(|end| (start..end).collect())),
329    );
330
331    // Serialize the vk into an ordering.
332    let vk_map = vk_set.into_iter().enumerate().map(|(i, vk)| (vk, i)).collect::<BTreeMap<_, _>>();
333
334    // Create the file to store the vk map.
335    let mut file = if dummy {
336        File::create(build_dir.join("dummy_vk_map.bin"))?
337    } else {
338        File::create(build_dir.join("vk_map.bin"))?
339    };
340
341    Ok(bincode::serialize_into(&mut file, &vk_map)?)
342}
343
344impl SP1ProofShape {
345    pub fn generate<'a>(
346        core_shape_config: &'a CoreShapeConfig<BabyBear>,
347        recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
348        reduce_batch_size: usize,
349    ) -> impl Iterator<Item = Self> + 'a {
350        core_shape_config
351            .all_shapes()
352            .map(Self::Recursion)
353            .chain((1..=reduce_batch_size).flat_map(|batch_size| {
354                recursion_shape_config.get_all_shape_combinations(batch_size).map(Self::Compress)
355            }))
356            .chain(
357                recursion_shape_config
358                    .get_all_shape_combinations(1)
359                    .map(|mut x| Self::Deferred(x.pop().unwrap())),
360            )
361            .chain(
362                recursion_shape_config
363                    .get_all_shape_combinations(1)
364                    .map(|mut x| Self::Shrink(x.pop().unwrap())),
365            )
366    }
367
368    pub fn generate_compress_shapes(
369        recursion_shape_config: &'_ RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
370        reduce_batch_size: usize,
371    ) -> impl Iterator<Item = Vec<OrderedShape>> + '_ {
372        recursion_shape_config.get_all_shape_combinations(reduce_batch_size)
373    }
374
375    pub fn generate_maximal_shapes<'a>(
376        core_shape_config: &'a CoreShapeConfig<BabyBear>,
377        recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
378        reduce_batch_size: usize,
379        no_precompiles: bool,
380    ) -> impl Iterator<Item = Self> + 'a {
381        let core_shape_iter = if no_precompiles {
382            core_shape_config.maximal_core_shapes(21).into_iter()
383        } else {
384            core_shape_config.maximal_core_plus_precompile_shapes(21).into_iter()
385        };
386        core_shape_iter
387            .map(|core_shape| {
388                Self::Recursion(OrderedShape {
389                    inner: core_shape.into_iter().map(|(k, v)| (k.to_string(), v)).collect(),
390                })
391            })
392            .chain((1..=reduce_batch_size).flat_map(|batch_size| {
393                recursion_shape_config.get_all_shape_combinations(batch_size).map(Self::Compress)
394            }))
395            .chain(
396                recursion_shape_config
397                    .get_all_shape_combinations(1)
398                    .map(|mut x| Self::Deferred(x.pop().unwrap())),
399            )
400            .chain(
401                recursion_shape_config
402                    .get_all_shape_combinations(1)
403                    .map(|mut x| Self::Shrink(x.pop().unwrap())),
404            )
405    }
406
407    pub fn dummy_vk_map<'a>(
408        core_shape_config: &'a CoreShapeConfig<BabyBear>,
409        recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
410        reduce_batch_size: usize,
411    ) -> BTreeMap<[BabyBear; DIGEST_SIZE], usize> {
412        Self::generate(core_shape_config, recursion_shape_config, reduce_batch_size)
413            .enumerate()
414            .map(|(i, _)| ([BabyBear::from_canonical_usize(i); DIGEST_SIZE], i))
415            .collect()
416    }
417}
418
419impl SP1CompressProgramShape {
420    pub fn from_proof_shape(shape: SP1ProofShape, height: usize) -> Self {
421        match shape {
422            SP1ProofShape::Recursion(proof_shape) => Self::Recursion(proof_shape.into()),
423            SP1ProofShape::Deferred(proof_shape) => {
424                Self::Deferred(SP1DeferredShape::new(vec![proof_shape].into(), height))
425            }
426            SP1ProofShape::Compress(proof_shapes) => Self::Compress(SP1CompressWithVkeyShape {
427                compress_shape: proof_shapes.into(),
428                merkle_tree_height: height,
429            }),
430            SP1ProofShape::Shrink(proof_shape) => Self::Shrink(SP1CompressWithVkeyShape {
431                compress_shape: vec![proof_shape].into(),
432                merkle_tree_height: height,
433            }),
434        }
435    }
436}
437
438impl<C: SP1ProverComponents> SP1Prover<C> {
439    pub fn program_from_shape(
440        &self,
441        shape: SP1CompressProgramShape,
442        shrink_shape: Option<RecursionShape>,
443    ) -> Arc<RecursionProgram<BabyBear>> {
444        match shape {
445            SP1CompressProgramShape::Recursion(shape) => {
446                let input = SP1RecursionWitnessValues::dummy(self.core_prover.machine(), &shape);
447                self.recursion_program(&input)
448            }
449            SP1CompressProgramShape::Deferred(shape) => {
450                let input = SP1DeferredWitnessValues::dummy(self.compress_prover.machine(), &shape);
451                self.deferred_program(&input)
452            }
453            SP1CompressProgramShape::Compress(shape) => {
454                let input =
455                    SP1CompressWithVKeyWitnessValues::dummy(self.compress_prover.machine(), &shape);
456                self.compress_program(&input)
457            }
458            SP1CompressProgramShape::Shrink(shape) => {
459                let input =
460                    SP1CompressWithVKeyWitnessValues::dummy(self.compress_prover.machine(), &shape);
461                self.shrink_program(
462                    shrink_shape.unwrap_or_else(ShrinkAir::<BabyBear>::shrink_shape),
463                    &input,
464                )
465            }
466        }
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    #![allow(clippy::print_stdout)]
473
474    use super::*;
475
476    #[test]
477    #[ignore]
478    fn test_generate_all_shapes() {
479        let core_shape_config = CoreShapeConfig::default();
480        let recursion_shape_config = RecursionShapeConfig::default();
481        let reduce_batch_size = 2;
482        let all_shapes =
483            SP1ProofShape::generate(&core_shape_config, &recursion_shape_config, reduce_batch_size)
484                .collect::<BTreeSet<_>>();
485
486        println!("Number of compress shapes: {}", all_shapes.len());
487    }
488}