Skip to main content

p3_sumcheck/zk/
data.rs

1//! Transcript schema and oracle handle for the HVZK sumcheck.
2
3use alloc::vec::Vec;
4
5use p3_commit::Mmcs;
6use p3_field::{ExtensionField, Field, HornerIter};
7use p3_matrix::dense::RowMajorMatrix;
8use p3_multilinear_util::point::Point;
9use serde::{Deserialize, Serialize};
10
11use crate::strategy::SumcheckProver;
12
13/// Per-round prover output of the HVZK sumcheck protocol.
14///
15/// - Prover writes;
16/// - Verifier reads back during Fiat-Shamir replay.
17///
18/// One instance covers a full run of `k` rounds.
19///
20/// # Wire format
21///
22/// Per round, the polynomial has coefficient layout
23///
24/// ```text
25///     [ c_0, c_1, c_2, ..., c_d ]    with  d = max(ell_zk - 1, 2)
26/// ```
27///
28/// The linear coefficient `c_1` is dropped on the wire.
29///
30/// The verifier reconstructs `c_1` from the affine identity
31///
32/// ```text
33///     h_j(0) + h_j(1) = 2 * c_0 + sum_{i >= 1} c_i = target
34/// ```
35///
36/// applied to the previous round's target.
37///
38/// # Soundness link to Lemma 6.4
39///
40/// Valid transcripts form an affine subspace of dimension `1 + k * (ell_zk - 1)`.
41/// The `k` dropped linear coefficients are exactly the redundant degrees of freedom of the rank-nullity argument.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ZkSumcheckData<F, EF> {
44    /// Sum of all mask polynomial evaluations across the boolean hypercube `{0,1}^k`.
45    ///
46    /// Observed on the transcript before the verifier samples the combining challenge.
47    /// Lives in the extension field because the mask coefficients do.
48    pub mu_tilde: EF,
49
50    /// Message length of the zero-knowledge mask code.
51    ///
52    /// The verifier rejects up front if its own expected value disagrees with this.
53    /// Pinning this in the transcript closes a non-injectivity gap in the wire-length check: lengths `2` and `3` share a wire layout.
54    pub ell_zk: usize,
55
56    /// Per-round wire payload with the linear coefficient dropped.
57    ///
58    /// One entry per sumcheck round.
59    /// Layout per entry: `[c_0, c_2, c_3, ..., c_d]` with `d = max(ell_zk - 1, 2)`.
60    pub round_coefficients: Vec<Vec<EF>>,
61
62    /// Per-round proof-of-work witnesses.
63    ///
64    /// Length equals the number of rounds when grinding is enabled.
65    /// Empty when `pow_bits == 0`.
66    pub pow_witnesses: Vec<F>,
67}
68
69impl<F, EF: Field> Default for ZkSumcheckData<F, EF> {
70    fn default() -> Self {
71        Self {
72            // Real runs overwrite this in step 2 once the prover has summed the masks.
73            mu_tilde: EF::ZERO,
74            // Sentinel: honest runs set this to the encoding's message length; the verifier rejects 0.
75            ell_zk: 0,
76            // Filled with one wire entry per sumcheck round.
77            round_coefficients: Vec::new(),
78            // Filled only when grinding is enabled.
79            pow_witnesses: Vec::new(),
80        }
81    }
82}
83
84/// Handle to one committed batch of interleaved mask codewords.
85///
86/// - Pairs the public Merkle root with the prover-side data needed to open
87///   the batch at requested positions.
88/// - Row `z` of the committed matrix holds position `z` of every mask in
89///   the batch.
90/// - One Merkle path therefore authenticates all of them.
91pub type MaskOracle<EF, M> = (
92    <M as Mmcs<EF>>::Commitment,
93    <M as Mmcs<EF>>::ProverData<RowMajorMatrix<EF>>,
94);
95
96/// Typed prover handoff produced by the HVZK sumcheck.
97///
98/// - Downstream code-switching needs both the residual prover and the
99///   sampled `eps` scale.
100/// - A named type makes the Construction 6.3 to Construction 9.7 boundary
101///   explicit.
102pub struct ZkSumcheckHandoff<F, EF, M>
103where
104    F: Field,
105    EF: ExtensionField<F>,
106    M: Mmcs<EF>,
107{
108    /// Residual sumcheck prover whose claim is scaled by `eps`.
109    pub residual_prover: SumcheckProver<F, EF>,
110    /// Per-round sumcheck challenges.
111    pub randomness: Point<EF>,
112    /// Construction 6.3 combining challenge.
113    pub eps: EF,
114    /// Plain mask messages sampled by the prover, in round order.
115    ///
116    /// These are prover-only witnesses. Code-switch composition uses them to
117    /// carry the verifier-visible masked residual as auxiliary linear claims.
118    pub mask_messages: Vec<Vec<EF>>,
119    /// Encoding randomness used for each mask, in round order.
120    ///
121    /// Prover-only. The HVZK base case reveals blinded combinations
122    /// `r* = r' + gamma * r`, which requires the raw values.
123    pub mask_randomness: Vec<Vec<EF>>,
124    /// The batch's interleaved mask oracle: one commitment, `k` columns.
125    pub mask_oracle: MaskOracle<EF, M>,
126}
127
128/// Typed verifier handoff produced by replaying an HVZK sumcheck transcript.
129///
130/// This mirrors [`ZkSumcheckHandoff`] without prover-only mask data.
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct ZkVerifierHandoff<EF> {
133    /// Per-round sumcheck challenges.
134    pub randomness: Point<EF>,
135    /// Residual claim after replay.
136    pub claimed_residual: EF,
137    /// Construction 6.3 combining challenge.
138    pub eps: EF,
139}
140
141/// Evaluates the final verifier-visible mask residual after all HVZK sumcheck rounds.
142///
143/// For masks `s_j(X)` and verifier challenges `gamma_j`, the mask part of the
144/// final Construction 6.3 target is:
145///
146/// ```text
147///     sum_j s_j(gamma_j)
148/// ```
149///
150/// This is the closed form of the live/past/future mask recurrence used while
151/// assembling the round polynomials.
152#[must_use]
153pub fn mask_residual<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
154where
155    EF: Field,
156{
157    assert_eq!(masks.len(), gammas.len());
158    masks
159        .iter()
160        .zip(gammas)
161        .map(|(mask, &gamma)| mask.iter().copied().horner(gamma))
162        .sum()
163}
164
165/// Linear covectors whose dot products with the masks equal [`mask_residual`].
166#[must_use]
167pub fn mask_residual_covectors<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> Vec<Vec<EF>>
168where
169    EF: Field,
170{
171    assert!(
172        masks
173            .iter()
174            .all(|mask| mask.len() == masks.first().map_or(0, Vec::len))
175    );
176    mask_residual_covectors_from_shape(masks.len(), masks.first().map_or(0, Vec::len), gammas)
177}
178
179/// Linear covectors for masks with a known rectangular shape.
180///
181/// The covector for mask `s_j` is `[1, gamma_j, gamma_j^2, ...]`.
182/// Code-switch composition carries these as the fresh sumcheck-mask claims.
183#[must_use]
184pub fn mask_residual_covectors_from_shape<EF: Field>(
185    mask_count: usize,
186    mask_len: usize,
187    gammas: &[EF],
188) -> Vec<Vec<EF>> {
189    assert_eq!(mask_count, gammas.len());
190    gammas
191        .iter()
192        .map(|gamma| gamma.powers().collect_n(mask_len))
193        .collect()
194}
195
196#[cfg(test)]
197mod tests {
198    use alloc::vec;
199    use alloc::vec::Vec;
200
201    use p3_baby_bear::BabyBear;
202    use p3_field::extension::BinomialExtensionField;
203    use p3_field::{Field, PrimeCharacteristicRing, dot_product};
204
205    use super::{mask_residual, mask_residual_covectors};
206
207    type F = BabyBear;
208    type EF = BinomialExtensionField<F, 4>;
209
210    fn reference_mask_recurrence<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
211    where
212        EF: Field,
213    {
214        assert_eq!(masks.len(), gammas.len());
215        let k = masks.len();
216        if k == 0 {
217            return EF::ZERO;
218        }
219
220        let pow2: Vec<EF> = EF::TWO.powers().collect_n(k + 1);
221        let mut mask_evals_at_gamma = Vec::with_capacity(k);
222        let mut sum_future_endpoints: EF = masks
223            .iter()
224            .map(|mask| mask[0].double() + mask[1..].iter().copied().sum::<EF>())
225            .sum();
226        let mut target = EF::ZERO;
227
228        for (round_idx, (s_j, &gamma_j)) in masks.iter().zip(gammas).enumerate() {
229            let j = round_idx + 1;
230            let s_j_endpoints = s_j[0].double() + s_j[1..].iter().copied().sum::<EF>();
231            sum_future_endpoints -= s_j_endpoints;
232
233            let h_size = s_j.len().max(3);
234            let mut h = EF::zero_vec(h_size);
235            let mult_live = pow2[k - j];
236            for (i, &c) in s_j.iter().enumerate() {
237                h[i] += mult_live * c;
238            }
239
240            let past_mask_sum: EF = mask_evals_at_gamma.iter().copied().sum();
241            h[0] += past_mask_sum * mult_live;
242            if j < k {
243                h[0] += pow2[k - j - 1] * sum_future_endpoints;
244            }
245
246            target = h
247                .iter()
248                .rev()
249                .copied()
250                .fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
251
252            let s_j_at_gamma = s_j
253                .iter()
254                .rev()
255                .copied()
256                .fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
257            mask_evals_at_gamma.push(s_j_at_gamma);
258        }
259
260        target
261    }
262
263    #[test]
264    fn mask_residual_closed_form_matches_round_recurrence() {
265        let masks = vec![
266            vec![
267                EF::from_u64(3),
268                EF::from_u64(5),
269                EF::from_u64(7),
270                EF::from_u64(11),
271            ],
272            vec![
273                EF::from_u64(13),
274                EF::from_u64(17),
275                EF::from_u64(19),
276                EF::from_u64(23),
277            ],
278            vec![
279                EF::from_u64(29),
280                EF::from_u64(31),
281                EF::from_u64(37),
282                EF::from_u64(41),
283            ],
284        ];
285        let gammas = vec![EF::from_u64(43), EF::from_u64(47), EF::from_u64(53)];
286
287        assert_eq!(
288            mask_residual::<EF>(&masks, &gammas),
289            reference_mask_recurrence::<EF>(&masks, &gammas),
290        );
291    }
292
293    #[test]
294    fn mask_residual_covectors_evaluate_closed_form() {
295        let masks = vec![
296            vec![EF::from_u64(2), EF::from_u64(3), EF::from_u64(5)],
297            vec![EF::from_u64(7), EF::from_u64(11), EF::from_u64(13)],
298        ];
299        let gammas = vec![EF::from_u64(17), EF::from_u64(19)];
300        let covectors = mask_residual_covectors::<EF>(&masks, &gammas);
301        let by_covectors = masks
302            .iter()
303            .zip(&covectors)
304            .map(|(mask, covector)| {
305                dot_product::<EF, _, _>(mask.iter().copied(), covector.iter().copied())
306            })
307            .sum::<EF>();
308
309        assert_eq!(by_covectors, mask_residual::<EF>(&masks, &gammas));
310    }
311}