provekit-whir 0.1.1

An implementation of the WHIR polynomial commitment scheme
Documentation
use ark_ff::{AdditiveGroup, FftField, Field};
#[cfg(feature = "tracing")]
use tracing::instrument;

use super::{Commitment, Config};
use crate::{
    algebra::{
        dot,
        embedding::{Embedding, Identity},
        linear_form::{Evaluate, LinearForm, MultilinearExtension},
        tensor_product, MultilinearPoint,
    },
    hash::Hash,
    protocols::{geometric_challenge::geometric_challenge, whir::FinalClaim},
    transcript::{
        codecs::U64, Codec, Decoding, DuplexSpongeInterface, ProverMessage, VerificationResult,
        VerifierMessage, VerifierState,
    },
    utils::zip_strict,
    verify,
};

impl<M: Embedding> Config<M>
where
    M::Source: FftField,
    M::Target: FftField,
{
    /// Verify a batched WHIR proof for multiple commitments.
    ///
    /// This verifies a batch proof generated by `prove_batch`. The verifier reads the N×M
    /// constraint evaluation matrix from the transcript, samples the batching randomness γ,
    /// and reconstructs the combined constraints using RLC. Round 0 verifies openings in all
    /// N original commitment trees, while subsequent rounds verify the single batched vector.
    ///
    /// Returns the constraint evaluation point and values of deferred constraints.
    #[allow(clippy::too_many_lines)]
    #[cfg_attr(feature = "tracing", instrument(skip_all, name = "whir::verify"))]
    pub fn verify<H>(
        &self,
        verifier_state: &mut VerifierState<'_, H>,
        commitments: &[&Commitment<M::Target>],
        evaluations: &[M::Target],
    ) -> VerificationResult<FinalClaim<M::Target>>
    where
        H: DuplexSpongeInterface,
        M::Target: Codec<[H::U]>,
        u8: Decoding<[H::U]>,
        [u8; 32]: Decoding<[H::U]>,
        U64: Codec<[H::U]>,
        Hash: ProverMessage<[H::U]>,
    {
        let num_vectors = commitments.len() * self.initial_committer.num_vectors;
        verify!(evaluations.len().is_multiple_of(num_vectors));
        let num_linear_forms = evaluations.len() / num_vectors;
        if num_vectors == 0 {
            return Ok(FinalClaim::default());
        }

        // Complete the constraint and evaluation matrix with OODs and their cross-terms.
        let (oods_evals, oods_matrix) = {
            let mut oods_evals = Vec::new();
            let mut oods_matrix = Vec::new();

            // OOD weights from each commitment, evaluated for each vector
            let mut vector_offset = 0;
            for commitment in commitments {
                for (weights, oods_row) in zip_strict(
                    commitment.out_of_domain().evaluators(self.initial_size()),
                    commitment.out_of_domain().rows(),
                ) {
                    for j in 0..num_vectors {
                        if j >= vector_offset && j < oods_row.len() + vector_offset {
                            oods_matrix.push(oods_row[j - vector_offset]);
                        } else {
                            oods_matrix.push(verifier_state.prover_message()?);
                        }
                    }
                    oods_evals.push(weights);
                }
                vector_offset += commitment.num_vectors();
            }
            (oods_evals, oods_matrix)
        };

        // Random linear combination of the vectors.
        let vector_rlc_coeffs = geometric_challenge(verifier_state, num_vectors);

        // Random linear combination of the constraints.
        let constraint_rlc_coeffs: Vec<M::Target> =
            geometric_challenge(verifier_state, oods_evals.len() + num_linear_forms);
        let (initial_form_rlc_coeffs, oods_rlc_coeffs) =
            constraint_rlc_coeffs.split_at(num_linear_forms);

        // Compute "The Sum"
        let mut the_sum = zip_strict(
            initial_form_rlc_coeffs,
            evaluations.chunks_exact(num_vectors),
        )
        .map(|(poly_coeff, row)| *poly_coeff * dot(&vector_rlc_coeffs, row))
        .sum::<M::Target>();
        the_sum += zip_strict(oods_rlc_coeffs, oods_matrix.chunks_exact(num_vectors))
            .map(|(poly_coeff, row)| *poly_coeff * dot(&vector_rlc_coeffs, row))
            .sum::<M::Target>();
        let mut round_constraints = vec![(oods_rlc_coeffs.to_vec(), oods_evals)];

        let mut round_folding_randomness = Vec::new();

        // Run initial sumcheck on batched vector with combined statement
        let folding_randomness = if constraint_rlc_coeffs.is_empty() {
            // There are no constraints yet, so we can skip the sumcheck.
            // (If we did run it, all sumcheck polynomials would be constant zero)
            assert_eq!(the_sum, M::Target::ZERO);
            let folding_randomness =
                verifier_state.verifier_message_vec(self.initial_sumcheck.num_rounds);
            self.initial_skip_pow.verify(verifier_state)?;
            MultilinearPoint(folding_randomness)
        } else {
            self.initial_sumcheck.verify(verifier_state, &mut the_sum)?
        };
        round_folding_randomness.push(folding_randomness);

        let (final_vector, final_sumcheck_randomness) = if self.round_configs.is_empty() {
            // 0-rounds case: open initial commitment and run final sumcheck directly.
            let final_vector =
                verifier_state.prover_messages_vec(self.final_sumcheck.initial_size)?;
            self.final_pow.verify(verifier_state)?;

            let in_domain = self.initial_committer.verify(verifier_state, commitments)?;
            let in_domain = in_domain.lift(self.embedding());

            for (weights, evals) in zip_strict(
                in_domain.evaluators(final_vector.len()),
                in_domain.values(&tensor_product(
                    &vector_rlc_coeffs,
                    &round_folding_randomness.last().unwrap().eq_weights(),
                )),
            ) {
                verify!(weights.evaluate(&Identity::<M::Target>::new(), &final_vector) == evals);
            }

            let final_sumcheck_randomness =
                self.final_sumcheck.verify(verifier_state, &mut the_sum)?;
            round_folding_randomness.push(final_sumcheck_randomness.clone());
            (final_vector, final_sumcheck_randomness)
        } else {
            // Round 0: open initial commitments with embedding lift and tensor_product.
            let round0_config = &self.round_configs[0];
            let commitment_h = round0_config
                .irs_committer
                .receive_commitment(verifier_state)?;
            round0_config.pow.verify(verifier_state)?;

            let in_domain = self.initial_committer.verify(verifier_state, commitments)?;
            // TODO: Skip lift and keep initial in-domain in subfield for evaluation.
            let in_domain = in_domain.lift(self.embedding());

            let constraint_weights = commitment_h
                .out_of_domain()
                .evaluators(round0_config.initial_size())
                .chain(in_domain.evaluators(round0_config.initial_size()))
                .collect::<Vec<_>>();
            let constraint_values = commitment_h
                .out_of_domain()
                .values(&[M::Target::ONE])
                .chain(in_domain.values(&tensor_product(
                    &vector_rlc_coeffs,
                    &round_folding_randomness.last().unwrap().eq_weights(),
                )))
                .collect::<Vec<_>>();
            let constraint_rlc_coeffs =
                geometric_challenge(verifier_state, constraint_values.len());
            the_sum += dot(&constraint_rlc_coeffs, &constraint_values);
            round_constraints.push((constraint_rlc_coeffs, constraint_weights));

            let folding_randomness = round0_config
                .sumcheck
                .verify(verifier_state, &mut the_sum)?;
            round_folding_randomness.push(folding_randomness.clone());

            // Rounds 1..N + final round.
            let remaining = super::rounds::verify_remaining_rounds(
                &self.round_configs,
                &super::rounds::FinalRoundConfig {
                    sumcheck: &self.final_sumcheck,
                    pow: &self.final_pow,
                },
                verifier_state,
                &mut the_sum,
                &commitment_h,
                &folding_randomness,
            )?;

            round_constraints.extend(remaining.round_constraints);
            round_folding_randomness.extend(remaining.round_folding_randomness);
            round_folding_randomness.push(remaining.final_sumcheck_randomness.clone());
            (remaining.final_vector, remaining.final_sumcheck_randomness)
        };

        // Compute folding randomness across all rounds
        let evaluation_point = round_folding_randomness
            .into_iter()
            .flat_map(|poly| poly.0.into_iter())
            .collect::<Vec<_>>();

        // Compute the claimed rlc of the linear form mles from the sumcheck invariant.
        let poly_eval = MultilinearExtension::new(final_sumcheck_randomness.0)
            .evaluate(&Identity::new(), &final_vector);
        let mut linear_form_rlc = the_sum / poly_eval;

        // Subtract all internal linear forms.
        for (round, (weights_rlc_coeffs, weights)) in round_constraints.into_iter().enumerate() {
            let num_variables = round.checked_sub(1).map_or_else(
                || self.initial_num_variables(),
                |p| self.round_configs[p].initial_num_variables(),
            );
            let start = evaluation_point.len().saturating_sub(num_variables);
            for (rlc_coeff, weights) in zip_strict(weights_rlc_coeffs, weights) {
                linear_form_rlc -= rlc_coeff * weights.mle_evaluate(&evaluation_point[start..]);
            }
        }

        // Return the evaluation point and the claimed values of the deferred weights.
        Ok(FinalClaim {
            evaluation_point,
            rlc_coefficients: initial_form_rlc_coeffs.to_vec(),
            linear_form_rlc,
        })
    }
}