use crate::{
folding::{prover::SumFoldProverOutput, zerofold::ZeroFold, SumFold, SumFoldInstance},
polynomials::{simple_eval::SimpleEval, MultiPoint},
sumcheck::{
CommitType, Env, EvalKind, NoChallIdx, NoChallenges, ProverOutput, Sum, SumcheckFunction,
SumcheckProver, SumcheckVerifier, Var,
},
zerocheck::{CompactPowers, ZeroCheckIdx, ZeroCheckMles},
};
use ark_ff::Field;
use sponge::sponge::UnsafeSponge;
use std::fmt::Debug;
use transcript::{
instances::PolyEvalCheck, params::ParamResolver, protocols::Reduction, MessageGuard,
TranscriptBuilder, TranscriptGuard,
};
struct ZeroCheckWrapped;
type Evals<V> = ZeroCheckMles<V, SimpleEval<V, 3>>;
const fn kinds() -> Evals<EvalKind> {
let inner = SimpleEval::new([EvalKind::Committed(CommitType::Instance); 3]);
ZeroCheckMles::new(EvalKind::Virtual, inner)
}
impl<F: Field> SumcheckFunction<F> for ZeroCheckWrapped {
type Idx = ZeroCheckIdx<usize>;
type Mles<V: Copy + Debug> = Evals<V>;
type Challs = NoChallenges<F>;
type ChallIdx = NoChallIdx;
const KINDS: Self::Mles<EvalKind> = kinds();
fn map_evals<A, B, M>(evals: Self::Mles<A>, f: M) -> Self::Mles<B>
where
A: Copy + Debug,
B: Copy + Debug,
M: Fn(A) -> B,
{
evals.map(&f, |inner| inner.map(&f))
}
fn function<V: Var<F>, E: Env<F, V, Self::Idx, Self::ChallIdx>>(env: E) -> V {
let a = env.get(ZeroCheckIdx::Inner(0));
let b = env.get(ZeroCheckIdx::Inner(1));
let c = env.get(ZeroCheckIdx::Inner(2));
let z = env.get(ZeroCheckIdx::ZeroCheckChallenge);
z * (a * b - c)
}
fn symbolic_function<V: Var<F>, E: Env<F, V, Self::Idx, Self::ChallIdx>>(
&self,
env: E,
) -> Option<V> {
let a = env.get(ZeroCheckIdx::Inner(0));
let b = env.get(ZeroCheckIdx::Inner(1));
let c = env.get(ZeroCheckIdx::Inner(2));
let z = env.get(ZeroCheckIdx::ZeroCheckChallenge);
Some(z * (a * b - c))
}
}
struct ZeroCheckInner;
impl<F: Field> SumcheckFunction<F> for ZeroCheckInner {
type Idx = usize;
type Mles<V: Copy + Debug> = SimpleEval<V, 3>;
type Challs = NoChallenges<F>;
type ChallIdx = NoChallIdx;
const KINDS: Self::Mles<EvalKind> = SimpleEval::new([EvalKind::Virtual; 3]);
fn map_evals<A, B, M>(evals: Self::Mles<A>, f: M) -> Self::Mles<B>
where
A: Copy + Debug,
B: Copy + Debug,
M: Fn(A) -> B,
{
evals.map(f)
}
fn function<V: Var<F>, E: Env<F, V, Self::Idx, Self::ChallIdx>>(env: E) -> V {
let a = env.get(0);
let b = env.get(1);
let c = env.get(2);
a * b - c
}
fn symbolic_function<V: Var<F>, E: Env<F, V, Self::Idx, Self::ChallIdx>>(
&self,
env: E,
) -> Option<V> {
let a = env.get(0);
let b = env.get(1);
let c = env.get(2);
Some(a * b - c)
}
}
const VARS: usize = 5;
#[derive(Clone)]
struct InstanceWitness<F: Field> {
witness: Vec<Evals<F>>,
powers: CompactPowers<F>,
}
fn sample_instance_witness<F: Field>(elems: Vec<F>) -> InstanceWitness<F> {
assert!(elems.len() > (1 << VARS) * 2);
let mut evals = vec![];
let mut elems = elems.into_iter();
let chall = elems.next().unwrap();
let compact_powers = CompactPowers::new(chall, VARS) * F::from(3u8);
let mut powers = compact_powers.clone().eval_over_domain().into_iter();
for _ in 0..(1 << VARS) {
let a = elems.next().unwrap();
let b = elems.next().unwrap();
let c = a * b;
let inner = SimpleEval::new([a, b, c]);
let z = powers.next().unwrap();
evals.push(Evals::new(z, inner));
}
InstanceWitness {
witness: evals,
powers: compact_powers,
}
}
fn check_pair<F: Field>(pair: InstanceWitness<F>, sum: F) {
let InstanceWitness { witness, powers } = pair;
prove_and_verify(powers, witness, sum);
}
fn test<F: Field>(random_elements: Vec<F>) {
let mut elements = random_elements.into_iter();
let pair1 = sample_instance_witness::<F>(elements.by_ref().take((1 << VARS) * 2 + 1).collect());
let pair2 = sample_instance_witness::<F>(elements.by_ref().take((1 << VARS) * 2 + 1).collect());
check_pair(pair1.clone(), F::zero());
check_pair(pair2.clone(), F::zero());
let zerofold: ZeroFold<F, ZeroCheckInner> = ZeroFold::new(ZeroCheckInner, VARS);
let (witness, sum, folder) = {
let transcript_desc = TranscriptBuilder::new(VARS, ParamResolver::new())
.add_reduction_patter::<F, SumFold<F, _>>(zerofold.sumfold_key())
.finish::<F, UnsafeSponge<F>>();
let instance = SumFoldInstance::new([F::zero(), F::zero()]);
let sums = Some(instance);
let w1 = pair1.witness.iter().map(|e| *e.inner()).collect();
let w2 = pair2.witness.iter().map(|e| *e.inner()).collect::<Vec<_>>();
let powers = [pair1.powers.clone(), pair2.powers.clone()];
let mut transcript = transcript_desc.instanciate();
let SumFoldProverOutput {
instance,
folded_witness,
proof,
folder,
sum,
} = zerofold.fold_zerocheck(
w1,
w2.as_slice(),
sums,
powers,
NoChallenges::default(),
&mut transcript,
);
transcript.finish_unchecked();
let mut transcript = transcript_desc.instanciate();
let transcript_guard = TranscriptGuard::new(&mut transcript, proof);
let instance = MessageGuard::new(instance);
let (instance, _) =
SumFold::verify_reduction(zerofold.sumfold_key(), instance, transcript_guard).unwrap();
assert_eq!(sum, instance.0);
transcript.finish_unchecked();
(folded_witness, instance, folder)
};
let powers = folder.fold_powers(pair1.powers, pair2.powers);
let folded_powers = powers.eval_over_domain().into_iter();
let witness = witness
.into_iter()
.zip(folded_powers)
.map(|(e, p)| ZeroCheckMles::new(p, e))
.collect();
let pair = InstanceWitness { witness, powers };
check_pair(pair, sum.0);
}
#[test]
fn fold_zerocheck() {
use ark_ff::UniformRand;
use ark_vesta::Fr;
use rand::{rngs::StdRng, SeedableRng};
use std::iter::repeat;
let mut rng = StdRng::seed_from_u64(0);
let elems = repeat(()).map(|_| Fr::rand(&mut rng));
let elems = elems.take((1 << VARS) * 4 + 2).collect();
test(elems);
}
pub fn prove_and_verify<F: Field>(powers: CompactPowers<F>, mle: Vec<Evals<F>>, sum: F) {
let prover = SumcheckProver::<F, ZeroCheckWrapped>::new_symbolic(VARS, &ZeroCheckWrapped);
let verifier = SumcheckVerifier::<F, ZeroCheckWrapped>::new_symbolic(ZeroCheckWrapped, VARS);
let transcript_desc = TranscriptBuilder::new(VARS, ParamResolver::new())
.add_reduction_patter::<F, SumcheckVerifier<F, ZeroCheckWrapped>>(&verifier)
.finish::<F, UnsafeSponge<F>>();
let reduced = {
let mut transcript = transcript_desc.instanciate();
let reduced = prover
.prove_zerocheck(
powers.clone(),
&mut transcript,
mle,
&NoChallenges::default(),
)
.unwrap();
transcript.finish().unwrap();
reduced
};
let ProverOutput { proof, evals, .. } = reduced;
let mut transcript = transcript_desc.instanciate();
let instance = MessageGuard::new(Sum(sum));
let reduced = SumcheckVerifier::verify_reduction(&verifier, instance, transcript.guard(proof));
transcript.finish_unchecked();
let PolyEvalCheck { vars, eval } = reduced.unwrap();
let point = MultiPoint::new(vars);
let inner = *evals.inner();
let powers_eval = powers.point_eval(&point);
let evals = ZeroCheckMles::new(powers_eval, inner);
let checks = verifier.check_evals_at_r(evals, eval, &NoChallenges::default());
assert!(checks);
}