use alloc::vec::Vec;
use core::marker::PhantomData;
use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
use p3_commit::Mmcs;
use p3_field::{ExtensionField, Field, HornerIter, TwoAdicField, dot_product};
use p3_matrix::Matrix;
use p3_multilinear_util::point::Point;
use p3_zk_codes::{ZkEncoding, ZkEncodingWithRandomness};
use rand::Rng;
use super::common::{observe_masks_and_mu_tilde, sample_masks};
use super::layout::ZkLayout;
use super::round::{PlainPiece, RoundContext, RoundState, round_poly_to_wire};
use crate::extrapolate_01inf;
use crate::lagrange::lagrange_weights_01inf_multi;
use crate::layout::{PrefixProver, SuffixProver};
use crate::strategy::SumcheckProver;
use crate::svo::calculate_accumulators_batch;
use crate::zk::data::{ZkSumcheckData, ZkSumcheckHandoff};
pub struct ZkProver<F, EF, Enc, M, L>
where
F: Field,
EF: ExtensionField<F>,
Enc: ZkEncoding<EF>,
M: Mmcs<EF>,
{
inner: L,
encoding: Enc,
mmcs: M,
_marker: PhantomData<(F, EF)>,
}
pub type ZkPrefixProver<F, EF, Enc, M> = ZkProver<F, EF, Enc, M, PrefixProver<F, EF>>;
pub type ZkSuffixProver<F, EF, Enc, M> = ZkProver<F, EF, Enc, M, SuffixProver<F, EF>>;
impl<F, EF, Enc, M, L> ZkProver<F, EF, Enc, M, L>
where
F: TwoAdicField,
EF: ExtensionField<F>,
Enc: ZkEncodingWithRandomness<EF>,
M: Mmcs<EF>,
L: ZkLayout<F, EF>,
{
pub const fn new(inner: L, encoding: Enc, mmcs: M) -> Self {
Self {
inner,
encoding,
mmcs,
_marker: PhantomData,
}
}
pub fn folding(&self) -> usize {
self.inner.folding()
}
pub fn num_variables(&self) -> usize {
self.inner.num_variables()
}
pub fn eval<Ch>(&mut self, table_idx: usize, polys: &[usize], challenger: &mut Ch) -> Vec<EF>
where
Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
{
self.inner.eval(table_idx, polys, challenger)
}
pub fn add_virtual_eval<Ch>(&mut self, challenger: &mut Ch) -> EF
where
Ch: FieldChallenger<F> + GrindingChallenger<Witness = F>,
{
self.inner.add_virtual_eval(challenger)
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip_all)]
pub fn into_sumcheck<R, Ch>(
self,
zk_data: &mut ZkSumcheckData<F, EF>,
pow_bits: usize,
challenger: &mut Ch,
rng: &mut R,
) -> ZkSumcheckHandoff<F, EF, M>
where
EF: TwoAdicField,
Enc::Codeword: Matrix<EF>,
R: Rng,
Ch: FieldChallenger<F> + GrindingChallenger<Witness = F> + CanObserve<M::Commitment>,
{
let k = self.inner.folding();
let ell_zk = self.encoding.message_len();
let n_vars = self.inner.num_variables();
assert!(F::TWO != F::ZERO, "Lemma 6.4 requires char(F) != 2");
assert!(
ell_zk >= 3,
"mask degree ell_zk - 1 must cover the degree-2 plain piece (ell_zk >= 3)",
);
assert!(k >= 1, "sumcheck requires at least one round");
assert!(
k <= n_vars,
"folding_factor must be <= poly.num_variables()",
);
let alpha: EF = challenger.sample_algebra_element();
let n_concrete: usize = self.inner.concrete_claims().map(|claim| claim.len()).sum();
let n_virtual = self.inner.virtual_claims().len();
let all_alphas: Vec<EF> = alpha.powers().collect_n(n_concrete + n_virtual);
let (concrete_alphas, virtual_alphas) = all_alphas.split_at(n_concrete);
let mut offset = 0;
let accumulators: Vec<_> = self
.inner
.concrete_claims()
.map(|claim| {
let slice = &concrete_alphas[offset..offset + claim.len()];
offset += claim.len();
calculate_accumulators_batch(claim, slice)
})
.collect();
let mut plain_sum = self.inner.batched_sum(alpha);
let (masks, mask_randomness, mask_oracle) =
sample_masks::<EF, _, _, _, _>(k, &self.encoding, &self.mmcs, challenger, rng);
let sum_endpoints_init =
observe_masks_and_mu_tilde::<F, EF, _>(&masks, k, ell_zk, challenger, zk_data);
let eps: EF = challenger.sample_algebra_element();
let mut rs: Vec<EF> = Vec::with_capacity(k);
let mut mask_evals_at_gamma: Vec<EF> = Vec::with_capacity(k);
let mut sum_future_endpoints = sum_endpoints_init;
let pow2: Vec<EF> = EF::TWO.powers().collect_n(k + 1);
let round_ctx = RoundContext {
k,
ell_zk,
pow2: &pow2,
eps,
};
for round_idx in 0..k {
let j = round_idx + 1;
let s_j = &masks[round_idx];
let s_j_endpoints = s_j[0].double() + s_j[1..].iter().copied().sum::<EF>();
sum_future_endpoints -= s_j_endpoints;
let weights_lag = lagrange_weights_01inf_multi(&rs);
let dot = |row: &[EF]| {
dot_product::<EF, _, _>(row.iter().copied(), weights_lag.iter().copied())
};
let mut plain_c0: EF = accumulators.iter().map(|a| dot(&a[round_idx][0])).sum();
let mut plain_c_inf: EF = accumulators.iter().map(|a| dot(&a[round_idx][1])).sum();
for (vc, alpha_i) in self
.inner
.virtual_claims()
.iter()
.zip(virtual_alphas.iter().copied())
{
plain_c0 += alpha_i * dot(&vc.data[round_idx][0]);
plain_c_inf += alpha_i * dot(&vc.data[round_idx][1]);
}
let h = round_ctx.assemble(
RoundState {
j,
mask: s_j,
past_mask_evals: &mask_evals_at_gamma,
future_endpoints: sum_future_endpoints,
},
PlainPiece {
c0: plain_c0,
c_inf: plain_c_inf,
},
);
let wire = round_poly_to_wire(&h);
challenger.observe_algebra_slice(&wire);
zk_data.round_coefficients.push(wire);
if pow_bits > 0 {
zk_data.pow_witnesses.push(challenger.grind(pow_bits));
}
let gamma_j: EF = challenger.sample_algebra_element();
let s_j_at_gamma_j: EF = s_j.iter().copied().horner(gamma_j);
mask_evals_at_gamma.push(s_j_at_gamma_j);
plain_sum = extrapolate_01inf(plain_c0, plain_sum - plain_c0, plain_c_inf, gamma_j);
rs.push(gamma_j);
}
let rs = Point::new(rs);
let prod_poly = self.inner.zk_residual_handoff(&rs, alpha, eps);
let residual_sum = eps * plain_sum;
debug_assert_eq!(
prod_poly.dot_product(),
residual_sum,
"residual product polynomial dot product must equal eps * plain_residual_sum",
);
ZkSumcheckHandoff {
residual_prover: SumcheckProver::new(prod_poly, residual_sum),
randomness: rs,
eps,
mask_messages: masks,
mask_randomness,
mask_oracle,
}
}
}
#[cfg(test)]
mod tests {
use p3_field::{Field, PackedValue};
use p3_util::log2_strict_usize;
use proptest::prelude::*;
use crate::strategy::VariableOrder;
use crate::zk::test_helpers::{F, run_roundtrip};
#[test]
fn prover_verifier_roundtrip_prefix() {
run_roundtrip(VariableOrder::Prefix, 8, 3, 4, 1, 1, 0)
.expect("honest roundtrip should accept");
}
#[test]
fn prover_verifier_roundtrip_suffix() {
run_roundtrip(VariableOrder::Suffix, 8, 3, 4, 1, 1, 0)
.expect("honest roundtrip should accept");
}
#[test]
fn long_mask_horner_path_prefix() {
run_roundtrip(VariableOrder::Prefix, 8, 3, 32, 1, 1, 0)
.expect("honest roundtrip should accept");
}
#[test]
fn long_mask_horner_path_suffix() {
run_roundtrip(VariableOrder::Suffix, 8, 3, 32, 1, 1, 0)
.expect("honest roundtrip should accept");
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(16))]
#[test]
fn prop_completeness_prefix(
n_vars in 3usize..=8,
ell_zk in 3usize..=5,
num_concrete in 0usize..=2,
num_virtual in 0usize..=2,
seed in 0u64..1024,
) {
prop_assume!(num_concrete + num_virtual >= 1);
let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
prop_assume!(n_vars > k_pack);
let folding_factor = 1 + (seed as usize % (n_vars - k_pack));
prop_assert!(
run_roundtrip(VariableOrder::Prefix, n_vars, folding_factor, ell_zk, num_concrete, num_virtual, seed)
.is_ok()
);
}
#[test]
fn prop_completeness_suffix(
n_vars in 3usize..=8,
ell_zk in 3usize..=5,
num_concrete in 0usize..=2,
num_virtual in 0usize..=2,
seed in 0u64..1024,
) {
prop_assume!(num_concrete + num_virtual >= 1);
let folding_factor = 1 + (seed as usize % (n_vars - 1));
prop_assert!(
run_roundtrip(VariableOrder::Suffix, n_vars, folding_factor, ell_zk, num_concrete, num_virtual, seed)
.is_ok()
);
}
}
}