Skip to main content

p3_sumcheck/zk/prover/
residual.rs

1//! HVZK overlay for an already-derived residual sumcheck claim.
2
3use alloc::vec::Vec;
4
5use p3_challenger::{CanObserve, FieldChallenger, GrindingChallenger};
6use p3_commit::Mmcs;
7use p3_field::{ExtensionField, Field, HornerIter};
8use p3_matrix::Matrix;
9use p3_multilinear_util::point::Point;
10use p3_zk_codes::ZkEncodingWithRandomness;
11use rand::Rng;
12
13use super::common::{observe_masks_and_mu_tilde, sample_masks};
14use super::round::{PlainPiece, RoundContext, RoundState, round_poly_to_wire};
15use crate::strategy::SumcheckProver;
16use crate::zk::{ZkSumcheckData, ZkSumcheckHandoff};
17
18impl<F, EF> SumcheckProver<F, EF>
19where
20    F: Field,
21    EF: ExtensionField<F>,
22{
23    /// Runs the HVZK sumcheck overlay on an already-derived residual product
24    /// polynomial.
25    ///
26    /// This is the post-code-switch analogue of `ZkPrefixProver::into_sumcheck`:
27    /// the caller has already reduced the layout-specific opening relation to a
28    /// product polynomial, and this method applies Construction 6.3's mask
29    /// transcript to the next batch of sumcheck rounds.
30    ///
31    /// # Joint claims and the auxiliary constant
32    ///
33    /// The committed-sumcheck relation (Definition 5.8 of eprint 2026/391)
34    /// pairs the source claim `<f, w>` with mask-oracle claims `<xi_i, u_i>`.
35    ///
36    /// - The mask-claim values are prover-only; their total is the
37    ///   auxiliary constant.
38    /// - The bound scalar is the joint claim: source claim plus that
39    ///   constant.
40    /// - The constant rides the affine chain with a `2^{-j}` carry per
41    ///   round:
42    ///
43    /// ```text
44    ///     h_j gains  eps * aux * 2^{-j}  on its constant slot
45    ///     =>  h_j(0) + h_j(1)  gains  eps * aux * 2^{-(j-1)}
46    ///     =>  the final residual gains  eps * aux * 2^{-k}
47    /// ```
48    ///
49    /// Downstream reductions must therefore scale the carried mask covectors
50    /// by `eps * 2^{-k}`.
51    ///
52    /// # Eval side
53    ///
54    /// - Only the weight side and the claim are scaled by `eps`.
55    /// - The evaluation side stays the honest folded message.
56    /// - An HVZK code-switch can therefore commit it verbatim.
57    #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
58    #[tracing::instrument(skip_all)]
59    pub fn into_zk_sumcheck<Enc, M, R, Ch>(
60        mut self,
61        zk_data: &mut ZkSumcheckData<F, EF>,
62        encoding: &Enc,
63        mmcs: &M,
64        folding_factor: usize,
65        pow_bits: usize,
66        aux_claim: EF,
67        challenger: &mut Ch,
68        rng: &mut R,
69    ) -> ZkSumcheckHandoff<F, EF, M>
70    where
71        Enc: ZkEncodingWithRandomness<EF>,
72        Enc::Codeword: Matrix<EF>,
73        M: Mmcs<EF>,
74        R: Rng,
75        Ch: FieldChallenger<F> + GrindingChallenger<Witness = F> + CanObserve<M::Commitment>,
76    {
77        assert!(F::TWO != F::ZERO, "Lemma 6.4 requires char(F) != 2");
78        assert!(folding_factor >= 1, "sumcheck requires at least one round");
79        assert!(
80            folding_factor <= self.num_variables(),
81            "folding_factor must be <= residual prover arity",
82        );
83
84        let ell_zk = encoding.message_len();
85        assert!(
86            ell_zk >= 3,
87            "mask degree ell_zk - 1 must cover the degree-2 plain piece (ell_zk >= 3)",
88        );
89
90        // Unlike the layout-driven path, this entry receives a scalar claim
91        // directly, so bind it before the masking prelude samples `eps`.
92        //
93        // The bound value is the joint claim, matching the verifier's view.
94        challenger.observe_algebra_element(self.claimed_sum() + aux_claim);
95
96        let (masks, mask_randomness, mask_oracle) =
97            sample_masks::<EF, _, _, _, _>(folding_factor, encoding, mmcs, challenger, rng);
98        let mut sum_future_endpoints = observe_masks_and_mu_tilde::<F, EF, _>(
99            &masks,
100            folding_factor,
101            ell_zk,
102            challenger,
103            zk_data,
104        );
105
106        let eps: EF = challenger.sample_algebra_element();
107        let mut rs = Vec::with_capacity(folding_factor);
108        let mut mask_evals_at_gamma = Vec::with_capacity(folding_factor);
109        let pow2: Vec<EF> = EF::TWO.powers().collect_n(folding_factor + 1);
110        let round_ctx = RoundContext {
111            k: folding_factor,
112            ell_zk,
113            pow2: &pow2,
114            eps,
115        };
116
117        // Running `aux * 2^{-j}` carry; halved once per round.
118        let half = EF::TWO.inverse();
119        let mut aux_carry = aux_claim;
120
121        for (round_idx, mask) in masks.iter().enumerate() {
122            let j = round_idx + 1;
123            let mask_endpoints = mask[0].double() + mask[1..].iter().copied().sum::<EF>();
124            sum_future_endpoints -= mask_endpoints;
125            aux_carry *= half;
126
127            let (plain_c0, plain_c_inf) = self.round_coefficients();
128            // The aux carry enters only the transmitted constant slot; the
129            // source-side fold below keeps the raw coefficients.
130            let h = round_ctx.assemble(
131                RoundState {
132                    j,
133                    mask,
134                    past_mask_evals: &mask_evals_at_gamma,
135                    future_endpoints: sum_future_endpoints,
136                },
137                PlainPiece {
138                    c0: plain_c0 + aux_carry,
139                    c_inf: plain_c_inf,
140                },
141            );
142            let wire = round_poly_to_wire(&h);
143            challenger.observe_algebra_slice(&wire);
144            zk_data.round_coefficients.push(wire);
145
146            if pow_bits > 0 {
147                zk_data.pow_witnesses.push(challenger.grind(pow_bits));
148            }
149
150            let gamma: EF = challenger.sample_algebra_element();
151            let mask_at_gamma = mask.iter().copied().horner(gamma);
152            mask_evals_at_gamma.push(mask_at_gamma);
153
154            self.fold_round_with_coefficients(plain_c0, plain_c_inf, gamma);
155            rs.push(gamma);
156        }
157
158        self.scale_weights_and_claim(eps);
159
160        ZkSumcheckHandoff {
161            residual_prover: self,
162            randomness: Point::new(rs),
163            eps,
164            mask_messages: masks,
165            mask_randomness,
166            mask_oracle,
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use alloc::vec;
174    use alloc::vec::Vec;
175
176    use p3_baby_bear::BabyBear;
177    use p3_field::extension::BinomialExtensionField;
178    use p3_field::{PrimeCharacteristicRing, dot_product};
179    use p3_matrix::dense::RowMajorMatrix;
180    use p3_multilinear_util::poly::Poly;
181    use p3_zk_codes::{ZkEncoding, ZkEncodingWithRandomness};
182    use rand::rngs::SmallRng;
183    use rand::{Rng, SeedableRng};
184
185    use super::*;
186    use crate::product_polynomial::ProductPolynomial;
187    use crate::strategy::VariableOrder;
188    use crate::zk::test_helpers::{MyChallenger, MyMmcs, make_setup};
189    use crate::zk::{ZkVerifier, mask_residual};
190
191    type F = BabyBear;
192    type EF = BinomialExtensionField<BabyBear, 4>;
193
194    #[derive(Clone)]
195    struct SentinelEncoding {
196        ell_zk: usize,
197    }
198
199    impl ZkEncoding<EF> for SentinelEncoding {
200        type Codeword = RowMajorMatrix<EF>;
201
202        fn message_len(&self) -> usize {
203            self.ell_zk
204        }
205
206        fn randomness_len(&self) -> usize {
207            0
208        }
209
210        fn error(&self) -> f64 {
211            0.0
212        }
213
214        fn sample_message<R: Rng>(&self, _rng: &mut R) -> Vec<EF> {
215            (0..self.ell_zk)
216                .map(|idx| EF::from_u64(100 + idx as u64))
217                .collect()
218        }
219
220        fn query_bound(&self) -> usize {
221            0
222        }
223
224        fn encode<R: Rng>(&self, msg: &[EF], _rng: &mut R) -> Self::Codeword {
225            RowMajorMatrix::new_col(msg.to_vec())
226        }
227
228        fn sample_randomness<R: Rng>(&self, _rng: &mut R) -> Vec<EF> {
229            Vec::new()
230        }
231
232        fn simulate<R: Rng>(&self, query_set: &[usize], _rng: &mut R) -> Vec<EF> {
233            EF::zero_vec(query_set.len())
234        }
235    }
236
237    impl ZkEncodingWithRandomness<EF> for SentinelEncoding {
238        fn encode_with_randomness(&self, msg: &[EF], randomness: &[EF]) -> Self::Codeword {
239            assert!(randomness.is_empty());
240            RowMajorMatrix::new_col(msg.to_vec())
241        }
242    }
243
244    #[test]
245    fn residual_prover_zk_handoff_replays_from_claim() {
246        let evals = Poly::new((1..=8).map(EF::from_u64).collect::<Vec<_>>());
247        let weights = Poly::new((11..=18).map(EF::from_u64).collect::<Vec<_>>());
248        let claimed_sum = dot_product::<EF, _, _>(
249            evals.as_slice().iter().copied(),
250            weights.as_slice().iter().copied(),
251        );
252        let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
253        let prover = SumcheckProver::new(poly, claimed_sum);
254
255        let ell_zk = 4;
256        let folding_factor = 2;
257        let (perm, mmcs, encoding) = make_setup(17, ell_zk);
258        let mut prover_challenger = MyChallenger::new(perm.clone());
259        let mut verifier_challenger = MyChallenger::new(perm);
260        let mut rng = SmallRng::seed_from_u64(19);
261        let mut zk_data = ZkSumcheckData::<F, EF>::default();
262
263        let prover_handoff = prover.into_zk_sumcheck(
264            &mut zk_data,
265            &encoding,
266            &mmcs,
267            folding_factor,
268            0,
269            EF::ZERO,
270            &mut prover_challenger,
271            &mut rng,
272        );
273        let mask_commitment = prover_handoff.mask_oracle.0.clone();
274
275        let verifier_handoff = ZkVerifier::<F, EF>::verify_claim::<MyMmcs, _>(
276            &zk_data,
277            &mask_commitment,
278            ell_zk,
279            folding_factor,
280            0,
281            claimed_sum,
282            &mut verifier_challenger,
283        )
284        .expect("honest residual ZK handoff should verify");
285
286        assert_eq!(verifier_handoff.randomness, prover_handoff.randomness);
287        assert_eq!(verifier_handoff.eps, prover_handoff.eps);
288
289        let gammas = prover_handoff
290            .randomness
291            .iter()
292            .copied()
293            .collect::<Vec<_>>();
294        let final_mask_residual = mask_residual::<EF>(&prover_handoff.mask_messages, &gammas);
295        assert_eq!(
296            verifier_handoff.claimed_residual,
297            prover_handoff.residual_prover.claimed_sum() + final_mask_residual,
298        );
299    }
300
301    #[test]
302    fn residual_zk_handoff_samples_masks_through_encoding() {
303        let evals = Poly::new((1..=8).map(EF::from_u64).collect::<Vec<_>>());
304        let weights = Poly::new((11..=18).map(EF::from_u64).collect::<Vec<_>>());
305        let claimed_sum = dot_product::<EF, _, _>(
306            evals.as_slice().iter().copied(),
307            weights.as_slice().iter().copied(),
308        );
309        let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
310        let prover = SumcheckProver::new(poly, claimed_sum);
311
312        let ell_zk = 4;
313        let folding_factor = 2;
314        let (perm, mmcs, _) = make_setup(23, ell_zk);
315        let encoding = SentinelEncoding { ell_zk };
316        let mut challenger = MyChallenger::new(perm);
317        let mut rng = SmallRng::seed_from_u64(29);
318        let mut zk_data = ZkSumcheckData::<F, EF>::default();
319
320        let handoff = prover.into_zk_sumcheck(
321            &mut zk_data,
322            &encoding,
323            &mmcs,
324            folding_factor,
325            0,
326            EF::ZERO,
327            &mut challenger,
328            &mut rng,
329        );
330
331        let sentinel = encoding.sample_message(&mut rng);
332        assert_eq!(handoff.mask_messages, vec![sentinel; folding_factor]);
333    }
334}