#![allow(clippy::print_stdout)]
use std::panic::{catch_unwind, AssertUnwindSafe};
use clap::Parser;
use p3_baby_bear::BabyBear;
use sp1_core_machine::utils::setup_logger;
use sp1_prover::{
components::CpuProverComponents,
shapes::{check_shapes, SP1ProofShape},
SP1Prover, ShrinkAir, REDUCE_BATCH_SIZE,
};
use sp1_recursion_core::shape::RecursionShapeConfig;
use sp1_stark::{shape::OrderedShape, MachineProver};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, default_value_t = false)]
dummy: bool,
#[arg(short, long, default_value_t = REDUCE_BATCH_SIZE)]
recursion_batch_size: usize,
#[arg(short, long, default_value_t = 1)]
num_compiler_workers: usize,
#[arg(short, long, default_value_t = 1)]
num_setup_workers: usize,
#[arg(short, long)]
start: Option<usize>,
#[arg(short, long)]
end: Option<usize>,
}
fn main() {
setup_logger();
let args = Args::parse();
let mut prover = SP1Prover::<CpuProverComponents>::new();
prover.vk_verification = !args.dummy;
prover.join_programs_map.clear();
let compress_shape_config =
prover.compress_shape_config.as_ref().expect("recursion shape config not found");
let candidate = compress_shape_config.union_config_with_extra_room().first().unwrap().clone();
prover.compress_shape_config = Some(RecursionShapeConfig::from_hash_map(&candidate));
assert!(
check_shapes(args.recursion_batch_size, false, args.num_compiler_workers, &mut prover,)
);
let mut answer = candidate.clone();
for (key, value) in candidate.iter() {
if key != "PublicValues" {
let mut done = false;
let mut new_val = *value;
while !done {
new_val -= 1;
answer.insert(key.clone(), new_val);
prover.compress_shape_config = Some(RecursionShapeConfig::from_hash_map(&answer));
done = !check_shapes(
args.recursion_batch_size,
false,
args.num_compiler_workers,
&mut prover,
);
}
answer.insert(key.clone(), new_val + 1);
}
}
let mut no_precompile_answer = answer.clone();
for (key, value) in answer.iter() {
if key != "PublicValues" {
let mut done = false;
let mut new_val = *value;
while !done {
new_val -= 1;
no_precompile_answer.insert(key.clone(), new_val);
prover.compress_shape_config =
Some(RecursionShapeConfig::from_hash_map(&no_precompile_answer));
done = !check_shapes(
args.recursion_batch_size,
true,
args.num_compiler_workers,
&mut prover,
);
}
no_precompile_answer.insert(key.clone(), new_val + 1);
}
}
let mut shrink_shape = ShrinkAir::<BabyBear>::shrink_shape().clone_into_hash_map();
assert!({
prover.compress_shape_config = Some(RecursionShapeConfig::from_hash_map(&answer));
catch_unwind(AssertUnwindSafe(|| {
prover.shrink_prover.setup(&prover.program_from_shape(
sp1_prover::shapes::SP1CompressProgramShape::from_proof_shape(
SP1ProofShape::Shrink(OrderedShape {
inner: answer.clone().into_iter().collect::<Vec<_>>(),
}),
5,
),
Some(shrink_shape.clone().into()),
))
}))
.is_ok()
});
for (key, value) in shrink_shape.clone().iter() {
if key != "PublicValues" {
let mut done = false;
let mut new_val = *value + 1;
while !done {
new_val -= 1;
shrink_shape.insert(key.clone(), new_val);
prover.compress_shape_config = Some(RecursionShapeConfig::from_hash_map(&answer));
done = catch_unwind(AssertUnwindSafe(|| {
prover.shrink_prover.setup(&prover.program_from_shape(
sp1_prover::shapes::SP1CompressProgramShape::from_proof_shape(
SP1ProofShape::Shrink(OrderedShape {
inner: answer.clone().into_iter().collect::<Vec<_>>(),
}),
5,
),
Some(shrink_shape.clone().into()),
))
}))
.is_err();
}
shrink_shape.insert(key.clone(), new_val + 1);
}
}
println!("Final compress shape: {answer:?}");
println!("Final compress shape with no precompiles: {no_precompile_answer:?}");
println!("Final shrink shape: {shrink_shape:?}");
}