use alloc::vec;
use alloc::vec::Vec;
use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
use p3_challenger::DuplexChallenger;
use p3_commit::ExtensionMmcs;
use p3_dft::Radix2DFTSmallBatch;
use p3_field::Field;
use p3_field::extension::BinomialExtensionField;
use p3_merkle_tree::MerkleTreeMmcs;
use p3_multilinear_util::poly::Poly;
use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
use p3_zk_codes::reed_solomon::ReedSolomonZkEncoding;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use crate::layout::{PrefixProver, SuffixProver, Table, TableShape};
use crate::strategy::VariableOrder;
use crate::zk::{ZkLayout, ZkProver, ZkSumcheckData, ZkVerifier};
pub type F = BabyBear;
pub type EF = BinomialExtensionField<F, 4>;
pub type Perm = Poseidon2BabyBear<16>;
pub type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
pub type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
pub type MyChallenger = DuplexChallenger<F, Perm, 16, 8>;
pub type PackedF = <F as Field>::Packing;
pub type BaseMmcs = MerkleTreeMmcs<PackedF, PackedF, MyHash, MyCompress, 2, 8>;
pub type MyMmcs = ExtensionMmcs<F, EF, BaseMmcs>;
pub type MyDft = Radix2DFTSmallBatch<EF>;
pub type MyEnc = ReedSolomonZkEncoding<EF, MyDft>;
pub const T: usize = 2;
pub fn make_setup(seed: u64, ell_zk: usize) -> (Perm, MyMmcs, MyEnc) {
let mut perm_rng = SmallRng::seed_from_u64(seed);
let perm = Perm::new_from_rng_128(&mut perm_rng);
let merkle_hash = MyHash::new(perm.clone());
let merkle_compress = MyCompress::new(perm.clone());
let base_mmcs = BaseMmcs::new(merkle_hash, merkle_compress, 0);
let mmcs: MyMmcs = ExtensionMmcs::new(base_mmcs);
let m = (ell_zk + T).next_power_of_two();
let dft = MyDft::default();
let encoding = MyEnc::new(T, ell_zk, m, dft);
(perm, mmcs, encoding)
}
#[allow(clippy::type_complexity)]
pub fn build_prover_verifier<L>(
evals: Vec<F>,
folding_factor: usize,
encoding: MyEnc,
mmcs: MyMmcs,
) -> (ZkProver<F, EF, MyEnc, MyMmcs, L>, ZkVerifier<F, EF>, usize)
where
L: ZkLayout<F, EF>,
{
let n_vars = p3_util::log2_strict_usize(evals.len());
let poly = Poly::new(evals);
let table = Table::new(vec![poly]);
let witness = L::new_witness(vec![table], folding_factor);
let inner = L::from_witness(witness);
let prover = ZkProver::new(inner, encoding, mmcs);
let shapes = [TableShape::new(n_vars, 1)];
let verifier = match L::strategy().variable_order {
VariableOrder::Prefix => ZkVerifier::<F, EF>::new_prefix(&shapes),
VariableOrder::Suffix => ZkVerifier::<F, EF>::new_suffix(&shapes),
};
(prover, verifier, n_vars)
}
pub struct ProverRun {
pub verifier: ZkVerifier<F, EF>,
pub verifier_challenger: MyChallenger,
pub zk_data: ZkSumcheckData<F, EF>,
pub mask_commitment: <MyMmcs as p3_commit::Mmcs<EF>>::Commitment,
pub prover_randomness: p3_multilinear_util::point::Point<EF>,
pub virtual_evals: Vec<EF>,
}
#[allow(clippy::too_many_arguments)]
pub fn run_prover(
binding: VariableOrder,
n_vars: usize,
folding_factor: usize,
ell_zk: usize,
num_concrete: usize,
num_virtual: usize,
pow_bits: usize,
seed: u64,
) -> ProverRun {
let (perm, mmcs, encoding) = make_setup(seed, ell_zk);
let mut data_rng = SmallRng::seed_from_u64(seed.wrapping_add(1));
let evals: Vec<F> = (0..(1usize << n_vars)).map(|_| data_rng.random()).collect();
let prover_challenger = MyChallenger::new(perm.clone());
let verifier_challenger = MyChallenger::new(perm);
let zk_data = ZkSumcheckData::<F, EF>::default();
let prover_rng = SmallRng::seed_from_u64(seed.wrapping_add(2));
match binding {
VariableOrder::Prefix => drive_prover_run::<PrefixProver<F, EF>>(
evals,
folding_factor,
encoding,
mmcs,
num_concrete,
num_virtual,
pow_bits,
prover_challenger,
verifier_challenger,
zk_data,
prover_rng,
),
VariableOrder::Suffix => drive_prover_run::<SuffixProver<F, EF>>(
evals,
folding_factor,
encoding,
mmcs,
num_concrete,
num_virtual,
pow_bits,
prover_challenger,
verifier_challenger,
zk_data,
prover_rng,
),
}
}
#[allow(clippy::too_many_arguments)]
fn drive_prover_run<L>(
evals: Vec<F>,
folding_factor: usize,
encoding: MyEnc,
mmcs: MyMmcs,
num_concrete: usize,
num_virtual: usize,
pow_bits: usize,
mut prover_challenger: MyChallenger,
mut verifier_challenger: MyChallenger,
mut zk_data: ZkSumcheckData<F, EF>,
mut prover_rng: SmallRng,
) -> ProverRun
where
L: ZkLayout<F, EF>,
{
let (mut prover, mut verifier, _n_vars) =
build_prover_verifier::<L>(evals, folding_factor, encoding, mmcs);
for _ in 0..num_concrete {
let openings = prover.eval(0, &[0], &mut prover_challenger);
verifier.add_claim(0, &[0], &openings, &mut verifier_challenger);
}
let mut virtual_evals = Vec::with_capacity(num_virtual);
for _ in 0..num_virtual {
let eval = prover.add_virtual_eval(&mut prover_challenger);
verifier.add_virtual_eval(eval, &mut verifier_challenger);
virtual_evals.push(eval);
}
let prover_handoff = prover.into_sumcheck(
&mut zk_data,
pow_bits,
&mut prover_challenger,
&mut prover_rng,
);
let mask_commitment = prover_handoff.mask_oracle.0.clone();
ProverRun {
verifier,
verifier_challenger,
zk_data,
mask_commitment,
prover_randomness: prover_handoff.randomness,
virtual_evals,
}
}
pub fn replay_verifier(
mut run: ProverRun,
ell_zk: usize,
folding_factor: usize,
pow_bits: usize,
) -> Result<p3_multilinear_util::point::Point<EF>, &'static str> {
let verifier_handoff = run
.verifier
.into_sumcheck::<MyMmcs, _>(
&run.zk_data,
&run.mask_commitment,
ell_zk,
folding_factor,
pow_bits,
&mut run.verifier_challenger,
)
.map_err(|_| "verifier rejected honest prover output")?;
Ok(verifier_handoff.randomness)
}
pub fn run_roundtrip(
binding: VariableOrder,
n_vars: usize,
folding_factor: usize,
ell_zk: usize,
num_concrete: usize,
num_virtual: usize,
seed: u64,
) -> Result<(), &'static str> {
let pow_bits = 4;
let run = run_prover(
binding,
n_vars,
folding_factor,
ell_zk,
num_concrete,
num_virtual,
pow_bits,
seed,
);
let prover_randomness = run.prover_randomness.clone();
let verifier_point = replay_verifier(run, ell_zk, folding_factor, pow_bits)?;
let prover_randomness_vec: Vec<EF> = prover_randomness.iter().copied().collect();
let verifier_randomness_vec: Vec<EF> = verifier_point.iter().copied().collect();
if prover_randomness_vec != verifier_randomness_vec {
return Err("prover/verifier disagreed on sumcheck randomness");
}
Ok(())
}