use crate::ml_sumcheck::data_structures::{ListOfProductsOfPolynomials, PolynomialInfo};
use crate::ml_sumcheck::protocol::prover::{ProverMsg, ProverState};
use crate::ml_sumcheck::protocol::verifier::SubClaim;
use crate::ml_sumcheck::protocol::IPForMLSumcheck;
use crate::rng::{Blake2s512Rng, FeedableRNG};
use ark_ff::Field;
use ark_std::marker::PhantomData;
use ark_std::vec::Vec;
pub mod protocol;
pub mod data_structures;
#[cfg(test)]
mod test;
pub struct MLSumcheck<F: Field>(#[doc(hidden)] PhantomData<F>);
pub type Proof<F> = Vec<ProverMsg<F>>;
impl<F: Field> MLSumcheck<F> {
pub fn extract_sum(proof: &Proof<F>) -> F {
proof[0].evaluations[0] + proof[0].evaluations[1]
}
pub fn prove(polynomial: &ListOfProductsOfPolynomials<F>) -> Result<Proof<F>, crate::Error> {
let mut fs_rng = Blake2s512Rng::setup();
Self::prove_as_subprotocol(&mut fs_rng, polynomial).map(|r| r.0)
}
pub fn prove_as_subprotocol(
fs_rng: &mut impl FeedableRNG<Error = crate::Error>,
polynomial: &ListOfProductsOfPolynomials<F>,
) -> Result<(Proof<F>, ProverState<F>), crate::Error> {
fs_rng.feed(&polynomial.info())?;
let mut prover_state = IPForMLSumcheck::prover_init(polynomial);
let mut verifier_msg = None;
let mut prover_msgs = Vec::with_capacity(polynomial.num_variables);
for _ in 0..polynomial.num_variables {
let prover_msg = IPForMLSumcheck::prove_round(&mut prover_state, &verifier_msg);
fs_rng.feed(&prover_msg)?;
prover_msgs.push(prover_msg);
verifier_msg = Some(IPForMLSumcheck::sample_round(fs_rng));
}
Ok((prover_msgs, prover_state))
}
pub fn verify(
polynomial_info: &PolynomialInfo,
claimed_sum: F,
proof: &Proof<F>,
) -> Result<SubClaim<F>, crate::Error> {
let mut fs_rng = Blake2s512Rng::setup();
Self::verify_as_subprotocol(&mut fs_rng, polynomial_info, claimed_sum, proof)
}
pub fn verify_as_subprotocol(
fs_rng: &mut impl FeedableRNG<Error = crate::Error>,
polynomial_info: &PolynomialInfo,
claimed_sum: F,
proof: &Proof<F>,
) -> Result<SubClaim<F>, crate::Error> {
fs_rng.feed(polynomial_info)?;
let mut verifier_state = IPForMLSumcheck::verifier_init(polynomial_info);
for i in 0..polynomial_info.num_variables {
let prover_msg = proof.get(i).expect("proof is incomplete");
fs_rng.feed(prover_msg)?;
let _verifier_msg =
IPForMLSumcheck::verify_round((*prover_msg).clone(), &mut verifier_state, fs_rng);
}
IPForMLSumcheck::check_and_generate_subclaim(verifier_state, claimed_sum)
}
}