#![allow(clippy::print_stdout)]
use clap::Parser;
use p3_baby_bear::BabyBear;
use p3_util::log2_ceil_usize;
use sp1_core_executor::{Executor, Program, RiscvAirId, SP1Context};
use sp1_core_machine::{
io::SP1Stdin, riscv::RiscvAir, shape::CoreShapeConfig, utils::setup_logger,
};
use sp1_stark::SP1CoreOpts;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long, value_delimiter = ' ')]
list: Vec<String>,
#[arg(short, long, value_delimiter = ' ')]
shard_size: usize,
}
fn test_shape_fixing(
elf: &[u8],
stdin: &SP1Stdin,
opts: SP1CoreOpts,
context: SP1Context,
shape_config: &CoreShapeConfig<BabyBear>,
) {
let mut program = Program::from(elf).unwrap();
shape_config.fix_preprocessed_shape(&mut program).unwrap();
let mut executor = Executor::with_context(program, opts, context);
executor.maximal_shapes = Some(
shape_config.maximal_core_shapes(log2_ceil_usize(opts.shard_size)).into_iter().collect(),
);
executor.write_vecs(&stdin.buffer);
for (proof, vkey) in stdin.proofs.iter() {
executor.write_proof(proof.clone(), vkey.clone());
}
let mut finished = false;
while !finished {
let (records, f) = executor.execute_record(true).unwrap();
finished = f;
for mut record in records {
let _ = record.defer();
let heights = RiscvAir::<BabyBear>::core_heights(&record);
println!("heights: {:?}", heights);
shape_config.fix_shape(&mut record).unwrap();
if record.contains_cpu()
&& record.shape.unwrap().height(&RiscvAirId::Cpu).unwrap() > opts.shard_size
{
panic!("something went wrong")
}
}
}
}
fn main() {
setup_logger();
let args = Args::parse();
let config = CoreShapeConfig::<BabyBear>::default();
let mut opts = SP1CoreOpts { shard_batch_size: 1, ..Default::default() };
opts.shard_size = 1 << args.shard_size;
let program_list = args.list;
for s3_path in program_list {
tracing::info!("download elf and input for {}", s3_path);
let status = std::process::Command::new("aws")
.args([
"s3",
"cp",
&format!("s3://sp1-testing-suite/{}/program.bin", s3_path),
"program.bin",
])
.status()
.expect("Failed to execute aws s3 cp command for program.bin");
if !status.success() {
panic!("Failed to download program.bin from S3");
}
let status = std::process::Command::new("aws")
.args([
"s3",
"cp",
&format!("s3://sp1-testing-suite/{}/stdin.bin", s3_path),
"stdin.bin",
])
.status()
.expect("Failed to execute aws s3 cp command for stdin.bin");
if !status.success() {
panic!("Failed to download stdin.bin from S3");
}
let elf = std::fs::read("program.bin").expect("failed to read program");
let stdin = std::fs::read("stdin.bin").expect("failed to read stdin");
let stdin: SP1Stdin = bincode::deserialize(&stdin).expect("failed to deserialize stdin");
let elf = elf.clone();
let stdin = stdin.clone();
let new_context = SP1Context::default();
test_shape_fixing(&elf, &stdin, opts, new_context, &config);
std::fs::remove_file("program.bin").expect("failed to remove program.bin");
std::fs::remove_file("stdin.bin").expect("failed to remove stdin.bin");
}
}