use alloc::vec::Vec;
use p3_commit::Mmcs;
use p3_field::{ExtensionField, Field, HornerIter};
use p3_matrix::dense::RowMajorMatrix;
use p3_multilinear_util::point::Point;
use serde::{Deserialize, Serialize};
use crate::strategy::SumcheckProver;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ZkSumcheckData<F, EF> {
pub mu_tilde: EF,
pub ell_zk: usize,
pub round_coefficients: Vec<Vec<EF>>,
pub pow_witnesses: Vec<F>,
}
impl<F, EF: Field> Default for ZkSumcheckData<F, EF> {
fn default() -> Self {
Self {
mu_tilde: EF::ZERO,
ell_zk: 0,
round_coefficients: Vec::new(),
pow_witnesses: Vec::new(),
}
}
}
pub type MaskOracle<EF, M> = (
<M as Mmcs<EF>>::Commitment,
<M as Mmcs<EF>>::ProverData<RowMajorMatrix<EF>>,
);
pub struct ZkSumcheckHandoff<F, EF, M>
where
F: Field,
EF: ExtensionField<F>,
M: Mmcs<EF>,
{
pub residual_prover: SumcheckProver<F, EF>,
pub randomness: Point<EF>,
pub eps: EF,
pub mask_messages: Vec<Vec<EF>>,
pub mask_randomness: Vec<Vec<EF>>,
pub mask_oracle: MaskOracle<EF, M>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ZkVerifierHandoff<EF> {
pub randomness: Point<EF>,
pub claimed_residual: EF,
pub eps: EF,
}
#[must_use]
pub fn mask_residual<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
where
EF: Field,
{
assert_eq!(masks.len(), gammas.len());
masks
.iter()
.zip(gammas)
.map(|(mask, &gamma)| mask.iter().copied().horner(gamma))
.sum()
}
#[must_use]
pub fn mask_residual_covectors<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> Vec<Vec<EF>>
where
EF: Field,
{
assert!(
masks
.iter()
.all(|mask| mask.len() == masks.first().map_or(0, Vec::len))
);
mask_residual_covectors_from_shape(masks.len(), masks.first().map_or(0, Vec::len), gammas)
}
#[must_use]
pub fn mask_residual_covectors_from_shape<EF: Field>(
mask_count: usize,
mask_len: usize,
gammas: &[EF],
) -> Vec<Vec<EF>> {
assert_eq!(mask_count, gammas.len());
gammas
.iter()
.map(|gamma| gamma.powers().collect_n(mask_len))
.collect()
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use p3_baby_bear::BabyBear;
use p3_field::extension::BinomialExtensionField;
use p3_field::{Field, PrimeCharacteristicRing, dot_product};
use super::{mask_residual, mask_residual_covectors};
type F = BabyBear;
type EF = BinomialExtensionField<F, 4>;
fn reference_mask_recurrence<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
where
EF: Field,
{
assert_eq!(masks.len(), gammas.len());
let k = masks.len();
if k == 0 {
return EF::ZERO;
}
let pow2: Vec<EF> = EF::TWO.powers().collect_n(k + 1);
let mut mask_evals_at_gamma = Vec::with_capacity(k);
let mut sum_future_endpoints: EF = masks
.iter()
.map(|mask| mask[0].double() + mask[1..].iter().copied().sum::<EF>())
.sum();
let mut target = EF::ZERO;
for (round_idx, (s_j, &gamma_j)) in masks.iter().zip(gammas).enumerate() {
let j = round_idx + 1;
let s_j_endpoints = s_j[0].double() + s_j[1..].iter().copied().sum::<EF>();
sum_future_endpoints -= s_j_endpoints;
let h_size = s_j.len().max(3);
let mut h = EF::zero_vec(h_size);
let mult_live = pow2[k - j];
for (i, &c) in s_j.iter().enumerate() {
h[i] += mult_live * c;
}
let past_mask_sum: EF = mask_evals_at_gamma.iter().copied().sum();
h[0] += past_mask_sum * mult_live;
if j < k {
h[0] += pow2[k - j - 1] * sum_future_endpoints;
}
target = h
.iter()
.rev()
.copied()
.fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
let s_j_at_gamma = s_j
.iter()
.rev()
.copied()
.fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
mask_evals_at_gamma.push(s_j_at_gamma);
}
target
}
#[test]
fn mask_residual_closed_form_matches_round_recurrence() {
let masks = vec![
vec![
EF::from_u64(3),
EF::from_u64(5),
EF::from_u64(7),
EF::from_u64(11),
],
vec![
EF::from_u64(13),
EF::from_u64(17),
EF::from_u64(19),
EF::from_u64(23),
],
vec![
EF::from_u64(29),
EF::from_u64(31),
EF::from_u64(37),
EF::from_u64(41),
],
];
let gammas = vec![EF::from_u64(43), EF::from_u64(47), EF::from_u64(53)];
assert_eq!(
mask_residual::<EF>(&masks, &gammas),
reference_mask_recurrence::<EF>(&masks, &gammas),
);
}
#[test]
fn mask_residual_covectors_evaluate_closed_form() {
let masks = vec![
vec![EF::from_u64(2), EF::from_u64(3), EF::from_u64(5)],
vec![EF::from_u64(7), EF::from_u64(11), EF::from_u64(13)],
];
let gammas = vec![EF::from_u64(17), EF::from_u64(19)];
let covectors = mask_residual_covectors::<EF>(&masks, &gammas);
let by_covectors = masks
.iter()
.zip(&covectors)
.map(|(mask, covector)| {
dot_product::<EF, _, _>(mask.iter().copied(), covector.iter().copied())
})
.sum::<EF>();
assert_eq!(by_covectors, mask_residual::<EF>(&masks, &gammas));
}
}