use super::test_cases::sumcheck_test_cases;
use crate::{
base::{
polynomial::CompositePolynomial,
proof::Transcript as _,
scalar::{test_scalar::TestScalar, MontScalar, Scalar},
},
proof_primitive::{
inner_product::curve_25519_scalar::Curve25519Scalar,
sumcheck::{ProverState, SumcheckProof},
},
};
use alloc::rc::Rc;
use ark_std::UniformRand;
use merlin::Transcript;
use num_traits::{One, Zero};
#[test]
fn test_create_verify_proof() {
let num_vars = 1;
let mut evaluation_point: [Curve25519Scalar; 1] = [Curve25519Scalar::zero(); 1];
let mut poly = CompositePolynomial::new(num_vars);
let a_vec: [Curve25519Scalar; 2] = [
Curve25519Scalar::from(123u64),
Curve25519Scalar::from(456u64),
];
let fa = Rc::new(a_vec.to_vec());
poly.add_product([fa], Curve25519Scalar::from(1u64));
let mut transcript = Transcript::new(b"sumchecktest");
let mut proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);
let mut transcript = Transcript::new(b"sumchecktest");
let subclaim = proof
.verify_without_evaluation(
&mut transcript,
poly.num_variables,
&Curve25519Scalar::from(579u64),
)
.expect("verify failed");
assert_eq!(subclaim.evaluation_point, evaluation_point);
assert_eq!(
poly.evaluate(&evaluation_point),
subclaim.expected_evaluation
);
let mut transcript = Transcript::new(b"sumchecktest");
transcript.extend_serialize_as_le(&123u64);
let subclaim = proof
.verify_without_evaluation(
&mut transcript,
poly.num_variables,
&Curve25519Scalar::from(579u64),
)
.expect("verify failed");
assert_ne!(subclaim.evaluation_point, evaluation_point);
let mut transcript = Transcript::new(b"sumchecktest");
let subclaim = proof.verify_without_evaluation(
&mut transcript,
poly.num_variables,
&Curve25519Scalar::from(123u64),
);
assert!(subclaim.is_err());
proof.coefficients[0] += Curve25519Scalar::from(3u64);
let subclaim = proof.verify_without_evaluation(
&mut transcript,
poly.num_variables,
&Curve25519Scalar::from(579u64),
);
assert!(subclaim.is_err());
}
fn random_product(
nv: usize,
num_multiplicands: usize,
rng: &mut ark_std::rand::rngs::StdRng,
) -> (Vec<Rc<Vec<Curve25519Scalar>>>, Curve25519Scalar) {
let mut multiplicands = Vec::with_capacity(num_multiplicands);
for _ in 0..num_multiplicands {
multiplicands.push(Vec::with_capacity(1 << nv));
}
let mut sum = Curve25519Scalar::zero();
for _ in 0..(1 << nv) {
let mut product = Curve25519Scalar::one();
for multiplicand in multiplicands.iter_mut().take(num_multiplicands) {
let val = Curve25519Scalar::rand(rng);
multiplicand.push(val);
product *= val;
}
sum += product;
}
(multiplicands.into_iter().map(Rc::new).collect(), sum)
}
fn random_polynomial(
nv: usize,
num_multiplicands_range: (usize, usize),
num_products: usize,
rng: &mut ark_std::rand::rngs::StdRng,
) -> (CompositePolynomial<Curve25519Scalar>, Curve25519Scalar) {
use ark_std::rand::Rng;
let mut sum = Curve25519Scalar::zero();
let mut poly = CompositePolynomial::new(nv);
for _ in 0..num_products {
let num_multiplicands = rng.gen_range(num_multiplicands_range.0..num_multiplicands_range.1);
let (product, product_sum) = random_product(nv, num_multiplicands, rng);
let coefficient = Curve25519Scalar::rand(rng);
poly.add_product(product.into_iter(), coefficient);
sum += product_sum * coefficient;
}
(poly, sum)
}
fn test_polynomial(nv: usize, num_multiplicands_range: (usize, usize), num_products: usize) {
let mut rng = <ark_std::rand::rngs::StdRng as ark_std::rand::SeedableRng>::from_seed([0u8; 32]);
let (poly, asserted_sum) =
random_polynomial(nv, num_multiplicands_range, num_products, &mut rng);
let mut transcript = Transcript::new(b"sumchecktest");
let mut evaluation_point = vec![Curve25519Scalar::zero(); poly.num_variables];
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&poly),
);
let mut transcript = Transcript::new(b"sumchecktest");
let subclaim = proof
.verify_without_evaluation(&mut transcript, poly.num_variables, &asserted_sum)
.expect("verify failed");
assert_eq!(subclaim.evaluation_point, evaluation_point);
assert_eq!(
poly.evaluate(&evaluation_point),
subclaim.expected_evaluation
);
}
#[test]
fn test_trivial_polynomial() {
let nv = 1;
let num_multiplicands_range = (4, 13);
let num_products = 5;
test_polynomial(nv, num_multiplicands_range, num_products);
}
#[test]
fn test_normal_polynomial() {
let nv = 7;
let num_multiplicands_range = (4, 9);
let num_products = 5;
test_polynomial(nv, num_multiplicands_range, num_products);
}
#[test]
fn we_can_verify_many_random_test_cases() {
let mut rng = ark_std::test_rng();
for test_case in sumcheck_test_cases::<TestScalar>(&mut rng) {
let mut transcript = Transcript::new(b"sumchecktest");
let mut evaluation_point = vec![MontScalar::default(); test_case.num_vars];
let proof = SumcheckProof::create(
&mut transcript,
&mut evaluation_point,
ProverState::create(&test_case.polynomial),
);
let mut transcript = Transcript::new(b"sumchecktest");
let subclaim = proof
.verify_without_evaluation(&mut transcript, test_case.num_vars, &test_case.sum)
.expect("verification should succeed with the correct setup");
assert_eq!(
subclaim.evaluation_point, evaluation_point,
"the prover's evaluation point should match the verifier's"
);
assert_eq!(
test_case.polynomial.evaluate(&evaluation_point),
subclaim.expected_evaluation,
"the claimed evaluation should match the actual evaluation"
);
let mut transcript = Transcript::new(b"sumchecktest");
transcript.extend_serialize_as_le(&123u64);
let verify_result =
proof.verify_without_evaluation(&mut transcript, test_case.num_vars, &test_case.sum);
if let Ok(subclaim) = verify_result {
assert_ne!(
subclaim.evaluation_point, evaluation_point,
"either verification should fail or we should have a different evaluation point with a different transcript"
);
}
let mut transcript = Transcript::new(b"sumchecktest");
assert!(
proof
.verify_without_evaluation(
&mut transcript,
test_case.num_vars,
&(test_case.sum + TestScalar::ONE),
)
.is_err(),
"verification should fail when the sum is wrong"
);
let mut modified_proof = proof;
modified_proof.coefficients[0] += TestScalar::ONE;
let mut transcript = Transcript::new(b"sumchecktest");
assert!(
modified_proof
.verify_without_evaluation(&mut transcript, test_case.num_vars, &test_case.sum,)
.is_err(),
"verification should fail when the proof is modified"
);
}
}
#[test]
fn we_can_generate_and_verify_a_simple_sumcheck_proof() {
use crate::{base::proof::Keccak256Transcript, proof_primitive::hyperkzg::BNScalar};
let num_vars = 3;
let degree = 2;
let state = ProverState::new(
vec![(1.into(), vec![0, 1]), (-BNScalar::ONE, vec![2])],
vec![
(101..=108).map(Into::into).collect(),
(201..=208).map(Into::into).collect(),
(101..=108)
.zip(201..=208)
.map(|(a, b)| a * b)
.map(Into::into)
.collect(),
],
num_vars,
degree,
);
let mut transcript = Keccak256Transcript::new();
transcript.extend_as_be([
0x0123_4567_89AB_CDEF_0123_4567_89AB_CDEF_u128,
0x0123_4567_89AB_CDEF_0123_4567_89AB_CDEF_u128,
]);
let mut evaluation_point = vec![MontScalar::default(); num_vars];
let proof = SumcheckProof::<BNScalar>::create(&mut transcript, &mut evaluation_point, state);
let mut transcript = Keccak256Transcript::new();
transcript.extend_as_be([
0x0123_4567_89AB_CDEF_0123_4567_89AB_CDEF_u128,
0x0123_4567_89AB_CDEF_0123_4567_89AB_CDEF_u128,
]);
let subclaim = proof
.verify_without_evaluation(&mut transcript, num_vars, &BNScalar::ZERO)
.unwrap();
assert_eq!(subclaim.evaluation_point, evaluation_point,);
}