use crate::ml_sumcheck::data_structures::PolynomialInfo;
use crate::ml_sumcheck::protocol::prover::ProverMsg;
use crate::ml_sumcheck::protocol::IPForMLSumcheck;
use ark_ff::Field;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::rand::RngCore;
use ark_std::vec::Vec;
#[derive(Clone, CanonicalSerialize, CanonicalDeserialize)]
pub struct VerifierMsg<F: Field> {
pub randomness: F,
}
pub struct VerifierState<F: Field> {
round: usize,
nv: usize,
max_multiplicands: usize,
finished: bool,
polynomials_received: Vec<Vec<F>>,
randomness: Vec<F>,
}
pub struct SubClaim<F: Field> {
pub point: Vec<F>,
pub expected_evaluation: F,
}
impl<F: Field> IPForMLSumcheck<F> {
pub fn verifier_init(index_info: &PolynomialInfo) -> VerifierState<F> {
VerifierState {
round: 1,
nv: index_info.num_variables,
max_multiplicands: index_info.max_multiplicands,
finished: false,
polynomials_received: Vec::with_capacity(index_info.num_variables),
randomness: Vec::with_capacity(index_info.num_variables),
}
}
pub fn verify_round<R: RngCore>(
prover_msg: ProverMsg<F>,
verifier_state: &mut VerifierState<F>,
rng: &mut R,
) -> Option<VerifierMsg<F>> {
if verifier_state.finished {
panic!("Incorrect verifier state: Verifier is already finished.");
}
let msg = Self::sample_round(rng);
verifier_state.randomness.push(msg.randomness);
verifier_state
.polynomials_received
.push(prover_msg.evaluations);
if verifier_state.round == verifier_state.nv {
verifier_state.finished = true;
} else {
verifier_state.round += 1;
}
Some(msg)
}
pub fn check_and_generate_subclaim(
verifier_state: VerifierState<F>,
asserted_sum: F,
) -> Result<SubClaim<F>, crate::Error> {
if !verifier_state.finished {
panic!("Verifier has not finished.");
}
let mut expected = asserted_sum;
if verifier_state.polynomials_received.len() != verifier_state.nv {
panic!("insufficient rounds");
}
for i in 0..verifier_state.nv {
let evaluations = &verifier_state.polynomials_received[i];
if evaluations.len() != verifier_state.max_multiplicands + 1 {
panic!("incorrect number of evaluations");
}
let p0 = evaluations[0];
let p1 = evaluations[1];
if p0 + p1 != expected {
return Err(crate::Error::Reject(Some(
"Prover message is not consistent with the claim.".into(),
)));
}
expected = interpolate_uni_poly(evaluations, verifier_state.randomness[i]);
}
Ok(SubClaim {
point: verifier_state.randomness,
expected_evaluation: expected,
})
}
#[inline]
pub fn sample_round<R: RngCore>(rng: &mut R) -> VerifierMsg<F> {
VerifierMsg {
randomness: F::rand(rng),
}
}
}
pub(crate) fn interpolate_uni_poly<F: Field>(p_i: &[F], eval_at: F) -> F {
let len = p_i.len();
let mut evals = vec![];
let mut prod = eval_at;
evals.push(eval_at);
let mut check = F::zero();
for i in 1..len {
if eval_at == check {
return p_i[i - 1];
}
check += F::one();
let tmp = eval_at - check;
evals.push(tmp);
prod *= tmp;
}
if eval_at == check {
return p_i[len - 1];
}
let mut res = F::zero();
if p_i.len() <= 20 {
let last_denom = F::from(u64_factorial(len - 1));
let mut ratio_numerator = 1i64;
let mut ratio_enumerator = 1u64;
for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u64)
} else {
F::from(ratio_numerator as u64)
};
res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denom * ratio_numerator_f * evals[i]);
if i != 0 {
ratio_numerator *= -(len as i64 - i as i64);
ratio_enumerator *= i as u64;
}
}
} else if p_i.len() <= 33 {
let last_denom = F::from(u128_factorial(len - 1));
let mut ratio_numerator = 1i128;
let mut ratio_enumerator = 1u128;
for i in (0..len).rev() {
let ratio_numerator_f = if ratio_numerator < 0 {
-F::from((-ratio_numerator) as u128)
} else {
F::from(ratio_numerator as u128)
};
res += p_i[i] * prod * F::from(ratio_enumerator)
/ (last_denom * ratio_numerator_f * evals[i]);
if i != 0 {
ratio_numerator *= -(len as i128 - i as i128);
ratio_enumerator *= i as u128;
}
}
} else {
let mut denom_up = field_factorial::<F>(len - 1);
let mut denom_down = F::one();
for i in (0..len).rev() {
res += p_i[i] * prod * denom_down / (denom_up * evals[i]);
if i != 0 {
denom_up *= -F::from((len - i) as u64);
denom_down *= F::from(i as u64);
}
}
}
res
}
#[inline]
fn field_factorial<F: Field>(a: usize) -> F {
let mut res = F::one();
for i in 1..=a {
res *= F::from(i as u64);
}
res
}
#[inline]
fn u128_factorial(a: usize) -> u128 {
let mut res = 1u128;
for i in 1..=a {
res *= i as u128;
}
res
}
#[inline]
fn u64_factorial(a: usize) -> u64 {
let mut res = 1u64;
for i in 1..=a {
res *= i as u64;
}
res
}
#[cfg(test)]
mod test {
use crate::ml_sumcheck::protocol::verifier::interpolate_uni_poly;
use ark_poly::univariate::DensePolynomial;
use ark_poly::DenseUVPolynomial;
use ark_poly::Polynomial;
use ark_std::vec::Vec;
use ark_std::UniformRand;
type F = ark_test_curves::bls12_381::Fr;
#[test]
fn test_interpolation() {
let mut prng = ark_std::test_rng();
let poly = DensePolynomial::<F>::rand(20 - 1, &mut prng);
let evals = (0..20)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);
assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
let poly = DensePolynomial::<F>::rand(33 - 1, &mut prng);
let evals = (0..33)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);
assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
let poly = DensePolynomial::<F>::rand(64 - 1, &mut prng);
let evals = (0..64)
.map(|i| poly.evaluate(&F::from(i)))
.collect::<Vec<F>>();
let query = F::rand(&mut prng);
assert_eq!(poly.evaluate(&query), interpolate_uni_poly(&evals, query));
let evals = vec![0, 1, 4, 9]
.into_iter()
.map(|i| F::from(i))
.collect::<Vec<F>>();
assert_eq!(interpolate_uni_poly(&evals, F::from(3)), F::from(9));
}
}