1use std::{
2 collections::{BTreeMap, BTreeSet},
3 fmt::Debug,
4 num::NonZero,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc, Mutex,
8 },
9};
10
11use hashbrown::HashSet;
12use lru::LruCache;
13use serde::{Deserialize, Serialize};
14use slop_air::BaseAir;
15use slop_algebra::AbstractField;
16use slop_basefold::FriConfig;
17use sp1_core_executor::MAX_PROGRAM_SIZE;
18use sp1_core_machine::{
19 bytes::columns::NUM_BYTE_PREPROCESSED_COLS, program::NUM_PROGRAM_PREPROCESSED_COLS,
20 range::columns::NUM_RANGE_PREPROCESSED_COLS, riscv::RiscvAir,
21};
22use sp1_hypercube::{
23 air::MachineAir,
24 log2_ceil_usize,
25 prover::{CoreProofShape, DefaultTraceGenerator, ProverSemaphore, TraceGenerator},
26 Chip, HashableKey, Machine, MachineShape, SP1PcsProofInner, SP1VerifyingKey,
27};
28use sp1_primitives::{
29 fri_params::{core_fri_config, CORE_LOG_BLOWUP},
30 SP1Field, SP1GlobalContext,
31};
32use sp1_prover_types::ArtifactClient;
33use sp1_recursion_circuit::{
34 dummy::{dummy_shard_proof, dummy_vk},
35 machine::{
36 SP1CompressWithVKeyWitnessValues, SP1MerkleProofWitnessValues, SP1NormalizeWitnessValues,
37 SP1ShapedWitnessValues,
38 },
39};
40use sp1_recursion_compiler::config::InnerConfig;
41use sp1_recursion_executor::{
42 shape::RecursionShape, RecursionAirEventCount, RecursionProgram, DIGEST_SIZE,
43};
44use sp1_recursion_machine::chips::{
45 alu_base::BaseAluChip,
46 alu_ext::ExtAluChip,
47 mem::{MemoryConstChip, MemoryVarChip},
48 poseidon2_helper::{
49 convert::ConvertChip, linear::Poseidon2LinearLayerChip, sbox::Poseidon2SBoxChip,
50 },
51 poseidon2_wide::Poseidon2WideChip,
52 prefix_sum_checks::PrefixSumChecksChip,
53 public_values::PublicValuesChip,
54 select::SelectChip,
55};
56use sp1_verifier::compressed::RECURSION_MAX_LOG_ROW_COUNT;
57use thiserror::Error;
58use tokio::task::JoinSet;
59
60use crate::{
61 components::{SP1ProverComponents, CORE_LOG_STACKING_HEIGHT},
62 recursion::{
63 compose_program_from_input, deferred_program_from_input, dummy_compose_input,
64 dummy_deferred_input, normalize_program_from_input, recursive_verifier,
65 shrink_program_from_input,
66 },
67 worker::{AirProverWorker, RecursionVkWorker},
68 CompressAir, CORE_MAX_LOG_ROW_COUNT,
69};
70
71pub const DEFAULT_ARITY: usize = 4;
72
73#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
76pub struct SP1NormalizeInputShape {
77 pub proof_shapes: Vec<CoreProofShape<SP1Field, RiscvAir<SP1Field>>>,
78 pub max_log_row_count: usize,
79 pub log_blowup: usize,
80 pub log_stacking_height: usize,
81}
82
83#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug)]
84pub enum SP1RecursionProgramShape {
85 Normalize(CoreProofShape<SP1Field, RiscvAir<SP1Field>>),
87 Compose(usize),
89 Deferred,
91 Shrink,
93}
94
95const PADDED_ELEMENT_THRESHOLD: u64 =
96 sp1_core_executor::ELEMENT_THRESHOLD + (1 << CORE_LOG_STACKING_HEIGHT);
97
98#[derive(Debug, Error)]
99pub enum VkBuildError {
100 #[error("IO error: {0}")]
101 IO(#[from] std::io::Error),
102 #[error("Serialization error: {0}")]
103 Bincode(#[from] bincode::Error),
104}
105
106impl SP1NormalizeInputShape {
107 pub fn dummy_input(
108 &self,
109 vk: SP1VerifyingKey,
110 ) -> SP1NormalizeWitnessValues<SP1GlobalContext, SP1PcsProofInner> {
111 let shard_proofs = self
112 .proof_shapes
113 .iter()
114 .map(|core_shape| {
115 dummy_shard_proof(
116 core_shape.shard_chips.clone(),
117 self.max_log_row_count,
118 core_fri_config(),
119 self.log_stacking_height,
120 &[
121 core_shape.preprocessed_area >> self.log_stacking_height,
122 core_shape.main_area >> self.log_stacking_height,
123 ],
124 &[core_shape.preprocessed_padding_cols, core_shape.main_padding_cols],
125 )
126 })
127 .collect::<Vec<_>>();
128
129 SP1NormalizeWitnessValues {
130 vk: vk.vk,
131 shard_proofs,
132 is_complete: false,
133 vk_root: [SP1Field::zero(); DIGEST_SIZE],
134 reconstruct_deferred_digest: [SP1Field::zero(); 8],
135 num_deferred_proofs: SP1Field::zero(),
136 }
137 }
138}
139
140pub struct SP1NormalizeCache {
141 lru: Arc<Mutex<LruCache<SP1NormalizeInputShape, Arc<RecursionProgram<SP1Field>>>>>,
142 total_calls: AtomicUsize,
143 hits: AtomicUsize,
144}
145
146impl SP1NormalizeCache {
147 pub fn new(size: usize) -> Self {
148 let size = NonZero::new(size).expect("size must be non-zero");
149 let lru = LruCache::new(size);
150 let lru = Arc::new(Mutex::new(lru));
151 Self { lru, total_calls: AtomicUsize::new(0), hits: AtomicUsize::new(0) }
152 }
153
154 pub fn get(&self, shape: &SP1NormalizeInputShape) -> Option<Arc<RecursionProgram<SP1Field>>> {
155 self.total_calls.fetch_add(1, Ordering::Relaxed);
156 if let Some(program) = self.lru.lock().unwrap().get(shape).cloned() {
157 self.hits.fetch_add(1, Ordering::Relaxed);
158 Some(program)
159 } else {
160 None
161 }
162 }
163
164 pub fn push(&self, shape: SP1NormalizeInputShape, program: Arc<RecursionProgram<SP1Field>>) {
165 self.lru.lock().unwrap().push(shape, program);
166 }
167
168 pub fn stats(&self) -> (usize, usize, f64) {
169 (
170 self.total_calls.load(Ordering::Relaxed),
171 self.hits.load(Ordering::Relaxed),
172 self.hits.load(Ordering::Relaxed) as f64
173 / self.total_calls.load(Ordering::Relaxed) as f64,
174 )
175 }
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord, Hash)]
179pub struct SP1RecursionProofShape {
180 pub shape: RecursionShape<SP1Field>,
181}
182
183impl Default for SP1RecursionProofShape {
184 fn default() -> Self {
185 Self::compress_proof_shape_from_arity(DEFAULT_ARITY).unwrap()
186 }
187}
188
189impl SP1RecursionProofShape {
190 pub fn compress_proof_shape_from_arity(arity: usize) -> Option<Self> {
191 match arity {
192 DEFAULT_ARITY => {
193 let file = include_bytes!("../compress_shape.json");
194 serde_json::from_slice(file).ok().or_else(|| {
195 tracing::warn!("Failed to load compress_shape.json, using default shape.");
196 Some(SP1RecursionProofShape {
199 shape: [
200 (
201 CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()),
202 600_000usize.next_multiple_of(32),
203 ),
204 (
205 CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::default()),
206 500_000usize.next_multiple_of(32),
207 ),
208 (
209 CompressAir::<SP1Field>::BaseAlu(BaseAluChip),
210 500_000usize.next_multiple_of(32),
211 ),
212 (
213 CompressAir::<SP1Field>::ExtAlu(ExtAluChip),
214 850_000usize.next_multiple_of(32),
215 ),
216 (
217 CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip),
218 150_448usize.next_multiple_of(32),
219 ),
220 (
221 CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip),
222 275_440usize.next_multiple_of(32),
223 ),
224 (
225 CompressAir::<SP1Field>::Select(SelectChip),
226 800_000usize.next_multiple_of(32),
227 ),
228 (CompressAir::<SP1Field>::PublicValues(PublicValuesChip), 16usize),
229 ]
230 .into_iter()
231 .collect(),
232 })
233 })
234 }
235 _ => None,
236 }
237 }
238
239 pub fn dummy_input(
240 &self,
241 arity: usize,
242 height: usize,
243 chips: BTreeSet<Chip<SP1Field, CompressAir<SP1Field>>>,
244 max_log_row_count: usize,
245 fri_config: FriConfig<SP1Field>,
246 log_stacking_height: usize,
247 ) -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
248 let dummy_vk = dummy_vk();
249
250 let preprocessed_multiple = chips
251 .iter()
252 .map(|chip| self.shape.height(chip).unwrap() * chip.preprocessed_width())
253 .sum::<usize>()
254 .div_ceil(1 << log_stacking_height);
255
256 let main_multiple = chips
257 .iter()
258 .map(|chip| self.shape.height(chip).unwrap() * chip.width())
259 .sum::<usize>()
260 .div_ceil(1 << log_stacking_height);
261
262 let preprocessed_padding_cols = ((preprocessed_multiple * (1 << log_stacking_height))
263 - chips
264 .iter()
265 .map(|chip| self.shape.height(chip).unwrap() * chip.preprocessed_width())
266 .sum::<usize>())
267 .div_ceil(1 << max_log_row_count)
268 .max(1);
269
270 let main_padding_cols = ((main_multiple * (1 << log_stacking_height))
271 - chips
272 .iter()
273 .map(|chip| self.shape.height(chip).unwrap() * chip.width())
274 .sum::<usize>())
275 .div_ceil(1 << max_log_row_count)
276 .max(1);
277
278 let dummy_proof = dummy_shard_proof(
279 chips,
280 max_log_row_count,
281 fri_config,
282 log_stacking_height,
283 &[preprocessed_multiple, main_multiple],
284 &[preprocessed_padding_cols, main_padding_cols],
285 );
286
287 let vks_and_proofs =
288 (0..arity).map(|_| (dummy_vk.clone(), dummy_proof.clone())).collect::<Vec<_>>();
289
290 SP1CompressWithVKeyWitnessValues {
291 compress_val: SP1ShapedWitnessValues { vks_and_proofs, is_complete: false },
292 merkle_val: SP1MerkleProofWitnessValues::dummy(arity, height),
293 }
294 }
295
296 pub async fn check_compatibility(
297 &self,
298 program: Arc<RecursionProgram<SP1Field>>,
299 machine: Machine<SP1Field, CompressAir<SP1Field>>,
300 ) -> bool {
301 let trace_generator = DefaultTraceGenerator::new(machine);
303 let setup_permits = ProverSemaphore::new(1);
304 let preprocessed_traces = trace_generator
305 .generate_preprocessed_traces(program, RECURSION_MAX_LOG_ROW_COUNT, setup_permits)
306 .await;
307
308 let mut is_compatible = true;
309 for (chip, trace) in preprocessed_traces.preprocessed_traces.into_iter() {
310 let real_height = trace.num_real_entries();
311 let expected_height = self.shape.height_of_name(&chip).unwrap();
312 if real_height > expected_height {
313 tracing::warn!(
314 "program is incompatible with shape: {} > {} for chip {}",
315 real_height,
316 expected_height,
317 chip
318 );
319 is_compatible = false;
320 }
321 }
322 is_compatible
323 }
324
325 #[allow(dead_code)]
326 async fn max_arity<C: SP1ProverComponents>(
327 &self,
328 vk_verification: bool,
329 height: usize,
330 ) -> usize {
331 let mut arity = 0;
332 let compress_verifier = C::compress_verifier();
333 let recursive_compress_verifier =
334 recursive_verifier::<_, _, InnerConfig>(compress_verifier.shard_verifier());
335 for possible_arity in 1.. {
336 let input = dummy_compose_input(&compress_verifier, self, possible_arity, height);
337 let program =
338 compose_program_from_input(&recursive_compress_verifier, vk_verification, &input);
339 let program = Arc::new(program);
340 let is_compatible =
341 self.check_compatibility(program, compress_verifier.machine().clone()).await;
342 if !is_compatible {
343 break;
344 }
345 arity = possible_arity;
346 }
347 arity
348 }
349}
350
351pub async fn build_vk_map<A: ArtifactClient, C: SP1ProverComponents + 'static>(
352 dummy: bool,
353 num_compiler_workers: usize,
354 num_setup_workers: usize,
355 indices: Option<Vec<usize>>,
356 max_arity: usize,
357 merkle_tree_height: usize,
358 vk_worker: Arc<RecursionVkWorker<C>>,
359) -> (BTreeSet<[SP1Field; DIGEST_SIZE]>, Vec<usize>) {
360 let recursion_permits = vk_worker.recursion_permits.clone();
361 let recursion_prover = vk_worker.recursion_prover.clone();
362 let shrink_prover = vk_worker.shrink_prover.clone();
363 if dummy {
364 let dummy_set = dummy_vk_map::<C>().into_keys().collect();
365 return (dummy_set, vec![]);
366 }
367
368 let (vk_tx, mut vk_rx) =
370 tokio::sync::mpsc::unbounded_channel::<(usize, [SP1Field; DIGEST_SIZE])>();
371 let (shape_tx, shape_rx) =
372 tokio::sync::mpsc::channel::<(usize, SP1RecursionProgramShape)>(num_compiler_workers);
373 let (program_tx, program_rx) = tokio::sync::mpsc::channel(num_setup_workers);
374 let (panic_tx, mut panic_rx) = tokio::sync::mpsc::unbounded_channel();
375
376 let shape_rx = Arc::new(tokio::sync::Mutex::new(shape_rx));
378 let program_rx = Arc::new(tokio::sync::Mutex::new(program_rx));
379
380 let all_shapes = create_all_input_shapes(C::core_verifier().machine().shape(), max_arity);
383
384 let num_shapes = all_shapes.len();
385
386 let height = if indices.is_some() { merkle_tree_height } else { log2_ceil_usize(num_shapes) };
387
388 let indices_set = indices.map(|indices| indices.into_iter().collect::<HashSet<_>>());
389
390 let vk_map_size = indices_set.as_ref().map(|indices| indices.len()).unwrap_or(num_shapes);
391
392 let mut set = JoinSet::new();
393
394 for _ in 0..num_compiler_workers {
396 let program_tx = program_tx.clone();
397 let shape_rx = shape_rx.clone();
398 let panic_tx = panic_tx.clone();
399 set.spawn(async move {
400 while let Some((i, shape)) = shape_rx.lock().await.recv().await {
401 let compress_verifier = C::compress_verifier();
403 let recursive_compress_verifier =
404 recursive_verifier::<_, _, InnerConfig>(compress_verifier.shard_verifier());
405 let program_thread = tokio::spawn(async move {
407 let reduce_shape =
408 SP1RecursionProofShape::compress_proof_shape_from_arity(max_arity);
409 match shape {
410 SP1RecursionProgramShape::Normalize(shape_clone) => {
411 let normalize_shape = SP1NormalizeInputShape {
412 proof_shapes: vec![shape_clone],
413 max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
414 log_blowup: CORE_LOG_BLOWUP,
415 log_stacking_height: CORE_LOG_STACKING_HEIGHT as usize,
416 };
417 let dummy_vk = dummy_vk();
418 let core_verifier = C::core_verifier();
419 let recursive_core_verifier = recursive_verifier::<_, _, InnerConfig>(
420 core_verifier.shard_verifier(),
421 );
422 let witness =
423 normalize_shape.dummy_input(SP1VerifyingKey { vk: dummy_vk });
424 let mut program =
425 normalize_program_from_input(&recursive_core_verifier, &witness);
426 program.shape =
427 Some(reduce_shape.clone().expect("max arity not supported").shape);
428 (Arc::new(program), false)
429 }
430 SP1RecursionProgramShape::Compose(arity) => {
431 let dummy_input = dummy_compose_input(
432 &compress_verifier,
433 &SP1RecursionProofShape::compress_proof_shape_from_arity(max_arity)
434 .expect("max arity not supported"),
435 arity,
436 height,
437 );
438
439 let mut program = compose_program_from_input(
440 &recursive_compress_verifier,
441 true,
442 &dummy_input,
443 );
444 program.shape =
445 Some(reduce_shape.clone().expect("max arity not supported").shape);
446 (Arc::new(program), false)
447 }
448 SP1RecursionProgramShape::Deferred => {
449 let dummy_input = dummy_deferred_input(
450 &C::compress_verifier(),
451 &reduce_shape.clone().expect("max arity not supported"),
452 height,
453 );
454 let mut program = deferred_program_from_input(
455 &recursive_compress_verifier,
456 true,
457 &dummy_input,
458 );
459
460 program.shape =
461 Some(reduce_shape.clone().expect("max arity not supported").shape);
462
463 (Arc::new(program), false)
464 }
465 SP1RecursionProgramShape::Shrink => {
466 let dummy_input = dummy_compose_input(
467 &C::compress_verifier(),
468 &reduce_shape.clone().expect("max arity not supported"),
469 1,
470 height,
471 );
472 let program = shrink_program_from_input(
473 &recursive_compress_verifier,
474 true,
475 &dummy_input,
476 );
477
478 (Arc::new(program), true)
479 }
480 }
481 });
482 match program_thread.await {
483 Ok((program, is_shrink)) => {
484 program_tx.send((i, program, is_shrink)).await.unwrap()
485 }
486 Err(e) => {
487 tracing::warn!(
488 "Program generation failed for shape {}, with error: {:?}",
489 i,
490 e
491 );
492 panic_tx.send(i).unwrap();
493 }
494 }
495 }
496 });
497 }
498
499 let recursion_prover = recursion_prover.clone();
500 for _ in 0..num_setup_workers {
502 let vk_tx = vk_tx.clone();
503 let program_rx = program_rx.clone();
504 let prover = recursion_prover.clone();
505 let recursion_permits = recursion_permits.clone();
506 let shrink_prover = shrink_prover.clone();
507 set.spawn(async move {
508 let mut done = 0;
509 while let Some((i, program, is_shrink)) = program_rx.lock().await.recv().await {
510 let prover = prover.clone();
511 let shrink_prover = shrink_prover.clone();
512 let recursion_permits = recursion_permits.clone();
513 let vk_thread = tokio::spawn(async move {
514 if is_shrink {
515 shrink_prover.setup(program).await
516 } else {
517 prover.setup(program, recursion_permits.clone()).await.1
518 }
519 });
520 let vk = vk_thread.await.unwrap();
521 done += 1;
522
523 let vk_digest = vk.hash_koalabear();
524
525 tracing::info!(
526 "program {} = {:?}, {}% done",
527 i,
528 vk_digest,
529 done * 100 / vk_map_size
530 );
531 vk_tx.send((i, vk_digest)).unwrap();
532 }
533 });
534 }
535
536 let subset_shapes = all_shapes
538 .into_iter()
539 .enumerate()
540 .filter(|(i, _)| indices_set.as_ref().map(|set| set.contains(i)).unwrap_or(true))
541 .collect::<Vec<_>>();
542
543 for (i, shape) in subset_shapes.iter() {
544 shape_tx.send((*i, shape.clone())).await.unwrap();
545 }
546
547 drop(shape_tx);
548 drop(program_tx);
549 drop(vk_tx);
550 drop(panic_tx);
551
552 set.join_all().await;
553
554 let mut vk_map = BTreeMap::new();
555 while let Some((i, vk)) = vk_rx.recv().await {
556 vk_map.insert(i, vk);
557 }
558
559 let mut panic_indices = vec![];
560 while let Some(i) = panic_rx.recv().await {
561 panic_indices.push(i);
562 }
563 for (i, shape) in subset_shapes {
564 if panic_indices.contains(&i) {
565 tracing::info!("panic shape {}: {:?}", i, shape);
566 }
567 }
568
569 let vk_set: BTreeSet<[SP1Field; DIGEST_SIZE]> = vk_map.into_values().collect();
571
572 (vk_set, panic_indices)
573}
574
575fn max_main_multiple_for_preprocessed_multiple(preprocessed_multiple: usize) -> usize {
576 (PADDED_ELEMENT_THRESHOLD - preprocessed_multiple as u64 * (1 << CORE_LOG_STACKING_HEIGHT))
577 .div_ceil(1 << CORE_LOG_STACKING_HEIGHT as u64) as usize
578}
579
580pub fn create_all_input_shapes(
581 core_shape: &MachineShape<SP1Field, RiscvAir<SP1Field>>,
582 max_arity: usize,
583) -> Vec<SP1RecursionProgramShape> {
584 let (max_preprocessed_multiple, _, capacity) = normalize_program_parameter_space();
585 let max_num_padding_cols =
586 ((1 << CORE_LOG_STACKING_HEIGHT) as usize).div_ceil(1 << CORE_MAX_LOG_ROW_COUNT);
587
588 let mut result: Vec<SP1RecursionProgramShape> = Vec::with_capacity(capacity);
589 for preprocessed_multiple in 1..=max_preprocessed_multiple {
590 for main_multiple in 1..=max_main_multiple_for_preprocessed_multiple(preprocessed_multiple)
591 {
592 for main_padding_cols in 1..=max_num_padding_cols {
593 for preprocessed_padding_cols in 1..=max_num_padding_cols {
594 for cluster in &core_shape.chip_clusters {
595 result.push(SP1RecursionProgramShape::Normalize(CoreProofShape {
596 shard_chips: cluster.clone(),
597 preprocessed_area: preprocessed_multiple << CORE_LOG_STACKING_HEIGHT,
598 main_area: main_multiple << CORE_LOG_STACKING_HEIGHT,
599 preprocessed_padding_cols,
600 main_padding_cols,
601 }));
602 }
603 }
604 }
605 }
606 }
607
608 for arity in 1..=max_arity {
610 result.push(SP1RecursionProgramShape::Compose(arity));
611 }
612
613 result.push(SP1RecursionProgramShape::Deferred);
615 result.push(SP1RecursionProgramShape::Shrink);
617 result
618}
619
620pub fn normalize_program_parameter_space() -> (usize, usize, usize) {
621 let max_preprocessed_multiple = (MAX_PROGRAM_SIZE * NUM_PROGRAM_PREPROCESSED_COLS
622 + (1 << 17) * NUM_RANGE_PREPROCESSED_COLS
623 + (1 << 16) * NUM_BYTE_PREPROCESSED_COLS)
624 .div_ceil(1 << CORE_LOG_STACKING_HEIGHT);
625 let max_main_multiple =
626 (PADDED_ELEMENT_THRESHOLD).div_ceil(1 << CORE_LOG_STACKING_HEIGHT) as usize;
627
628 let num_shapes = (0..=max_preprocessed_multiple)
629 .map(max_main_multiple_for_preprocessed_multiple)
630 .sum::<usize>();
631
632 (max_preprocessed_multiple, max_main_multiple, num_shapes)
633}
634
635pub fn dummy_vk_map<C: SP1ProverComponents>() -> BTreeMap<[SP1Field; DIGEST_SIZE], usize> {
636 create_all_input_shapes(C::core_verifier().machine().shape(), DEFAULT_ARITY)
637 .iter()
638 .enumerate()
639 .map(|(i, _)| ([SP1Field::from_canonical_usize(i); DIGEST_SIZE], i))
640 .collect()
641}
642
643pub fn max_count(a: RecursionAirEventCount, b: RecursionAirEventCount) -> RecursionAirEventCount {
644 use std::cmp::max;
645 RecursionAirEventCount {
646 mem_const_events: max(a.mem_const_events, b.mem_const_events),
647 mem_var_events: max(a.mem_var_events, b.mem_var_events),
648 base_alu_events: max(a.base_alu_events, b.base_alu_events),
649 ext_alu_events: max(a.ext_alu_events, b.ext_alu_events),
650 ext_felt_conversion_events: max(a.ext_felt_conversion_events, b.ext_felt_conversion_events),
651 poseidon2_wide_events: max(a.poseidon2_wide_events, b.poseidon2_wide_events),
652 poseidon2_linear_layer_events: max(
653 a.poseidon2_linear_layer_events,
654 b.poseidon2_linear_layer_events,
655 ),
656 poseidon2_sbox_events: max(a.poseidon2_sbox_events, b.poseidon2_sbox_events),
657 select_events: max(a.select_events, b.select_events),
658 prefix_sum_checks_events: max(a.prefix_sum_checks_events, b.prefix_sum_checks_events),
659 commit_pv_hash_events: max(a.commit_pv_hash_events, b.commit_pv_hash_events),
660 }
661}
662
663pub fn create_test_shape(
664 cluster: &BTreeSet<Chip<SP1Field, RiscvAir<SP1Field>>>,
665) -> SP1NormalizeInputShape {
666 let preprocessed_multiple = (MAX_PROGRAM_SIZE * NUM_PROGRAM_PREPROCESSED_COLS
667 + (1 << 17) * NUM_RANGE_PREPROCESSED_COLS
668 + (1 << 16) * NUM_BYTE_PREPROCESSED_COLS)
669 .div_ceil(1 << CORE_LOG_STACKING_HEIGHT);
670 let main_multiple = (PADDED_ELEMENT_THRESHOLD).div_ceil(1 << CORE_LOG_STACKING_HEIGHT) as usize;
671 let num_padding_cols =
672 ((1 << CORE_LOG_STACKING_HEIGHT) as usize).div_ceil(1 << CORE_MAX_LOG_ROW_COUNT);
673 SP1NormalizeInputShape {
674 proof_shapes: vec![CoreProofShape {
675 shard_chips: cluster.clone(),
676 preprocessed_area: preprocessed_multiple << CORE_LOG_STACKING_HEIGHT,
677 main_area: main_multiple << CORE_LOG_STACKING_HEIGHT,
678 preprocessed_padding_cols: num_padding_cols,
679 main_padding_cols: num_padding_cols,
680 }],
681 max_log_row_count: CORE_MAX_LOG_ROW_COUNT,
682 log_stacking_height: CORE_LOG_STACKING_HEIGHT as usize,
683 log_blowup: CORE_LOG_BLOWUP,
684 }
685}
686
687pub fn build_recursion_count_from_shape(
688 shape: &RecursionShape<SP1Field>,
689) -> RecursionAirEventCount {
690 RecursionAirEventCount {
691 mem_const_events: shape
692 .height(&CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()))
693 .unwrap(),
694 mem_var_events: shape
695 .height(&CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::<SP1Field, 2>::default()))
696 .unwrap()
697 * 2,
698 base_alu_events: shape.height(&CompressAir::<SP1Field>::BaseAlu(BaseAluChip)).unwrap(),
699 ext_alu_events: shape.height(&CompressAir::<SP1Field>::ExtAlu(ExtAluChip)).unwrap(),
700 ext_felt_conversion_events: shape
701 .height(&CompressAir::<SP1Field>::ExtFeltConvert(ConvertChip))
702 .unwrap_or(0),
703 poseidon2_wide_events: shape
704 .height(&CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip))
705 .unwrap_or(0),
706 poseidon2_linear_layer_events: shape
707 .height(&CompressAir::<SP1Field>::Poseidon2LinearLayer(Poseidon2LinearLayerChip))
708 .unwrap_or(0),
709 poseidon2_sbox_events: shape
710 .height(&CompressAir::<SP1Field>::Poseidon2SBox(Poseidon2SBoxChip))
711 .unwrap_or(0),
712 select_events: shape.height(&CompressAir::<SP1Field>::Select(SelectChip)).unwrap(),
713 prefix_sum_checks_events: shape
714 .height(&CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip))
715 .unwrap_or(0),
716 commit_pv_hash_events: shape
717 .height(&CompressAir::<SP1Field>::PublicValues(PublicValuesChip))
718 .unwrap(),
719 }
720}
721
722pub fn build_shape_from_recursion_air_event_count(
723 event_count: &RecursionAirEventCount,
724) -> SP1RecursionProofShape {
725 let &RecursionAirEventCount {
726 mem_const_events,
727 mem_var_events,
728 base_alu_events,
729 ext_alu_events,
730 poseidon2_wide_events,
731 select_events,
732 prefix_sum_checks_events,
733 commit_pv_hash_events,
734 ..
735 } = event_count;
736 let chips_and_heights = vec![
737 (CompressAir::<SP1Field>::MemoryConst(MemoryConstChip::default()), mem_const_events),
738 (
739 CompressAir::<SP1Field>::MemoryVar(MemoryVarChip::<SP1Field, 2>::default()),
740 mem_var_events / 2,
741 ),
742 (CompressAir::<SP1Field>::BaseAlu(BaseAluChip), base_alu_events),
743 (CompressAir::<SP1Field>::ExtAlu(ExtAluChip), ext_alu_events),
744 (CompressAir::<SP1Field>::Poseidon2Wide(Poseidon2WideChip), poseidon2_wide_events),
745 (CompressAir::<SP1Field>::Select(SelectChip), select_events),
746 (CompressAir::<SP1Field>::PrefixSumChecks(PrefixSumChecksChip), prefix_sum_checks_events),
747 (CompressAir::<SP1Field>::PublicValues(PublicValuesChip), commit_pv_hash_events),
748 ];
749 SP1RecursionProofShape { shape: chips_and_heights.into_iter().collect() }
750}
751
752#[cfg(test)]
753mod tests {
754 use anyhow::Context;
755
756 use crate::{
757 recursion::{
758 compose_program_from_input, deferred_program_from_input, dummy_compose_input,
759 dummy_deferred_input, normalize_program_from_input, recursive_verifier,
760 },
761 worker::SP1LightNode,
762 CpuSP1ProverComponents,
763 };
764 #[cfg(feature = "experimental")]
765 use sp1_core_executor::SP1Context;
766
767 use sp1_core_machine::utils::setup_logger;
768 use sp1_recursion_compiler::config::InnerConfig;
769 use sp1_recursion_executor::RecursionAirEventCount;
770
771 use super::*;
772
773 #[tokio::test]
774 #[ignore = "should be invoked specifically"]
775 async fn test_max_arity() {
776 setup_logger();
777 let client = SP1LightNode::new().await;
778
779 let vk_verification = client.inner().vk_verification();
780 let allowed_vk_height = client.inner().allowed_vk_height();
781
782 let reduce_shape = SP1RecursionProofShape::compress_proof_shape_from_arity(DEFAULT_ARITY)
783 .expect("default arity shape should be valid");
784
785 let arity = reduce_shape
786 .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
787 .await;
788
789 tracing::info!("arity: {}", arity);
790 }
791
792 #[derive(Debug, Error)]
793 enum ShapeError {
794 #[error("Expected arity to be {DEFAULT_ARITY}, found {_0}")]
795 WrongArity(usize),
796 #[error(
797 "Expected the arity {DEFAULT_ARITY} shape to be large enough
798 to accommodate all core shard proof shapes."
799 )]
800 CoreShapesTooLarge,
801 #[error("Expected height of chip {_0} to be a multiple of 32")]
802 ChipHeightNotMultipleOf32(String),
803 #[error("Expected the shape to be minimal")]
804 ShapeNotMinimal,
805 #[error("Public values chip height is not 16")]
806 PublicValuesChipHeightNot16,
807 }
808
809 #[tokio::test]
810 async fn test_core_shape_fit() -> Result<(), anyhow::Error> {
811 setup_logger();
812 let elf = test_artifacts::FIBONACCI_ELF;
813 let client = SP1LightNode::new().await;
814 let vk = client.setup(&elf).await?;
816
817 let context =
818 "Shape is not valid. To fix: From sp1-wip directory, run `cargo test --release -p sp1-prover --features experimental -- test_find_recursion_shape --include-ignored`";
819
820 let machine = RiscvAir::<SP1Field>::machine();
821 let chip_clusters = &machine.shape().chip_clusters;
822 let mut max_cluster_count = RecursionAirEventCount::default();
823
824 let core_verifier = CpuSP1ProverComponents::core_verifier();
825 let recursive_core_verifier =
826 recursive_verifier::<SP1GlobalContext, _, InnerConfig>(core_verifier.shard_verifier());
827
828 for cluster in chip_clusters {
829 let shape = create_test_shape(cluster);
830 let program = normalize_program_from_input(
831 &recursive_core_verifier,
832 &shape.dummy_input(vk.clone()),
833 );
834 max_cluster_count = max_count(max_cluster_count, program.event_counts);
835 }
836
837 let vk_verification = client.inner().vk_verification();
838 let allowed_vk_height = client.inner().allowed_vk_height();
839
840 tracing::info!("max_cluster_count: {:?}", max_cluster_count);
841
842 let reduce_shape =
843 SP1RecursionProofShape::compress_proof_shape_from_arity(DEFAULT_ARITY).unwrap();
844 let arity = reduce_shape
845 .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
846 .await;
847 if arity != DEFAULT_ARITY {
848 return Err(ShapeError::WrongArity(arity)).context(context);
849 }
850
851 {
853 let compress_verifier = CpuSP1ProverComponents::compress_verifier();
854 let recursive_compress_verifier = recursive_verifier::<SP1GlobalContext, _, InnerConfig>(
855 compress_verifier.shard_verifier(),
856 );
857 let deferred_input =
858 dummy_deferred_input(&compress_verifier, &reduce_shape, allowed_vk_height);
859 let deferred_program = deferred_program_from_input(
860 &recursive_compress_verifier,
861 vk_verification,
862 &deferred_input,
863 );
864 let deferred_count = deferred_program.event_counts;
865 tracing::info!("deferred_count: {:?}", deferred_count);
866 max_cluster_count = max_count(max_cluster_count, deferred_count);
867 }
868
869 let arity_4_count = build_recursion_count_from_shape(&reduce_shape.shape);
870 let combined_count = max_count(max_cluster_count, arity_4_count);
871
872 let max_cluster_shape = build_shape_from_recursion_air_event_count(&max_cluster_count);
873 if combined_count != arity_4_count {
874 return Err(ShapeError::CoreShapesTooLarge).context(context);
875 }
876
877 for (chip, height) in (&reduce_shape.shape).into_iter() {
878 if chip != "PublicValues" {
879 if !height.is_multiple_of(32) {
880 return Err(ShapeError::ChipHeightNotMultipleOf32(chip.clone()))
881 .context(context);
882 }
883 let mut new_reduce_shape = reduce_shape.clone();
884
885 new_reduce_shape.shape.insert_with_name(chip, height - 32);
886
887 if !(new_reduce_shape
888 .max_arity::<CpuSP1ProverComponents>(vk_verification, allowed_vk_height)
889 .await
890 < DEFAULT_ARITY
891 || new_reduce_shape.shape.height_of_name(chip).unwrap()
892 < max_cluster_shape
893 .shape
894 .height_of_name(chip)
895 .unwrap()
896 .next_multiple_of(32))
897 {
898 return Err(ShapeError::ShapeNotMinimal).context(context);
899 }
900 } else if *height != 16 {
901 return Err(ShapeError::PublicValuesChipHeightNot16).context(context);
902 }
903 }
904 Ok(())
905 }
906
907 #[cfg(feature = "experimental")]
908 use serial_test::serial;
909
910 #[tokio::test]
911 #[serial]
912 #[cfg(feature = "experimental")]
913 async fn test_build_vk_map() {
914 use std::fs::File;
915
916 use either::Either;
917
918 use sp1_core_machine::io::SP1Stdin;
919 use sp1_prover_types::network_base_types::ProofMode;
920 use sp1_verifier::SP1Proof;
921
922 use crate::worker::{cpu_worker_builder, SP1LocalNodeBuilder};
923
924 setup_logger();
925
926 let temp_dir = std::env::temp_dir();
928 let vk_map_path = temp_dir.join("vk_map.bin");
929
930 let _ = std::fs::remove_file(&vk_map_path);
932
933 let node = SP1LocalNodeBuilder::from_worker_client_builder(cpu_worker_builder())
934 .build()
935 .await
936 .unwrap();
937
938 let elf = test_artifacts::FIBONACCI_ELF;
939
940 let vk = node.setup(&elf).await.expect("Failed to setup");
942
943 let proof = node
944 .prove_with_mode(&elf, SP1Stdin::default(), SP1Context::default(), ProofMode::Core)
945 .await
946 .expect("Failed to prove");
947
948 let shapes = create_all_input_shapes(
950 CpuSP1ProverComponents::core_verifier().shard_verifier().machine().shape(),
951 DEFAULT_ARITY,
952 );
953
954 let mut shape_indices = vec![];
956
957 let core_proof = match proof.proof {
958 SP1Proof::Core(proof) => proof,
959 _ => panic!("Expected core proof"),
960 };
961
962 for proof in &core_proof {
963 let shape = SP1RecursionProgramShape::Normalize(
964 CpuSP1ProverComponents::core_verifier().shape_from_proof(proof),
965 );
966
967 shape_indices.push(shapes.iter().position(|s| s == &shape).unwrap());
968 }
969
970 let shape_indices =
971 shape_indices.into_iter().chain(shapes.len() - 12..shapes.len()).collect::<Vec<_>>();
972
973 let result = node.build_vks(Some(Either::Left(shape_indices)), 4).await.unwrap();
974
975 let vk_map_path = temp_dir.join("vk_map.bin");
976
977 let mut file = File::create(vk_map_path.clone()).unwrap();
979
980 bincode::serialize_into(&mut file, &result.vk_map).unwrap();
981
982 let node = SP1LocalNodeBuilder::from_worker_client_builder(
984 cpu_worker_builder().with_vk_map_path(vk_map_path.to_str().unwrap().to_string()),
985 )
986 .build()
987 .await
988 .unwrap();
989
990 tracing::info!("Rebuilt prover with vk map.");
991
992 let proof = node
994 .prove_with_mode(
995 &elf,
996 SP1Stdin::default(),
997 SP1Context::default(),
998 ProofMode::Compressed,
999 )
1000 .await
1001 .expect("Failed to prove");
1002
1003 node.verify(&vk, &proof.proof).unwrap();
1004
1005 std::fs::remove_file(vk_map_path).unwrap();
1006 }
1007
1008 #[tokio::test]
1009 #[ignore = "should be invoked for shape tuning"]
1010 async fn test_find_recursion_shape() {
1011 setup_logger();
1012 let elf = test_artifacts::FIBONACCI_ELF;
1013 let client = SP1LightNode::new().await;
1014 let vk = client.setup(&elf).await.unwrap();
1015
1016 let machine = RiscvAir::<SP1Field>::machine();
1017 let chip_clusters = &machine.shape().chip_clusters;
1018 let allowed_vk_height = client.inner().allowed_vk_height();
1019 let vk_verification = client.inner().vk_verification();
1020
1021 let verifier = CpuSP1ProverComponents::compress_verifier();
1022 let dummy_input =
1023 |current_shape: &SP1RecursionProofShape| -> SP1CompressWithVKeyWitnessValues<SP1PcsProofInner> {
1024 dummy_compose_input(&verifier, current_shape, DEFAULT_ARITY, allowed_vk_height)
1025 };
1026 let core_verifier = CpuSP1ProverComponents::core_verifier();
1027 let recursive_core_verifier =
1028 recursive_verifier::<SP1GlobalContext, _, InnerConfig>(core_verifier.shard_verifier());
1029
1030 let recursive_compress_verifier =
1031 recursive_verifier::<SP1GlobalContext, _, InnerConfig>(verifier.shard_verifier());
1032 let compose_program =
1033 |input: &SP1CompressWithVKeyWitnessValues<SP1PcsProofInner>| -> Arc<RecursionProgram<SP1Field>> {
1034 Arc::new(compose_program_from_input(
1035 &recursive_compress_verifier,
1036 vk_verification,
1037 input,
1038 ))
1039 };
1040
1041 let mut max_cluster_count = RecursionAirEventCount::default();
1044
1045 for cluster in chip_clusters {
1046 let shape = create_test_shape(cluster);
1047 let program = normalize_program_from_input(
1048 &recursive_core_verifier,
1049 &shape.dummy_input(vk.clone()),
1050 );
1051 max_cluster_count = max_count(max_cluster_count, program.event_counts);
1052 }
1053
1054 let mut current_shape = build_shape_from_recursion_air_event_count(&max_cluster_count);
1057 let trace_generator =
1058 DefaultTraceGenerator::new(CompressAir::<SP1Field>::compress_machine());
1059 loop {
1060 let input = dummy_input(¤t_shape);
1062 let program = compose_program(&input);
1064 let setup_permits = ProverSemaphore::new(1);
1065 let preprocessed_traces = trace_generator
1068 .generate_preprocessed_traces(program, RECURSION_MAX_LOG_ROW_COUNT, setup_permits)
1069 .await;
1070
1071 let updated_key_values = preprocessed_traces
1073 .preprocessed_traces
1074 .into_iter()
1075 .filter_map(|(chip, trace)| {
1076 let real_height = trace.num_real_entries();
1077 let expected_height = current_shape.shape.height_of_name(&chip).unwrap();
1078
1079 if real_height > expected_height {
1080 tracing::warn!(
1081 "Insufficient height for chip {}: expected {}, got {}",
1082 chip,
1083 expected_height,
1084 real_height
1085 );
1086 Some((chip, real_height))
1087 } else {
1088 None
1089 }
1090 })
1091 .collect::<Vec<_>>();
1092
1093 if updated_key_values.is_empty() {
1095 break;
1096 }
1097 for (chip, real_height) in updated_key_values {
1099 current_shape.shape.insert_with_name(&chip, real_height);
1100 }
1101 }
1102
1103 let shape = SP1RecursionProofShape {
1105 shape: RecursionShape::new(
1106 current_shape
1107 .shape
1108 .into_iter()
1109 .map(|(chip, height)| {
1110 let new_height = if chip == "PublicValues" {
1111 height
1112 } else {
1113 height.next_multiple_of(32)
1114 };
1115 (chip, new_height)
1116 })
1117 .collect(),
1118 ),
1119 };
1120
1121 let mut file = std::fs::File::create("compress_shape.json").unwrap();
1122 serde_json::to_writer_pretty(&mut file, &shape).unwrap();
1123 }
1124}