1use std::{
2 collections::{BTreeMap, BTreeSet, HashSet},
3 fs::File,
4 hash::{DefaultHasher, Hash, Hasher},
5 panic::{catch_unwind, AssertUnwindSafe},
6 path::PathBuf,
7 sync::{Arc, Mutex},
8};
9
10use eyre::Result;
11use p3_baby_bear::BabyBear;
12use p3_field::AbstractField;
13use serde::{Deserialize, Serialize};
14use sp1_core_machine::shape::CoreShapeConfig;
15use sp1_recursion_circuit::machine::{
16 SP1CompressWithVKeyWitnessValues, SP1DeferredWitnessValues, SP1RecursionWitnessValues,
17};
18use sp1_recursion_core::{
19 shape::{RecursionShape, RecursionShapeConfig},
20 RecursionProgram,
21};
22use sp1_stark::{shape::OrderedShape, MachineProver, DIGEST_SIZE};
23use thiserror::Error;
24
25pub use sp1_recursion_circuit::machine::{
26 SP1CompressWithVkeyShape, SP1DeferredShape, SP1RecursionShape,
27};
28
29use crate::{components::SP1ProverComponents, CompressAir, HashableKey, SP1Prover, ShrinkAir};
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
32pub enum SP1ProofShape {
33 Recursion(OrderedShape),
34 Compress(Vec<OrderedShape>),
35 Deferred(OrderedShape),
36 Shrink(OrderedShape),
37}
38
39#[derive(Debug, Clone, Hash)]
40pub enum SP1CompressProgramShape {
41 Recursion(SP1RecursionShape),
42 Compress(SP1CompressWithVkeyShape),
43 Deferred(SP1DeferredShape),
44 Shrink(SP1CompressWithVkeyShape),
45}
46
47impl SP1CompressProgramShape {
48 pub fn hash_u64(&self) -> u64 {
49 let mut hasher = DefaultHasher::new();
50 Hash::hash(&self, &mut hasher);
51 hasher.finish()
52 }
53}
54
55#[derive(Debug, Error)]
56pub enum VkBuildError {
57 #[error("IO error: {0}")]
58 IO(#[from] std::io::Error),
59 #[error("Serialization error: {0}")]
60 Bincode(#[from] bincode::Error),
61}
62
63pub fn check_shapes<C: SP1ProverComponents>(
64 reduce_batch_size: usize,
65 no_precompiles: bool,
66 num_compiler_workers: usize,
67 prover: &mut SP1Prover<C>,
68) -> bool {
69 let (shape_tx, shape_rx) =
70 std::sync::mpsc::sync_channel::<SP1CompressProgramShape>(num_compiler_workers);
71 let (panic_tx, panic_rx) = std::sync::mpsc::channel();
72 let core_shape_config = prover.core_shape_config.as_ref().expect("core shape config not found");
73 let recursion_shape_config =
74 prover.compress_shape_config.as_ref().expect("recursion shape config not found");
75
76 let shape_rx = Mutex::new(shape_rx);
77
78 let all_maximal_shapes = SP1ProofShape::generate_maximal_shapes(
79 core_shape_config,
80 recursion_shape_config,
81 reduce_batch_size,
82 no_precompiles,
83 )
84 .collect::<BTreeSet<SP1ProofShape>>();
85 let num_shapes = all_maximal_shapes.len();
86 tracing::debug!("number of shapes: {}", num_shapes);
87
88 let height = num_shapes.next_power_of_two().ilog2() as usize;
90
91 prover.join_programs_map.clear();
93
94 let compress_ok = std::thread::scope(|s| {
95 for _ in 0..num_compiler_workers {
97 let shape_rx = &shape_rx;
98 let prover = &prover;
99 let panic_tx = panic_tx.clone();
100 s.spawn(move || {
101 while let Ok(shape) = shape_rx.lock().unwrap().recv() {
102 tracing::debug!("shape is {:?}", shape);
103 let program = catch_unwind(AssertUnwindSafe(|| {
104 prover.program_from_shape(shape.clone(), None)
106 }));
107 match program {
108 Ok(_) => {}
109 Err(e) => {
110 tracing::warn!(
111 "Program generation failed for shape {:?}, with error: {:?}",
112 shape,
113 e
114 );
115 panic_tx.send(true).unwrap();
116 }
117 }
118 }
119 });
120 }
121
122 all_maximal_shapes.into_iter().for_each(|program_shape| {
124 shape_tx
125 .send(SP1CompressProgramShape::from_proof_shape(program_shape, height))
126 .unwrap();
127 });
128
129 drop(shape_tx);
130 drop(panic_tx);
131
132 panic_rx.iter().next().is_none()
134 });
135
136 compress_ok
137}
138
139pub fn build_vk_map<C: SP1ProverComponents + 'static>(
140 reduce_batch_size: usize,
141 dummy: bool,
142 num_compiler_workers: usize,
143 num_setup_workers: usize,
144 indices: Option<Vec<usize>>,
145) -> (BTreeSet<[BabyBear; DIGEST_SIZE]>, Vec<usize>, usize) {
146 let mut prover = SP1Prover::<C>::new();
148 prover.vk_verification = !dummy;
149 if !dummy {
150 prover.join_programs_map.clear();
151 }
152 let prover = Arc::new(prover);
153
154 let core_shape_config = prover.core_shape_config.as_ref().expect("core shape config not found");
156 let recursion_shape_config =
157 prover.compress_shape_config.as_ref().expect("recursion shape config not found");
158
159 let (vk_set, panic_indices, height) = if dummy {
160 tracing::warn!("building a dummy vk map");
161 let dummy_set = SP1ProofShape::dummy_vk_map(
162 core_shape_config,
163 recursion_shape_config,
164 reduce_batch_size,
165 )
166 .into_keys()
167 .collect::<BTreeSet<_>>();
168 let height = dummy_set.len().next_power_of_two().ilog2() as usize;
169 (dummy_set, vec![], height)
170 } else {
171 tracing::debug!("building vk map");
172
173 let (vk_tx, vk_rx) = std::sync::mpsc::channel();
175 let (shape_tx, shape_rx) =
176 std::sync::mpsc::sync_channel::<(usize, SP1CompressProgramShape)>(num_compiler_workers);
177 let (program_tx, program_rx) = std::sync::mpsc::sync_channel(num_setup_workers);
178 let (panic_tx, panic_rx) = std::sync::mpsc::channel();
179
180 let shape_rx = Mutex::new(shape_rx);
182 let program_rx = Mutex::new(program_rx);
183
184 let indices_set = indices.map(|indices| indices.into_iter().collect::<HashSet<_>>());
187 let mut all_shapes = BTreeSet::new();
188 let start = std::time::Instant::now();
189 for shape in
190 SP1ProofShape::generate(core_shape_config, recursion_shape_config, reduce_batch_size)
191 {
192 all_shapes.insert(shape);
193 }
194
195 let num_shapes = all_shapes.len();
196 tracing::debug!("number of shapes: {} in {:?}", num_shapes, start.elapsed());
197
198 let height = num_shapes.next_power_of_two().ilog2() as usize;
199 let chunk_size = indices_set.as_ref().map(|indices| indices.len()).unwrap_or(num_shapes);
200
201 std::thread::scope(|s| {
202 for _ in 0..num_compiler_workers {
204 let program_tx = program_tx.clone();
205 let shape_rx = &shape_rx;
206 let prover = prover.clone();
207 let panic_tx = panic_tx.clone();
208 s.spawn(move || {
209 while let Ok((i, shape)) = shape_rx.lock().unwrap().recv() {
210 eprintln!("shape: {shape:?}");
211 let is_shrink = matches!(shape, SP1CompressProgramShape::Shrink(_));
212 let prover = prover.clone();
213 let shape_clone = shape.clone();
214 let program_thread = std::thread::spawn(move || {
216 prover.program_from_shape(shape_clone, None)
217 });
218 match program_thread.join() {
219 Ok(program) => program_tx.send((i, program, is_shrink)).unwrap(),
220 Err(e) => {
221 tracing::warn!(
222 "Program generation failed for shape {} {:?}, with error: {:?}",
223 i,
224 shape,
225 e
226 );
227 panic_tx.send(i).unwrap();
228 }
229 }
230 }
231 });
232 }
233
234 for _ in 0..num_setup_workers {
236 let vk_tx = vk_tx.clone();
237 let program_rx = &program_rx;
238 let prover = &prover;
239 let panic_tx = panic_tx.clone();
240 s.spawn(move || {
241 let mut done = 0;
242 while let Ok((i, program, is_shrink)) = program_rx.lock().unwrap().recv() {
243 let prover = prover.clone();
244 let vk_thread = std::thread::spawn(move || {
245 if is_shrink {
246 prover.shrink_prover.setup(&program).1
247 } else {
248 prover.compress_prover.setup(&program).1
249 }
250 });
251 let vk = tracing::debug_span!("setup for program {}", i)
252 .in_scope(|| vk_thread.join());
253 done += 1;
254
255 if let Err(e) = vk {
256 tracing::error!("failed to setup program {}: {:?}", i, e);
257 panic_tx.send(i).unwrap();
258 continue;
259 }
260 let vk = vk.unwrap();
261
262 let vk_digest = vk.hash_babybear();
263 tracing::debug!(
264 "program {} = {:?}, {}% done",
265 i,
266 vk_digest,
267 done * 100 / chunk_size
268 );
269 vk_tx.send(vk_digest).unwrap();
270 }
271 });
272 }
273
274 let subset_shapes = all_shapes
276 .into_iter()
277 .enumerate()
278 .filter(|(i, _)| indices_set.as_ref().map(|set| set.contains(i)).unwrap_or(true))
279 .collect::<Vec<_>>();
280
281 subset_shapes
282 .clone()
283 .into_iter()
284 .map(|(i, shape)| (i, SP1CompressProgramShape::from_proof_shape(shape, height)))
285 .for_each(|(i, program_shape)| {
286 shape_tx.send((i, program_shape)).unwrap();
287 });
288
289 drop(shape_tx);
290 drop(program_tx);
291 drop(vk_tx);
292 drop(panic_tx);
293
294 let vk_set = vk_rx.iter().collect::<BTreeSet<_>>();
295
296 let panic_indices = panic_rx.iter().collect::<Vec<_>>();
297 for (i, shape) in subset_shapes {
298 if panic_indices.contains(&i) {
299 tracing::debug!("panic shape {}: {:?}", i, shape);
300 }
301 }
302
303 (vk_set, panic_indices, height)
304 })
305 };
306 tracing::debug!("compress vks generated, number of keys: {}", vk_set.len());
307 (vk_set, panic_indices, height)
308}
309
310pub fn build_vk_map_to_file<C: SP1ProverComponents + 'static>(
311 build_dir: PathBuf,
312 reduce_batch_size: usize,
313 dummy: bool,
314 num_compiler_workers: usize,
315 num_setup_workers: usize,
316 range_start: Option<usize>,
317 range_end: Option<usize>,
318) -> Result<(), VkBuildError> {
319 std::fs::create_dir_all(&build_dir)?;
321
322 let (vk_set, _, _) = build_vk_map::<C>(
324 reduce_batch_size,
325 dummy,
326 num_compiler_workers,
327 num_setup_workers,
328 range_start.and_then(|start| range_end.map(|end| (start..end).collect())),
329 );
330
331 let vk_map = vk_set.into_iter().enumerate().map(|(i, vk)| (vk, i)).collect::<BTreeMap<_, _>>();
333
334 let mut file = if dummy {
336 File::create(build_dir.join("dummy_vk_map.bin"))?
337 } else {
338 File::create(build_dir.join("vk_map.bin"))?
339 };
340
341 Ok(bincode::serialize_into(&mut file, &vk_map)?)
342}
343
344impl SP1ProofShape {
345 pub fn generate<'a>(
346 core_shape_config: &'a CoreShapeConfig<BabyBear>,
347 recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
348 reduce_batch_size: usize,
349 ) -> impl Iterator<Item = Self> + 'a {
350 core_shape_config
351 .all_shapes()
352 .map(Self::Recursion)
353 .chain((1..=reduce_batch_size).flat_map(|batch_size| {
354 recursion_shape_config.get_all_shape_combinations(batch_size).map(Self::Compress)
355 }))
356 .chain(
357 recursion_shape_config
358 .get_all_shape_combinations(1)
359 .map(|mut x| Self::Deferred(x.pop().unwrap())),
360 )
361 .chain(
362 recursion_shape_config
363 .get_all_shape_combinations(1)
364 .map(|mut x| Self::Shrink(x.pop().unwrap())),
365 )
366 }
367
368 pub fn generate_compress_shapes(
369 recursion_shape_config: &'_ RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
370 reduce_batch_size: usize,
371 ) -> impl Iterator<Item = Vec<OrderedShape>> + '_ {
372 recursion_shape_config.get_all_shape_combinations(reduce_batch_size)
373 }
374
375 pub fn generate_maximal_shapes<'a>(
376 core_shape_config: &'a CoreShapeConfig<BabyBear>,
377 recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
378 reduce_batch_size: usize,
379 no_precompiles: bool,
380 ) -> impl Iterator<Item = Self> + 'a {
381 let core_shape_iter = if no_precompiles {
382 core_shape_config.maximal_core_shapes(21).into_iter()
383 } else {
384 core_shape_config.maximal_core_plus_precompile_shapes(21).into_iter()
385 };
386 core_shape_iter
387 .map(|core_shape| {
388 Self::Recursion(OrderedShape {
389 inner: core_shape.into_iter().map(|(k, v)| (k.to_string(), v)).collect(),
390 })
391 })
392 .chain((1..=reduce_batch_size).flat_map(|batch_size| {
393 recursion_shape_config.get_all_shape_combinations(batch_size).map(Self::Compress)
394 }))
395 .chain(
396 recursion_shape_config
397 .get_all_shape_combinations(1)
398 .map(|mut x| Self::Deferred(x.pop().unwrap())),
399 )
400 .chain(
401 recursion_shape_config
402 .get_all_shape_combinations(1)
403 .map(|mut x| Self::Shrink(x.pop().unwrap())),
404 )
405 }
406
407 pub fn dummy_vk_map<'a>(
408 core_shape_config: &'a CoreShapeConfig<BabyBear>,
409 recursion_shape_config: &'a RecursionShapeConfig<BabyBear, CompressAir<BabyBear>>,
410 reduce_batch_size: usize,
411 ) -> BTreeMap<[BabyBear; DIGEST_SIZE], usize> {
412 Self::generate(core_shape_config, recursion_shape_config, reduce_batch_size)
413 .enumerate()
414 .map(|(i, _)| ([BabyBear::from_canonical_usize(i); DIGEST_SIZE], i))
415 .collect()
416 }
417}
418
419impl SP1CompressProgramShape {
420 pub fn from_proof_shape(shape: SP1ProofShape, height: usize) -> Self {
421 match shape {
422 SP1ProofShape::Recursion(proof_shape) => Self::Recursion(proof_shape.into()),
423 SP1ProofShape::Deferred(proof_shape) => {
424 Self::Deferred(SP1DeferredShape::new(vec![proof_shape].into(), height))
425 }
426 SP1ProofShape::Compress(proof_shapes) => Self::Compress(SP1CompressWithVkeyShape {
427 compress_shape: proof_shapes.into(),
428 merkle_tree_height: height,
429 }),
430 SP1ProofShape::Shrink(proof_shape) => Self::Shrink(SP1CompressWithVkeyShape {
431 compress_shape: vec![proof_shape].into(),
432 merkle_tree_height: height,
433 }),
434 }
435 }
436}
437
438impl<C: SP1ProverComponents> SP1Prover<C> {
439 pub fn program_from_shape(
440 &self,
441 shape: SP1CompressProgramShape,
442 shrink_shape: Option<RecursionShape>,
443 ) -> Arc<RecursionProgram<BabyBear>> {
444 match shape {
445 SP1CompressProgramShape::Recursion(shape) => {
446 let input = SP1RecursionWitnessValues::dummy(self.core_prover.machine(), &shape);
447 self.recursion_program(&input)
448 }
449 SP1CompressProgramShape::Deferred(shape) => {
450 let input = SP1DeferredWitnessValues::dummy(self.compress_prover.machine(), &shape);
451 self.deferred_program(&input)
452 }
453 SP1CompressProgramShape::Compress(shape) => {
454 let input =
455 SP1CompressWithVKeyWitnessValues::dummy(self.compress_prover.machine(), &shape);
456 self.compress_program(&input)
457 }
458 SP1CompressProgramShape::Shrink(shape) => {
459 let input =
460 SP1CompressWithVKeyWitnessValues::dummy(self.compress_prover.machine(), &shape);
461 self.shrink_program(
462 shrink_shape.unwrap_or_else(ShrinkAir::<BabyBear>::shrink_shape),
463 &input,
464 )
465 }
466 }
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 #![allow(clippy::print_stdout)]
473
474 use super::*;
475
476 #[test]
477 #[ignore]
478 fn test_generate_all_shapes() {
479 let core_shape_config = CoreShapeConfig::default();
480 let recursion_shape_config = RecursionShapeConfig::default();
481 let reduce_batch_size = 2;
482 let all_shapes =
483 SP1ProofShape::generate(&core_shape_config, &recursion_shape_config, reduce_batch_size)
484 .collect::<BTreeSet<_>>();
485
486 println!("Number of compress shapes: {}", all_shapes.len());
487 }
488}