p3_sumcheck/zk/data.rs
1//! Transcript schema and oracle handle for the HVZK sumcheck.
2
3use alloc::vec::Vec;
4
5use p3_commit::Mmcs;
6use p3_field::{ExtensionField, Field, HornerIter};
7use p3_matrix::dense::RowMajorMatrix;
8use p3_multilinear_util::point::Point;
9use serde::{Deserialize, Serialize};
10
11use crate::strategy::SumcheckProver;
12
13/// Per-round prover output of the HVZK sumcheck protocol.
14///
15/// - Prover writes;
16/// - Verifier reads back during Fiat-Shamir replay.
17///
18/// One instance covers a full run of `k` rounds.
19///
20/// # Wire format
21///
22/// Per round, the polynomial has coefficient layout
23///
24/// ```text
25/// [ c_0, c_1, c_2, ..., c_d ] with d = max(ell_zk - 1, 2)
26/// ```
27///
28/// The linear coefficient `c_1` is dropped on the wire.
29///
30/// The verifier reconstructs `c_1` from the affine identity
31///
32/// ```text
33/// h_j(0) + h_j(1) = 2 * c_0 + sum_{i >= 1} c_i = target
34/// ```
35///
36/// applied to the previous round's target.
37///
38/// # Soundness link to Lemma 6.4
39///
40/// Valid transcripts form an affine subspace of dimension `1 + k * (ell_zk - 1)`.
41/// The `k` dropped linear coefficients are exactly the redundant degrees of freedom of the rank-nullity argument.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ZkSumcheckData<F, EF> {
44 /// Sum of all mask polynomial evaluations across the boolean hypercube `{0,1}^k`.
45 ///
46 /// Observed on the transcript before the verifier samples the combining challenge.
47 /// Lives in the extension field because the mask coefficients do.
48 pub mu_tilde: EF,
49
50 /// Message length of the zero-knowledge mask code.
51 ///
52 /// The verifier rejects up front if its own expected value disagrees with this.
53 /// Pinning this in the transcript closes a non-injectivity gap in the wire-length check: lengths `2` and `3` share a wire layout.
54 pub ell_zk: usize,
55
56 /// Per-round wire payload with the linear coefficient dropped.
57 ///
58 /// One entry per sumcheck round.
59 /// Layout per entry: `[c_0, c_2, c_3, ..., c_d]` with `d = max(ell_zk - 1, 2)`.
60 pub round_coefficients: Vec<Vec<EF>>,
61
62 /// Per-round proof-of-work witnesses.
63 ///
64 /// Length equals the number of rounds when grinding is enabled.
65 /// Empty when `pow_bits == 0`.
66 pub pow_witnesses: Vec<F>,
67}
68
69impl<F, EF: Field> Default for ZkSumcheckData<F, EF> {
70 fn default() -> Self {
71 Self {
72 // Real runs overwrite this in step 2 once the prover has summed the masks.
73 mu_tilde: EF::ZERO,
74 // Sentinel: honest runs set this to the encoding's message length; the verifier rejects 0.
75 ell_zk: 0,
76 // Filled with one wire entry per sumcheck round.
77 round_coefficients: Vec::new(),
78 // Filled only when grinding is enabled.
79 pow_witnesses: Vec::new(),
80 }
81 }
82}
83
84/// Handle to one committed batch of interleaved mask codewords.
85///
86/// - Pairs the public Merkle root with the prover-side data needed to open
87/// the batch at requested positions.
88/// - Row `z` of the committed matrix holds position `z` of every mask in
89/// the batch.
90/// - One Merkle path therefore authenticates all of them.
91pub type MaskOracle<EF, M> = (
92 <M as Mmcs<EF>>::Commitment,
93 <M as Mmcs<EF>>::ProverData<RowMajorMatrix<EF>>,
94);
95
96/// Typed prover handoff produced by the HVZK sumcheck.
97///
98/// - Downstream code-switching needs both the residual prover and the
99/// sampled `eps` scale.
100/// - A named type makes the Construction 6.3 to Construction 9.7 boundary
101/// explicit.
102pub struct ZkSumcheckHandoff<F, EF, M>
103where
104 F: Field,
105 EF: ExtensionField<F>,
106 M: Mmcs<EF>,
107{
108 /// Residual sumcheck prover whose claim is scaled by `eps`.
109 pub residual_prover: SumcheckProver<F, EF>,
110 /// Per-round sumcheck challenges.
111 pub randomness: Point<EF>,
112 /// Construction 6.3 combining challenge.
113 pub eps: EF,
114 /// Plain mask messages sampled by the prover, in round order.
115 ///
116 /// These are prover-only witnesses. Code-switch composition uses them to
117 /// carry the verifier-visible masked residual as auxiliary linear claims.
118 pub mask_messages: Vec<Vec<EF>>,
119 /// Encoding randomness used for each mask, in round order.
120 ///
121 /// Prover-only. The HVZK base case reveals blinded combinations
122 /// `r* = r' + gamma * r`, which requires the raw values.
123 pub mask_randomness: Vec<Vec<EF>>,
124 /// The batch's interleaved mask oracle: one commitment, `k` columns.
125 pub mask_oracle: MaskOracle<EF, M>,
126}
127
128/// Typed verifier handoff produced by replaying an HVZK sumcheck transcript.
129///
130/// This mirrors [`ZkSumcheckHandoff`] without prover-only mask data.
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub struct ZkVerifierHandoff<EF> {
133 /// Per-round sumcheck challenges.
134 pub randomness: Point<EF>,
135 /// Residual claim after replay.
136 pub claimed_residual: EF,
137 /// Construction 6.3 combining challenge.
138 pub eps: EF,
139}
140
141/// Evaluates the final verifier-visible mask residual after all HVZK sumcheck rounds.
142///
143/// For masks `s_j(X)` and verifier challenges `gamma_j`, the mask part of the
144/// final Construction 6.3 target is:
145///
146/// ```text
147/// sum_j s_j(gamma_j)
148/// ```
149///
150/// This is the closed form of the live/past/future mask recurrence used while
151/// assembling the round polynomials.
152#[must_use]
153pub fn mask_residual<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
154where
155 EF: Field,
156{
157 assert_eq!(masks.len(), gammas.len());
158 masks
159 .iter()
160 .zip(gammas)
161 .map(|(mask, &gamma)| mask.iter().copied().horner(gamma))
162 .sum()
163}
164
165/// Linear covectors whose dot products with the masks equal [`mask_residual`].
166#[must_use]
167pub fn mask_residual_covectors<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> Vec<Vec<EF>>
168where
169 EF: Field,
170{
171 assert!(
172 masks
173 .iter()
174 .all(|mask| mask.len() == masks.first().map_or(0, Vec::len))
175 );
176 mask_residual_covectors_from_shape(masks.len(), masks.first().map_or(0, Vec::len), gammas)
177}
178
179/// Linear covectors for masks with a known rectangular shape.
180///
181/// The covector for mask `s_j` is `[1, gamma_j, gamma_j^2, ...]`.
182/// Code-switch composition carries these as the fresh sumcheck-mask claims.
183#[must_use]
184pub fn mask_residual_covectors_from_shape<EF: Field>(
185 mask_count: usize,
186 mask_len: usize,
187 gammas: &[EF],
188) -> Vec<Vec<EF>> {
189 assert_eq!(mask_count, gammas.len());
190 gammas
191 .iter()
192 .map(|gamma| gamma.powers().collect_n(mask_len))
193 .collect()
194}
195
196#[cfg(test)]
197mod tests {
198 use alloc::vec;
199 use alloc::vec::Vec;
200
201 use p3_baby_bear::BabyBear;
202 use p3_field::extension::BinomialExtensionField;
203 use p3_field::{Field, PrimeCharacteristicRing, dot_product};
204
205 use super::{mask_residual, mask_residual_covectors};
206
207 type F = BabyBear;
208 type EF = BinomialExtensionField<F, 4>;
209
210 fn reference_mask_recurrence<EF>(masks: &[Vec<EF>], gammas: &[EF]) -> EF
211 where
212 EF: Field,
213 {
214 assert_eq!(masks.len(), gammas.len());
215 let k = masks.len();
216 if k == 0 {
217 return EF::ZERO;
218 }
219
220 let pow2: Vec<EF> = EF::TWO.powers().collect_n(k + 1);
221 let mut mask_evals_at_gamma = Vec::with_capacity(k);
222 let mut sum_future_endpoints: EF = masks
223 .iter()
224 .map(|mask| mask[0].double() + mask[1..].iter().copied().sum::<EF>())
225 .sum();
226 let mut target = EF::ZERO;
227
228 for (round_idx, (s_j, &gamma_j)) in masks.iter().zip(gammas).enumerate() {
229 let j = round_idx + 1;
230 let s_j_endpoints = s_j[0].double() + s_j[1..].iter().copied().sum::<EF>();
231 sum_future_endpoints -= s_j_endpoints;
232
233 let h_size = s_j.len().max(3);
234 let mut h = EF::zero_vec(h_size);
235 let mult_live = pow2[k - j];
236 for (i, &c) in s_j.iter().enumerate() {
237 h[i] += mult_live * c;
238 }
239
240 let past_mask_sum: EF = mask_evals_at_gamma.iter().copied().sum();
241 h[0] += past_mask_sum * mult_live;
242 if j < k {
243 h[0] += pow2[k - j - 1] * sum_future_endpoints;
244 }
245
246 target = h
247 .iter()
248 .rev()
249 .copied()
250 .fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
251
252 let s_j_at_gamma = s_j
253 .iter()
254 .rev()
255 .copied()
256 .fold(EF::ZERO, |acc, coeff| acc * gamma_j + coeff);
257 mask_evals_at_gamma.push(s_j_at_gamma);
258 }
259
260 target
261 }
262
263 #[test]
264 fn mask_residual_closed_form_matches_round_recurrence() {
265 let masks = vec![
266 vec![
267 EF::from_u64(3),
268 EF::from_u64(5),
269 EF::from_u64(7),
270 EF::from_u64(11),
271 ],
272 vec![
273 EF::from_u64(13),
274 EF::from_u64(17),
275 EF::from_u64(19),
276 EF::from_u64(23),
277 ],
278 vec![
279 EF::from_u64(29),
280 EF::from_u64(31),
281 EF::from_u64(37),
282 EF::from_u64(41),
283 ],
284 ];
285 let gammas = vec![EF::from_u64(43), EF::from_u64(47), EF::from_u64(53)];
286
287 assert_eq!(
288 mask_residual::<EF>(&masks, &gammas),
289 reference_mask_recurrence::<EF>(&masks, &gammas),
290 );
291 }
292
293 #[test]
294 fn mask_residual_covectors_evaluate_closed_form() {
295 let masks = vec![
296 vec![EF::from_u64(2), EF::from_u64(3), EF::from_u64(5)],
297 vec![EF::from_u64(7), EF::from_u64(11), EF::from_u64(13)],
298 ];
299 let gammas = vec![EF::from_u64(17), EF::from_u64(19)];
300 let covectors = mask_residual_covectors::<EF>(&masks, &gammas);
301 let by_covectors = masks
302 .iter()
303 .zip(&covectors)
304 .map(|(mask, covector)| {
305 dot_product::<EF, _, _>(mask.iter().copied(), covector.iter().copied())
306 })
307 .sum::<EF>();
308
309 assert_eq!(by_covectors, mask_residual::<EF>(&masks, &gammas));
310 }
311}