1use alloc::vec::Vec;
4
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_commit::Mmcs;
7use p3_field::{ExtensionField, Field, HornerIter};
8use p3_matrix::Matrix;
9use p3_multilinear_util::point::Point;
10use p3_zk_codes::ZkEncodingWithRandomness;
11use rand::Rng;
12
13use super::common::{observe_masks_and_mu_tilde, sample_masks};
14use super::round::{PlainPiece, RoundContext, RoundState, round_poly_to_wire};
15use crate::strategy::SumcheckProver;
16use crate::zk::{ZkSumcheckData, ZkSumcheckHandoff};
17
18impl<F, EF> SumcheckProver<F, EF>
19where
20 F: Field,
21 EF: ExtensionField<F>,
22{
23 #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
58 #[tracing::instrument(skip_all)]
59 pub fn into_zk_sumcheck<Enc, M, R, Ch>(
60 mut self,
61 zk_data: &mut ZkSumcheckData<F, EF>,
62 encoding: &Enc,
63 mmcs: &M,
64 folding_factor: usize,
65 pow_bits: usize,
66 aux_claim: EF,
67 challenger: &mut Ch,
68 rng: &mut R,
69 ) -> ZkSumcheckHandoff<F, EF, M>
70 where
71 Enc: ZkEncodingWithRandomness<EF>,
72 Enc::Codeword: Matrix<EF>,
73 M: Mmcs<EF>,
74 R: Rng,
75 Ch: FieldChallenger<F> + GrindingChallenger<Witness = F> + CanObserve<M::Commitment>,
76 {
77 assert!(F::TWO != F::ZERO, "Lemma 6.4 requires char(F) != 2");
78 assert!(folding_factor >= 1, "sumcheck requires at least one round");
79 assert!(
80 folding_factor <= self.num_variables(),
81 "folding_factor must be <= residual prover arity",
82 );
83
84 let ell_zk = encoding.message_len();
85 assert!(
86 ell_zk >= 3,
87 "mask degree ell_zk - 1 must cover the degree-2 plain piece (ell_zk >= 3)",
88 );
89
90 challenger.observe_algebra_element(self.claimed_sum() + aux_claim);
95
96 let (masks, mask_randomness, mask_oracle) =
97 sample_masks::<EF, _, _, _, _>(folding_factor, encoding, mmcs, challenger, rng);
98 let mut sum_future_endpoints = observe_masks_and_mu_tilde::<F, EF, _>(
99 &masks,
100 folding_factor,
101 ell_zk,
102 challenger,
103 zk_data,
104 );
105
106 let eps: EF = challenger.sample_algebra_element();
107 let mut rs = Vec::with_capacity(folding_factor);
108 let mut mask_evals_at_gamma = Vec::with_capacity(folding_factor);
109 let pow2: Vec<EF> = EF::TWO.powers().collect_n(folding_factor + 1);
110 let round_ctx = RoundContext {
111 k: folding_factor,
112 ell_zk,
113 pow2: &pow2,
114 eps,
115 };
116
117 let half = EF::TWO.inverse();
119 let mut aux_carry = aux_claim;
120
121 for (round_idx, mask) in masks.iter().enumerate() {
122 let j = round_idx + 1;
123 let mask_endpoints = mask[0].double() + mask[1..].iter().copied().sum::<EF>();
124 sum_future_endpoints -= mask_endpoints;
125 aux_carry *= half;
126
127 let (plain_c0, plain_c_inf) = self.round_coefficients();
128 let h = round_ctx.assemble(
131 RoundState {
132 j,
133 mask,
134 past_mask_evals: &mask_evals_at_gamma,
135 future_endpoints: sum_future_endpoints,
136 },
137 PlainPiece {
138 c0: plain_c0 + aux_carry,
139 c_inf: plain_c_inf,
140 },
141 );
142 let wire = round_poly_to_wire(&h);
143 challenger.observe_algebra_slice(&wire);
144 zk_data.round_coefficients.push(wire);
145
146 if pow_bits > 0 {
147 zk_data.pow_witnesses.push(challenger.grind(pow_bits));
148 }
149
150 let gamma: EF = challenger.sample_algebra_element();
151 let mask_at_gamma = mask.iter().copied().horner(gamma);
152 mask_evals_at_gamma.push(mask_at_gamma);
153
154 self.fold_round_with_coefficients(plain_c0, plain_c_inf, gamma);
155 rs.push(gamma);
156 }
157
158 self.scale_weights_and_claim(eps);
159
160 ZkSumcheckHandoff {
161 residual_prover: self,
162 randomness: Point::new(rs),
163 eps,
164 mask_messages: masks,
165 mask_randomness,
166 mask_oracle,
167 }
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use alloc::vec;
174 use alloc::vec::Vec;
175
176 use p3_baby_bear::BabyBear;
177 use p3_field::extension::BinomialExtensionField;
178 use p3_field::{PrimeCharacteristicRing, dot_product};
179 use p3_matrix::dense::RowMajorMatrix;
180 use p3_multilinear_util::poly::Poly;
181 use p3_zk_codes::{ZkEncoding, ZkEncodingWithRandomness};
182 use rand::rngs::SmallRng;
183 use rand::{Rng, SeedableRng};
184
185 use super::*;
186 use crate::product_polynomial::ProductPolynomial;
187 use crate::strategy::VariableOrder;
188 use crate::zk::test_helpers::{MyChallenger, MyMmcs, make_setup};
189 use crate::zk::{ZkVerifier, mask_residual};
190
191 type F = BabyBear;
192 type EF = BinomialExtensionField<BabyBear, 4>;
193
194 #[derive(Clone)]
195 struct SentinelEncoding {
196 ell_zk: usize,
197 }
198
199 impl ZkEncoding<EF> for SentinelEncoding {
200 type Codeword = RowMajorMatrix<EF>;
201
202 fn message_len(&self) -> usize {
203 self.ell_zk
204 }
205
206 fn randomness_len(&self) -> usize {
207 0
208 }
209
210 fn error(&self) -> f64 {
211 0.0
212 }
213
214 fn sample_message<R: Rng>(&self, _rng: &mut R) -> Vec<EF> {
215 (0..self.ell_zk)
216 .map(|idx| EF::from_u64(100 + idx as u64))
217 .collect()
218 }
219
220 fn query_bound(&self) -> usize {
221 0
222 }
223
224 fn encode<R: Rng>(&self, msg: &[EF], _rng: &mut R) -> Self::Codeword {
225 RowMajorMatrix::new_col(msg.to_vec())
226 }
227
228 fn sample_randomness<R: Rng>(&self, _rng: &mut R) -> Vec<EF> {
229 Vec::new()
230 }
231
232 fn simulate<R: Rng>(&self, query_set: &[usize], _rng: &mut R) -> Vec<EF> {
233 EF::zero_vec(query_set.len())
234 }
235 }
236
237 impl ZkEncodingWithRandomness<EF> for SentinelEncoding {
238 fn encode_with_randomness(&self, msg: &[EF], randomness: &[EF]) -> Self::Codeword {
239 assert!(randomness.is_empty());
240 RowMajorMatrix::new_col(msg.to_vec())
241 }
242 }
243
244 #[test]
245 fn residual_prover_zk_handoff_replays_from_claim() {
246 let evals = Poly::new((1..=8).map(EF::from_u64).collect::<Vec<_>>());
247 let weights = Poly::new((11..=18).map(EF::from_u64).collect::<Vec<_>>());
248 let claimed_sum = dot_product::<EF, _, _>(
249 evals.as_slice().iter().copied(),
250 weights.as_slice().iter().copied(),
251 );
252 let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
253 let prover = SumcheckProver::new(poly, claimed_sum);
254
255 let ell_zk = 4;
256 let folding_factor = 2;
257 let (perm, mmcs, encoding) = make_setup(17, ell_zk);
258 let mut prover_challenger = MyChallenger::new(perm.clone());
259 let mut verifier_challenger = MyChallenger::new(perm);
260 let mut rng = SmallRng::seed_from_u64(19);
261 let mut zk_data = ZkSumcheckData::<F, EF>::default();
262
263 let prover_handoff = prover.into_zk_sumcheck(
264 &mut zk_data,
265 &encoding,
266 &mmcs,
267 folding_factor,
268 0,
269 EF::ZERO,
270 &mut prover_challenger,
271 &mut rng,
272 );
273 let mask_commitment = prover_handoff.mask_oracle.0.clone();
274
275 let verifier_handoff = ZkVerifier::<F, EF>::verify_claim::<MyMmcs, _>(
276 &zk_data,
277 &mask_commitment,
278 ell_zk,
279 folding_factor,
280 0,
281 claimed_sum,
282 &mut verifier_challenger,
283 )
284 .expect("honest residual ZK handoff should verify");
285
286 assert_eq!(verifier_handoff.randomness, prover_handoff.randomness);
287 assert_eq!(verifier_handoff.eps, prover_handoff.eps);
288
289 let gammas = prover_handoff
290 .randomness
291 .iter()
292 .copied()
293 .collect::<Vec<_>>();
294 let final_mask_residual = mask_residual::<EF>(&prover_handoff.mask_messages, &gammas);
295 assert_eq!(
296 verifier_handoff.claimed_residual,
297 prover_handoff.residual_prover.claimed_sum() + final_mask_residual,
298 );
299 }
300
301 #[test]
302 fn residual_zk_handoff_samples_masks_through_encoding() {
303 let evals = Poly::new((1..=8).map(EF::from_u64).collect::<Vec<_>>());
304 let weights = Poly::new((11..=18).map(EF::from_u64).collect::<Vec<_>>());
305 let claimed_sum = dot_product::<EF, _, _>(
306 evals.as_slice().iter().copied(),
307 weights.as_slice().iter().copied(),
308 );
309 let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
310 let prover = SumcheckProver::new(poly, claimed_sum);
311
312 let ell_zk = 4;
313 let folding_factor = 2;
314 let (perm, mmcs, _) = make_setup(23, ell_zk);
315 let encoding = SentinelEncoding { ell_zk };
316 let mut challenger = MyChallenger::new(perm);
317 let mut rng = SmallRng::seed_from_u64(29);
318 let mut zk_data = ZkSumcheckData::<F, EF>::default();
319
320 let handoff = prover.into_zk_sumcheck(
321 &mut zk_data,
322 &encoding,
323 &mmcs,
324 folding_factor,
325 0,
326 EF::ZERO,
327 &mut challenger,
328 &mut rng,
329 );
330
331 let sentinel = encoding.sample_message(&mut rng);
332 assert_eq!(handoff.mask_messages, vec![sentinel; folding_factor]);
333 }
334}