use spongefish::{ProverState, VerificationError, VerificationResult, VerifierState};
use crate::field::SumcheckField;
use super::code::{Field as WhirField, InterleavedCode, LinearCode};
use super::commitment::{
CodeCommitmentProverState, ExplicitCodeCommitmentHandle, FoldedCodeCommitmentHandle,
};
use super::linear_form::{FoldedFormHandle, LinearConstraint, LinearForm, LinearFormHandle};
use super::mask_stack::MaskOracleHandle;
use super::vc::{MerkleVc, VectorCommitment};
use crate::field::Goldilocks4;
fn fold_in_place<F>(a: &mut Vec<F>, w: F)
where
F: Copy + core::ops::Add<Output = F> + core::ops::Sub<Output = F> + core::ops::Mul<Output = F>,
{
let n = a.len();
if n <= 1 {
return;
}
let half = n.next_power_of_two() >> 1;
for k in 0..half.min(n - half) {
let lo = a[k];
let hi = a[k + half];
a[k] = lo + w * (hi - lo);
}
a.truncate(half);
}
fn round_poly<F: crate::field::SumcheckField>(a: &[F], b: &[F]) -> (F, F) {
let n = a.len();
if n <= 1 {
let v = if n == 1 { a[0] * b[0] } else { F::ZERO };
return (v, F::ZERO);
}
let half = n.next_power_of_two() >> 1;
let paired = half.min(n - half);
let mut q0 = F::ZERO;
let mut q_inf = F::ZERO;
for k in 0..paired {
let al = a[k];
let ah = a[k + half];
let bl = b[k];
let bh = b[k + half];
q0 += al * bl;
q_inf += (ah - al) * (bh - bl);
}
for k in paired..half.min(n) {
let al = a[k];
let bl = b[k];
let dot = al * bl;
q0 += dot;
q_inf += dot;
}
(q0, q_inf)
}
pub(crate) const HVZK_MASK_LENGTH: usize = 3;
#[allow(clippy::type_complexity)]
pub(crate) fn prove_sumcheck<EC, VC>(
transcript: &mut ProverState,
input: CodeCommitmentProverState<InterleavedCode<EC>, VC>,
constraint: LinearForm<EC::InputAlphabet>,
mask_coeffs: &[EC::Alphabet],
) -> (
super::commitment::FoldedCodeCommitmentProverState<EC, VC>,
LinearForm<EC::InputAlphabet>,
Vec<EC::Alphabet>,
EC::Alphabet,
)
where
EC: LinearCode,
EC::Alphabet: crate::field::SumcheckField,
VC: VectorCommitment<Alphabet = Vec<EC::OutputAlphabet>>,
{
let interleaving = input.code.interleaving_factor();
assert!(interleaving > 0 && interleaving.is_power_of_two());
let num_rounds = interleaving.ilog2() as usize;
assert_eq!(input.msg.len(), constraint.coefficients().len());
let masks_present = !mask_coeffs.is_empty();
if masks_present {
assert_eq!(mask_coeffs.len(), num_rounds * HVZK_MASK_LENGTH);
}
let mut a = input.msg.clone();
let mut b = constraint.into_coefficients();
let mut mask_sum = EC::Alphabet::ZERO;
let mut mask_rlc = EC::Alphabet::ONE;
let mut running_sum = if masks_present {
let s: EC::Alphabet = a.iter().zip(b.iter()).map(|(x, y)| *x * *y).sum();
Some(s)
} else {
None
};
if masks_present {
let sum_multiple_initial = pow2::<EC::Alphabet>(num_rounds.saturating_sub(1));
let total_eval01: EC::Alphabet = mask_coeffs
.chunks_exact(HVZK_MASK_LENGTH)
.map(eval_01)
.fold(EC::Alphabet::ZERO, |acc, x| acc + x);
mask_sum = total_eval01 * sum_multiple_initial;
transcript.prover_message(&mask_sum);
mask_rlc = transcript.verifier_message();
}
let mut prev_challenge: Option<EC::Alphabet> = None;
let mut all_challenges: Vec<EC::Alphabet> = Vec::with_capacity(num_rounds);
let half_inv: EC::Alphabet = {
(EC::Alphabet::ONE + EC::Alphabet::ONE)
.inverse()
.expect("char ≠ 2")
};
for round_idx in 0..num_rounds {
if let Some(w) = prev_challenge {
fold_in_place(&mut a, w);
fold_in_place(&mut b, w);
}
let (q0, q_inf) = round_poly(&a, &b);
if !masks_present {
transcript.prover_message(&q0);
transcript.prover_message(&q_inf);
let r: EC::Alphabet = transcript.verifier_message();
all_challenges.push(r);
prev_challenge = Some(r);
continue;
}
let mask = &mask_coeffs[round_idx * HVZK_MASK_LENGTH..(round_idx + 1) * HVZK_MASK_LENGTH];
let sum_multiple = pow2::<EC::Alphabet>(num_rounds.saturating_sub(round_idx + 1));
let current_sum = running_sum.expect("running sum tracked");
let q1 = current_sum - q0.double() - q_inf;
let constant_adj = (mask_sum - sum_multiple * eval_01(mask)) * half_inv;
let h0 = sum_multiple * mask[0] + constant_adj + mask_rlc * q0;
let h1 = sum_multiple * mask[1] + mask_rlc * q1;
let h_inf = sum_multiple * mask[2] + mask_rlc * q_inf;
transcript.prover_message(&h0);
transcript.prover_message(&h_inf);
let r: EC::Alphabet = transcript.verifier_message();
all_challenges.push(r);
let new_sum = q0 + q1 * r + q_inf * r * r;
let new_univariate_at_r = h0 + h1 * r + h_inf * r * r;
mask_sum = new_univariate_at_r - mask_rlc * new_sum;
running_sum = Some(new_sum);
prev_challenge = Some(r);
}
if let Some(w) = prev_challenge {
fold_in_place(&mut a, w);
fold_in_place(&mut b, w);
}
(
super::commitment::FoldedCodeCommitmentProverState {
inner: input,
msg: a,
},
LinearForm::new(b),
all_challenges,
mask_rlc,
)
}
#[allow(clippy::type_complexity)]
pub(crate) fn prove_sumcheck_zk(
transcript: &mut ProverState,
input: CodeCommitmentProverState<
InterleavedCode<super::code::ReedSolomon<Goldilocks4>>,
MerkleVc,
>,
constraint: LinearForm<Goldilocks4>,
mask_seed: &[u8; 32],
salt: &[u8],
) -> (
super::commitment::FoldedCodeCommitmentProverState<
super::code::ReedSolomon<Goldilocks4>,
MerkleVc,
>,
LinearForm<Goldilocks4>,
Vec<MaskOracleHandle>,
Vec<Goldilocks4>,
Goldilocks4,
Vec<Goldilocks4>,
) {
let interleaving = input.code.interleaving_factor();
let k = interleaving.ilog2() as usize;
let l_zk = crate::params::Params::L_ZK;
let l_zk_inner = crate::params::Params::M_ZK - crate::params::Params::T_ZK;
let t_zk = crate::params::Params::T_ZK;
let zk_enc = super::encoding::ZkEncoding::new(l_zk_inner, t_zk);
let mut all_mask_coeffs: Vec<Goldilocks4> = Vec::with_capacity(k * HVZK_MASK_LENGTH);
let mut mask_handles: Vec<MaskOracleHandle> = Vec::with_capacity(k);
let mut mask_polys: Vec<Vec<Goldilocks4>> = Vec::with_capacity(k);
for j in 0..k {
let mut poly_salt = salt.to_vec();
poly_salt.extend_from_slice(b"::poly::");
poly_salt.extend_from_slice(&(j as u64).to_le_bytes());
let s_j_poly = super::base_case::derive_field_vec(mask_seed, &poly_salt, l_zk);
all_mask_coeffs.extend_from_slice(&s_j_poly);
let mut s_j_msg = vec![Goldilocks4::ZERO; l_zk_inner];
for (i, c) in s_j_poly.iter().enumerate() {
s_j_msg[i] = *c;
}
let mut r_salt = salt.to_vec();
r_salt.extend_from_slice(b"::r::");
r_salt.extend_from_slice(&(j as u64).to_le_bytes());
let s_j_r = super::base_case::derive_field_vec(mask_seed, &r_salt, t_zk);
let s_j_codeword = zk_enc.encode_with(&s_j_msg, &s_j_r);
let s_j_slab = super::code::CodewordSlab::new(s_j_codeword, 1);
let s_j_vc = MerkleVc::new(s_j_slab.positions());
let (s_j_root, s_j_state) = s_j_vc.commit_slab(s_j_slab);
transcript.prover_message(&s_j_root);
mask_polys.push(s_j_poly);
mask_handles.push(MaskOracleHandle::new_prover(
s_j_msg, s_j_r, s_j_vc, s_j_state,
));
}
let (folded_state, folded_constraint, challenges, epsilon) =
prove_sumcheck(transcript, input, constraint, &all_mask_coeffs);
let mut mask_targets: Vec<Goldilocks4> = Vec::with_capacity(k);
for (s_j_poly, gamma) in mask_polys.iter().zip(&challenges) {
let mut acc = Goldilocks4::ZERO;
for c in s_j_poly.iter().rev() {
acc = acc * *gamma + *c;
}
transcript.prover_message(&acc);
mask_targets.push(acc);
}
(
folded_state,
folded_constraint,
mask_handles,
challenges,
epsilon,
mask_targets,
)
}
#[allow(clippy::type_complexity)]
pub(crate) fn verify_sumcheck_zk(
transcript: &mut VerifierState,
commitment: ExplicitCodeCommitmentHandle<
InterleavedCode<super::code::ReedSolomon<Goldilocks4>>,
MerkleVc,
>,
constraint: LinearConstraint<FoldedFormHandle<Goldilocks4>>,
) -> VerificationResult<(
FoldedCodeCommitmentHandle<super::code::ReedSolomon<Goldilocks4>, MerkleVc>,
LinearConstraint<FoldedFormHandle<Goldilocks4>>,
Vec<MaskOracleHandle>,
Vec<Goldilocks4>,
Goldilocks4,
Vec<Goldilocks4>,
)> {
let k = commitment.code.interleaving_factor().ilog2() as usize;
let mut mask_handles: Vec<MaskOracleHandle> = Vec::with_capacity(k);
for _ in 0..k {
let root: [u8; 32] = transcript.prover_message()?;
mask_handles.push(MaskOracleHandle::verifier_root_only(root));
}
let (folded_commitment, folded_constraint, challenges, epsilon) =
verify_sumcheck(transcript, commitment, constraint, true)?;
let mut mask_targets: Vec<Goldilocks4> = Vec::with_capacity(k);
for _ in 0..k {
mask_targets.push(transcript.prover_message()?);
}
Ok((
folded_commitment,
folded_constraint,
mask_handles,
challenges,
epsilon,
mask_targets,
))
}
fn eval_01<F: crate::field::SumcheckField>(coeffs: &[F]) -> F {
if coeffs.is_empty() {
return F::ZERO;
}
let mut sum = F::ZERO;
for c in coeffs {
sum += *c;
}
coeffs[0] + sum
}
fn pow2<F: crate::field::SumcheckField>(n: usize) -> F {
let mut acc = F::ONE;
for _ in 0..n {
acc = acc.double();
}
acc
}
fn inline_sumcheck_verify<F: WhirField>(
transcript: &mut VerifierState,
claimed_sum: F,
num_rounds: usize,
hvzk: bool,
) -> VerificationResult<(F, Vec<F>, F)> {
let mut claim = claimed_sum;
let mut mask_rlc_out: F = F::ONE;
if hvzk {
let mask_sum: F = transcript.prover_message()?;
let mask_rlc: F = transcript.verifier_message();
claim = mask_sum + mask_rlc * claim;
mask_rlc_out = mask_rlc;
}
let mut challenges = Vec::with_capacity(num_rounds);
for _ in 0..num_rounds {
let h0: F = transcript.prover_message()?;
let h_inf: F = transcript.prover_message()?;
let h1 = claim - h0.double() - h_inf;
let r: F = transcript.verifier_message();
challenges.push(r);
claim = h0 + h1 * r + h_inf * r * r;
}
Ok((claim, challenges, mask_rlc_out))
}
#[allow(clippy::type_complexity)]
pub(crate) fn verify_sumcheck<EC, VC, LFH>(
transcript: &mut VerifierState,
commitment: ExplicitCodeCommitmentHandle<InterleavedCode<EC>, VC>,
constraint: LinearConstraint<LFH>,
hvzk: bool,
) -> VerificationResult<(
FoldedCodeCommitmentHandle<EC, VC>,
LinearConstraint<FoldedFormHandle<EC::Alphabet>>,
Vec<EC::Alphabet>,
EC::Alphabet,
)>
where
EC: LinearCode,
VC: VectorCommitment<Alphabet = Vec<EC::Alphabet>>,
LFH: LinearFormHandle<Alphabet = EC::Alphabet> + 'static,
{
use super::commitment::CodeCommitmentHandle;
let n = commitment.code().interleaving_factor();
if n == 0
|| !n.is_power_of_two()
|| constraint.linear_form_handle.form_size() != commitment.msg_len()
{
return Err(VerificationError);
}
let num_rounds = n.ilog2() as usize;
let (final_claim, challenges, mask_rlc) =
inline_sumcheck_verify(transcript, constraint.value, num_rounds, hvzk)?;
Ok((
FoldedCodeCommitmentHandle {
inner: commitment,
rand: challenges.clone(),
},
LinearConstraint {
linear_form_handle: FoldedFormHandle {
linear_form_handle: Box::new(constraint.linear_form_handle),
rand: challenges.clone(),
scale: mask_rlc,
},
value: final_claim,
},
challenges,
mask_rlc,
))
}