p3_sumcheck/constraints/statement/select.rs
1use alloc::vec::Vec;
2
3use itertools::Itertools;
4use p3_field::{
5 ExtensionField, Field, HornerIter, PackedFieldExtension, PackedValue, PrimeCharacteristicRing,
6 dot_product,
7};
8use p3_matrix::Matrix;
9use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
10use p3_maybe_rayon::prelude::*;
11use p3_multilinear_util::point::Point;
12use p3_multilinear_util::poly::Poly;
13use p3_util::log2_strict_usize;
14use tracing::instrument;
15
16/// Expand powers-of-two seeds into the full power table via butterfly.
17///
18/// # Input
19///
20/// A `k × n` matrix where column `j` holds the squared powers of
21/// variable `v_j` in descending exponent order:
22///
23/// ```text
24/// row 0: [v_1^{2^{k-1}}, v_2^{2^{k-1}}, …, v_n^{2^{k-1}}]
25/// row 1: [v_1^{2^{k-2}}, v_2^{2^{k-2}}, …, v_n^{2^{k-2}}]
26/// ⋮
27/// row k-1: [v_1^1, v_2^1, …, v_n^1 ]
28/// ```
29///
30/// # Output
31///
32/// A `2^k × n` matrix where entry `[b, j] = v_j^b` (the full
33/// monomial power, not just a squared power).
34///
35/// # Algorithm
36///
37/// Uses a binary-tree butterfly. After processing row `i` of the
38/// input, the first `2^{i+1}` rows of the output are filled.
39/// Each step copies the existing rows and multiplies by the
40/// current squared power to fill the new rows:
41///
42/// ```text
43/// mat[b + 2^i, j] = mat[b, j] * points[i, j]
44/// ```
45fn batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F> {
46 let k = points.height();
47 let n = points.width();
48
49 let mut mat = RowMajorMatrix::new(F::zero_vec(n * (1 << k)), n);
50
51 // Base case: v_j^0 = 1 for all j.
52 mat.row_mut(0).fill(F::ONE);
53
54 // Butterfly expansion: each input row doubles the number of filled rows.
55 points.row_slices().enumerate().for_each(|(i, vars)| {
56 let (lo, mut hi) = mat.split_rows_mut(1 << i);
57 lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
58 // hi[j] = lo[j] * var[j], extending the power by 2^i.
59 vars.iter()
60 .zip(lo.zip(hi.iter_mut()))
61 .for_each(|(&var, (lo, hi))| *hi = lo * var);
62 });
63 });
64 mat
65}
66
67/// SIMD-packed variant of the power-table butterfly expansion.
68///
69/// # Overview
70///
71/// Splits the `k` input variables into two phases:
72///
73/// 1. **Packing phase** (first `k_pack` variables): Builds a small
74/// scalar power table per column, then packs it into a single
75/// SIMD lane. This fills the first row of the packed output.
76///
77/// 2. **Butterfly phase** (remaining `k - k_pack` variables): Applies
78/// the same butterfly as the scalar version, but operates on
79/// packed elements — multiplying all SIMD lanes in one instruction.
80///
81/// # Output
82///
83/// A `2^{k - k_pack} × n` matrix of packed elements, where
84/// unpacking row `r` column `j` yields the `F::Packing::WIDTH`
85/// consecutive scalar entries `v_j^{r * WIDTH}, …, v_j^{r * WIDTH + WIDTH - 1}`.
86fn packed_batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F::Packing> {
87 let k = points.height();
88 let n = points.width();
89 assert_ne!(n, 0);
90 let k_pack = log2_strict_usize(F::Packing::WIDTH);
91 assert!(k >= k_pack);
92
93 let (init_vars, rest_vars) = points.split_rows(k_pack);
94 let mut mat = RowMajorMatrix::new(F::Packing::zero_vec(n * (1 << (k - k_pack))), n);
95
96 if k_pack > 0 {
97 // Packing phase: build a scalar 2^{k_pack}-row power table
98 // per column and pack it into one SIMD element.
99 init_vars
100 .transpose()
101 .row_slices()
102 .zip(mat.values.iter_mut())
103 .for_each(|(vars, packed)| {
104 let point = RowMajorMatrixView::new(vars, 1);
105 *packed = *F::Packing::from_slice(&batch_pows(point).values);
106 });
107 } else {
108 // No packing needed: WIDTH = 1, seed row is all ones.
109 mat.row_mut(0).fill(F::Packing::ONE);
110 }
111
112 // Butterfly phase: same expansion as the scalar version,
113 // but each multiply operates on WIDTH lanes simultaneously.
114 rest_vars.row_slices().enumerate().for_each(|(i, vars)| {
115 let (lo, mut hi) = mat.split_rows_mut(1 << i);
116 lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
117 vars.iter()
118 .zip(lo.zip(hi.iter_mut()))
119 .for_each(|(&var, (lo, hi))| *hi = lo * var);
120 });
121 });
122 mat
123}
124
125/// A batched system of `select`-based evaluation constraints for multilinear polynomials.
126///
127/// This struct represents a collection of evaluation constraints of the form `p(z_i) = s_i`
128/// for a multilinear polynomial `p` over the Boolean hypercube `{0,1}^k`.
129///
130/// # The Select Function
131///
132/// For vectors `X, Y ∈ F^k`, the select function is defined as:
133///
134/// ```text
135/// select(X, Y) = ∏_i (X_i · Y_i + (1 - Y_i))
136/// ```
137///
138/// **Key Property:** When `Y ∈ {0,1}^k` is a Boolean vector and `X = pow(z)`:
139///
140/// ```text
141/// select(pow(z), b) = z^{int(b)}
142/// ```
143///
144/// where `pow(z) = (z, z^2, z^4, ..., z^{2^{k-1}})` and `int(b)` interprets the Boolean
145/// vector `b` as an integer in binary.
146///
147/// **Derivation:**
148/// ```text
149/// select(pow(z), b) = ∏_i (z^{2^i} · b_i + (1 - b_i))
150/// = ∏_{i: b_i=1} (z^{2^i}) [since b_i ∈ {0,1}]
151/// = z^{Σ_{i: b_i=1} 2^i}
152/// = z^{int(b)}
153/// ```
154///
155/// # Verification Claims
156///
157/// Each constraint `(z_i, s_i)` in this statement asserts:
158///
159/// ```text
160/// Σ_{b ∈ {0,1}^k} P(b) · select(pow(z_i), b) = s_i
161/// ```
162///
163/// where `P(b)` are the evaluations of the polynomial over the Boolean hypercube.
164///
165/// # Batching
166///
167/// Multiple constraints are batched using random challenge `γ` to produce:
168///
169/// - **Weight polynomial**: `W(b) = Σ_i γ^i · select(pow(z_i), b)`
170/// - **Target sum**: `S = Σ_i γ^i · s_i`
171///
172/// This reduces `n` separate verification claims to a single sumcheck:
173///
174/// ```text
175/// Σ_{b ∈ {0,1}^k} P(b) · W(b) = S
176/// ```
177#[derive(Clone, Debug)]
178pub struct SelectStatement<F, EF> {
179 /// Number of variables `k` defining the Boolean hypercube `{0,1}^k`.
180 ///
181 /// This determines the dimension of the multilinear polynomial space and the size
182 /// of the evaluation domain (2^k points).
183 num_variables: usize,
184
185 /// Evaluation points `[z_1, z_2, ..., z_n]` where each constraint checks `p(z_i) = s_i`.
186 ///
187 /// Each `z_i ∈ F` is a base field element. The `pow` map will expand it to
188 /// `pow(z_i) = (z_i, z_i^2, z_i^4, ..., z_i^{2^{k-1}})` for the select function.
189 pub(crate) vars: Vec<F>,
190
191 /// Expected evaluation values `[s_1, s_2, ..., s_n]` corresponding to each constraint.
192 ///
193 /// Each `s_i ∈ EF` is an extension field element representing the claimed evaluation
194 /// of the polynomial at point `z_i`.
195 evaluations: Vec<EF>,
196}
197
198impl<F: Field, EF: ExtensionField<F>> SelectStatement<F, EF> {
199 /// Creates an empty select statement for polynomials over `{0,1}^k`.
200 ///
201 /// # Parameters
202 ///
203 /// - `num_variables`: The dimension `k` of the Boolean hypercube
204 ///
205 /// # Returns
206 ///
207 /// An initialized statement with no constraints, ready to accept constraints.
208 #[must_use]
209 pub const fn initialize(num_variables: usize) -> Self {
210 Self {
211 num_variables,
212 vars: Vec::new(),
213 evaluations: Vec::new(),
214 }
215 }
216
217 /// Creates a select statement pre-populated with constraints.
218 ///
219 /// # Parameters
220 ///
221 /// - `num_variables`: The dimension `k` of the Boolean hypercube
222 /// - `vars`: Evaluation points `[z_1, ..., z_n]`
223 /// - `evaluations`: Expected values `[s_1, ..., s_n]`
224 ///
225 /// # Panics
226 ///
227 /// Panics if the nu
228 #[must_use]
229 pub const fn new(num_variables: usize, vars: Vec<F>, evaluations: Vec<EF>) -> Self {
230 assert!(vars.len() == evaluations.len());
231 Self {
232 num_variables,
233 vars,
234 evaluations,
235 }
236 }
237
238 /// Returns the number of variables `k` defining the polynomial space dimension.
239 ///
240 /// This is the dimension of the Boolean hypercube `{0,1}^k` over which polynomials
241 /// are defined, containing `2^k` evaluation points.
242 #[must_use]
243 pub const fn num_variables(&self) -> usize {
244 self.num_variables
245 }
246
247 /// Returns `true` if no constraints have been added to this statement.
248 #[must_use]
249 pub const fn is_empty(&self) -> bool {
250 debug_assert!(self.vars.is_empty() == self.evaluations.is_empty());
251 self.vars.is_empty()
252 }
253
254 /// Returns an iterator over constraint pairs `(z_i, s_i)`.
255 ///
256 /// Each pair represents one evaluation constraint: `p(z_i) = s_i`.
257 pub fn iter(&self) -> impl Iterator<Item = (&F, &EF)> {
258 self.vars.iter().zip(self.evaluations.iter())
259 }
260
261 /// Returns the number of evaluation constraints `n` in this statement.
262 #[must_use]
263 pub const fn len(&self) -> usize {
264 debug_assert!(self.vars.len() == self.evaluations.len());
265 self.vars.len()
266 }
267
268 /// Verifies that a given polynomial satisfies all constraints in the statement.
269 ///
270 /// For each constraint `(z_i, s_i)`, this method interprets the evaluation table as
271 /// coefficients of a univariate polynomial, evaluates it at `z_i` using Horner's method,
272 /// and checks if the result equals the expected value `s_i`.
273 ///
274 /// For a polynomial represented by evaluations `[c_0, c_1, ..., c_{2^k-1}]`:
275 ///
276 /// ```text
277 /// p(z) = c_0 + z(c_1 + z(c_2 + z(...)))
278 /// ```
279 ///
280 /// This is computed right-to-left as:
281 /// ```text
282 /// acc = 0
283 /// for i = 2^k-1 down to 0:
284 /// acc = acc * z + c_i
285 /// ```
286 ///
287 /// # Parameters
288 ///
289 /// - `poly`: Evaluation table treated as univariate polynomial coefficients
290 ///
291 /// # Returns
292 ///
293 /// `true` if all constraints are satisfied, `false` otherwise.
294 #[must_use]
295 pub fn verify(&self, poly: &Poly<EF>) -> bool {
296 self.iter().all(|(&var, &expected_eval)| {
297 // Evaluate the polynomial at `var` using Horner's method.
298 // This computes: p(var) = c_0 + var(c_1 + var(c_2 + ...))
299 poly.iter().copied().horner::<EF, _>(var) == expected_eval
300 })
301 }
302
303 /// Adds a single evaluation constraint `p(z) = s` to the statement.
304 ///
305 /// # Parameters
306 ///
307 /// - `var`: Evaluation point `z ∈ F`
308 /// - `eval`: Expected evaluation value `s ∈ EF`
309 pub fn add_constraint(&mut self, var: F, eval: EF) {
310 self.vars.push(var);
311 self.evaluations.push(eval);
312 }
313
314 /// Batches all constraints into a single weighted polynomial and target sum for sumcheck.
315 ///
316 /// Given constraints `p(z_1) = s_1, ..., p(z_n) = s_n`, this method transforms them into
317 /// a single sumcheck claim using random challenge `γ`:
318 ///
319 /// ```text
320 /// Σ_{b ∈ {0,1}^k} P(b) · W(b) = S
321 /// ```
322 ///
323 /// where:
324 /// - **Weight polynomial**: `W(b) = Σ_i γ^{i+shift} · select(pow(z_i), b)`
325 /// - **Target sum**: `S = Σ_i γ^{i+shift} · s_i`
326 ///
327 /// The method computes `W(b)` for all `b ∈ {0,1}^k` and `S`, adding them to the
328 /// provided accumulators.
329 ///
330 /// # Parameters
331 ///
332 /// - `acc_weights`: Accumulator for the weight polynomial `W(b)`. Must have `2^k` entries.
333 /// This method **adds** the batched weights to existing values.
334 ///
335 /// - `acc_sum`: Accumulator for the target sum `S`. This method **adds** the batched
336 /// evaluations to the existing value.
337 ///
338 /// - `challenge`: Random challenge `γ ∈ EF` used for batching.
339 ///
340 /// - `shift`: Power offset for challenge. Constraint `i` uses weight `γ^{i+shift}`.
341 /// Allows multiple statement types to use non-overlapping challenge powers.
342 /// Batches all constraints into a single weighted polynomial and target sum for sumcheck.
343 ///
344 /// # Algorithm
345 ///
346 /// Three stages:
347 ///
348 /// 1. **Power map**: Build a `k × n` matrix where row `i`, column `j`
349 /// holds `z_j^{2^i}`. Stored as a flat row-major buffer so each
350 /// butterfly step reads a contiguous row (cache-friendly).
351 ///
352 /// 2. **Butterfly expansion**: Expand the power map into the full
353 /// `2^k × n` select matrix using the same binary-tree doubling as
354 /// the scalar power table. Entry `[b, j] = z_j^b`.
355 ///
356 /// 3. **Challenge combination**: Dot each row of the select matrix
357 /// with the challenge power vector to produce the weight polynomial.
358 #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
359 pub fn combine(
360 &self,
361 acc_weights: &mut Poly<EF>,
362 acc_sum: &mut EF,
363 challenge: EF,
364 shift: usize,
365 ) {
366 // Early return for empty statement:
367 //
368 // No constraints means no contribution to the batched claim.
369 if self.vars.is_empty() {
370 return;
371 }
372
373 // Extract dimensions for clarity.
374 //
375 // Number of constraints
376 let n = self.len();
377 // Dimension of Boolean hypercube
378 let k = self.num_variables();
379
380 // ---------------------------------------------------------------
381 // Stage 1: Build the k × n power-of-two matrix.
382 // ---------------------------------------------------------------
383 //
384 // Row i contains [z_1^{2^i}, z_2^{2^i}, ..., z_n^{2^i}].
385 // Stored as a flat Vec<F> of size k * n in row-major order.
386 let mut pow_matrix = F::zero_vec(k * n);
387 for (j, &var) in self.vars.iter().enumerate() {
388 let mut v = var;
389 for i in 0..k {
390 // pow_matrix[i * n + j] = z_j^{2^i}
391 pow_matrix[i * n + j] = v;
392 v = v.square();
393 }
394 }
395
396 // ---------------------------------------------------------------
397 // Stage 2: Butterfly expansion into the 2^k × n select matrix.
398 // ---------------------------------------------------------------
399 //
400 // After iteration i, the first 2^{i+1} rows are filled.
401 // Entry [b, j] = z_j^b.
402 let mut acc = F::zero_vec((1 << k) * n);
403
404 // Base case: z_j^0 = 1 for all j.
405 acc[..n].fill(F::ONE);
406
407 for i in 0..k {
408 let num_existing_rows = 1 << i;
409 let (lo, hi) = acc.split_at_mut(num_existing_rows * n);
410
411 // Contiguous row slice — no per-iteration allocation.
412 let pow_row = &pow_matrix[i * n..(i + 1) * n];
413
414 // For each existing row, compute the new row:
415 // acc[b + 2^i, j] = acc[b, j] * z_j^{2^i}
416 lo.par_chunks_mut(n)
417 .zip(hi.par_chunks_mut(n))
418 .for_each(|(lo_row, hi_row)| {
419 pow_row
420 .iter()
421 .zip(lo_row.iter())
422 .zip(hi_row.iter_mut())
423 .for_each(|((&z_pow, &lo_val), hi_val)| {
424 *hi_val = lo_val * z_pow;
425 });
426 });
427 }
428
429 // ---------------------------------------------------------------
430 // Stage 3: Combine with challenge powers.
431 // ---------------------------------------------------------------
432
433 // Precompute [gamma^shift, gamma^{shift+1}, ..., gamma^{shift+n-1}].
434 let challenges = challenge
435 .shifted_powers(challenge.exp_u64(shift as u64))
436 .collect_n(n);
437
438 // W(b) += sum_i gamma^{i+shift} * z_i^b
439 acc.par_chunks(n)
440 .zip(acc_weights.as_mut_slice().par_iter_mut())
441 .for_each(|(row, weight_out)| {
442 *weight_out +=
443 dot_product::<EF, _, _>(challenges.iter().copied(), row.iter().copied());
444 });
445
446 // S += sum_i gamma^{i+shift} * s_i
447 *acc_sum +=
448 dot_product::<EF, _, _>(challenges.into_iter(), self.evaluations.iter().copied());
449 }
450
451 /// SIMD-packed variant of constraint batching.
452 ///
453 /// # Overview
454 ///
455 /// Produces the same result as the scalar version, but stores the
456 /// weight polynomial in packed form (one SIMD element per
457 /// `Packing::WIDTH` consecutive hypercube entries).
458 ///
459 /// # Algorithm
460 ///
461 /// For small `k` (where `2 * k_pack > k`), falls back to a naive
462 /// per-constraint loop using shifted powers.
463 ///
464 /// For larger `k`, uses the split-and-dot approach:
465 ///
466 /// 1. Expand each evaluation point into its power-map representation.
467 /// 2. Transpose into a `k × n` matrix and split at `k / 2`.
468 /// 3. Build the packed left-half power table and the scalar right-half
469 /// power table.
470 /// 4. For each pair of rows (left packed, right scalar), compute the
471 /// weighted dot product with the challenge powers.
472 #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
473 pub fn combine_packed(
474 &self,
475 weights: &mut Poly<EF::ExtensionPacking>,
476 sum: &mut EF,
477 challenge: EF,
478 shift: usize,
479 ) {
480 if self.vars.is_empty() {
481 return;
482 }
483
484 let n = self.len();
485 let k = self.num_variables();
486 let k_pack = log2_strict_usize(F::Packing::WIDTH);
487 assert!(k >= k_pack);
488 assert_eq!(weights.num_variables() + k_pack, k);
489
490 // Accumulate the scalar target sum first.
491 self.combine_evals(sum, challenge, shift);
492
493 // Naive fallback: when there aren't enough variables for the
494 // split approach, compute shifted powers directly per constraint.
495 if k_pack * 2 > k {
496 self.vars
497 .iter()
498 .zip(challenge.shifted_powers(challenge.exp_u64(shift as u64)))
499 .for_each(|(&var, challenge)| {
500 // gamma^{shift+i} * [1, z, z^2, ..., z^{2^k - 1}]
501 let pow = EF::from(var).shifted_powers(challenge).collect_n(1 << k);
502 weights
503 .as_mut_slice()
504 .iter_mut()
505 .zip_eq(pow.chunks(F::Packing::WIDTH))
506 .for_each(|(out, chunk)| {
507 *out += EF::ExtensionPacking::from_ext_slice(chunk);
508 });
509 });
510 return;
511 }
512
513 // Split approach: expand each var into its power-map form,
514 // transpose, and split into left (packed) and right (scalar) halves.
515 let points = self
516 .vars
517 .iter()
518 .map(|&var| Point::expand_from_univariate(var, k))
519 .collect::<Vec<_>>();
520 let points = Point::transpose(&points, true);
521 let (left, right) = points.split_rows(k / 2);
522
523 // Left half → packed power table (operates in SIMD lanes).
524 let left = packed_batch_pows(left);
525 // Right half → scalar power table.
526 let right = batch_pows(right);
527
528 // Broadcast challenge powers into packed form for dot products.
529 let alphas = challenge
530 .shifted_powers(challenge.exp_u64(shift as u64))
531 .collect_n(n)
532 .into_iter()
533 .map(EF::ExtensionPacking::from)
534 .collect::<Vec<_>>();
535
536 // For each right-half row, dot all left-half rows against it
537 // (weighted by the challenge powers) and accumulate into the
538 // packed weight polynomial.
539 weights
540 .as_mut_slice()
541 .par_chunks_mut(left.height())
542 .zip(right.par_row_slices())
543 .for_each(|(out, right)| {
544 out.iter_mut().zip(left.rows()).for_each(|(out, left)| {
545 *out += left
546 .zip(right.iter())
547 .zip(alphas.iter())
548 .map(|((left, &right), &alpha)| alpha * (left * right))
549 .sum::<EF::ExtensionPacking>();
550 });
551 });
552 }
553
554 /// Batches expected evaluation values into a single target sum using challenge powers.
555 ///
556 /// Computes and adds to `claimed_eval`:
557 ///
558 /// ```text
559 /// S = Σ_i γ^{i+shift} · s_i
560 /// ```
561 ///
562 /// where `s_i` are the expected evaluation values in `self.evaluations`.
563 ///
564 /// # Parameters
565 ///
566 /// - `claimed_eval`: Accumulator for the target sum. This method **adds** the batched
567 /// evaluations to the existing value.
568 ///
569 /// - `challenge`: Random challenge `γ ∈ EF` used for batching.
570 ///
571 /// - `shift`: Power offset. Constraint `i` uses weight `γ^{i+shift}`.
572 pub fn combine_evals(&self, claimed_eval: &mut EF, challenge: EF, shift: usize) {
573 // Compute: Σ_i γ^{i+shift} · s_i
574 // This is equivalent to dot_product(evaluations, [γ^shift, γ^{shift+1}, ...])
575 *claimed_eval += dot_product::<EF, _, _>(
576 self.evaluations.iter().copied(),
577 challenge
578 .shifted_powers(challenge.exp_u64(shift as u64))
579 .take(self.len()),
580 );
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use alloc::vec;
587
588 use p3_baby_bear::BabyBear;
589 use p3_field::extension::BinomialExtensionField;
590 use p3_field::{PackedFieldExtension, PrimeCharacteristicRing};
591 use proptest::prelude::*;
592 use rand::rngs::SmallRng;
593 use rand::{RngExt, SeedableRng};
594
595 use super::*;
596
597 type F = BabyBear;
598 type EF = BinomialExtensionField<F, 4>;
599
600 #[test]
601 fn test_select_statement_initialize() {
602 // Test that initialize creates an empty statement with correct num_variables.
603 let statement = SelectStatement::<F, F>::initialize(3);
604
605 // The statement should have 3 variables.
606 assert_eq!(statement.num_variables(), 3);
607 // The statement should be empty (no constraints).
608 assert!(statement.is_empty());
609 // The length should be 0.
610 assert_eq!(statement.len(), 0);
611 }
612
613 #[test]
614 fn test_select_statement_new() {
615 // Test that new creates a statement with pre-populated constraints.
616 let vars = vec![F::from_u64(5), F::from_u64(7)];
617 let evaluations = vec![F::from_u64(10), F::from_u64(20)];
618
619 let statement = SelectStatement::new(2, vars.clone(), evaluations.clone());
620
621 // The statement should have 2 variables.
622 assert_eq!(statement.num_variables(), 2);
623 // The statement should not be empty.
624 assert!(!statement.is_empty());
625 // The statement should have 2 constraints.
626 assert_eq!(statement.len(), 2);
627 // The vars and evaluations should match.
628 assert_eq!(statement.vars, vars);
629 assert_eq!(statement.evaluations, evaluations);
630 }
631
632 #[test]
633 #[should_panic(expected = "assertion")]
634 fn test_select_statement_new_mismatched_lengths() {
635 // Test that new panics when vars.len() != evaluations.len().
636 let vars = vec![F::from_u64(5)];
637 let evaluations = vec![F::from_u64(10), F::from_u64(20)];
638
639 // This should panic due to length mismatch.
640 let _ = SelectStatement::new(2, vars, evaluations);
641 }
642
643 #[test]
644 fn test_select_statement_add_constraint() {
645 // Test adding constraints one at a time.
646 let mut statement = SelectStatement::<F, F>::initialize(2);
647
648 // Initially empty.
649 assert!(statement.is_empty());
650 assert_eq!(statement.len(), 0);
651
652 // Add first constraint: p(5) = 10.
653 statement.add_constraint(F::from_u64(5), F::from_u64(10));
654 assert!(!statement.is_empty());
655 assert_eq!(statement.len(), 1);
656
657 // Add second constraint: p(7) = 20.
658 statement.add_constraint(F::from_u64(7), F::from_u64(20));
659 assert_eq!(statement.len(), 2);
660
661 // Verify the constraints were added correctly.
662 let constraints: Vec<_> = statement.iter().collect();
663 assert_eq!(constraints.len(), 2);
664 assert_eq!(*constraints[0].0, F::from_u64(5));
665 assert_eq!(*constraints[0].1, F::from_u64(10));
666 assert_eq!(*constraints[1].0, F::from_u64(7));
667 assert_eq!(*constraints[1].1, F::from_u64(20));
668 }
669
670 #[test]
671 fn test_select_statement_verify_basic() {
672 // Test the verify method with a simple polynomial.
673 //
674 // Create a polynomial with evaluations [c0, c1, c2, c3] over {0,1}^2.
675 let c0 = F::from_u64(1);
676 let c1 = F::from_u64(2);
677 let c2 = F::from_u64(3);
678 let c3 = F::from_u64(4);
679 let poly = Poly::new(vec![c0, c1, c2, c3]);
680
681 // Create a statement with k=2 variables.
682 let k = 2;
683 let mut statement = SelectStatement::<F, F>::initialize(k);
684
685 // The polynomial evaluations [c0, c1, c2, c3] can be interpreted as a univariate polynomial:
686 // p(z) = c0 + c1*z + c2*z^2 + c3*z^3
687 //
688 // Test p(0) = c0 = 1.
689 let z0 = F::ZERO;
690 let eval0 = c0;
691 statement.add_constraint(z0, eval0);
692 assert!(statement.verify(&poly));
693
694 // Test p(1) = c0 + c1 + c2 + c3
695 let mut statement2 = SelectStatement::<F, F>::initialize(k);
696 let z1 = F::ONE;
697 let eval1 = c0 + c1 + c2 + c3;
698 statement2.add_constraint(z1, eval1);
699 assert!(statement2.verify(&poly));
700
701 // Test p(2) = c0 + c1*2 + c2*4 + c3*8
702 let mut statement3 = SelectStatement::<F, F>::initialize(k);
703 let z2 = F::from_u64(2);
704 let eval2 = c0 + c1 * z2 + c2 * z2 * z2 + c3 * z2 * z2 * z2;
705 statement3.add_constraint(z2, eval2);
706 assert!(statement3.verify(&poly));
707
708 // Test a failing verification: p(1) = wrong_eval
709 let mut statement4 = SelectStatement::<F, F>::initialize(k);
710 let wrong_eval = F::from_u64(56765);
711 statement4.add_constraint(z1, wrong_eval);
712 assert!(!statement4.verify(&poly));
713 }
714
715 #[test]
716 fn test_select_statement_combine_single_constraint() {
717 // Test combining a single constraint.
718 //
719 // For k=2 variables, we have a 2^2 = 4-point domain.
720 let k = 2;
721
722 // Create a statement with one constraint: p(z) = s.
723 let mut statement = SelectStatement::<F, F>::initialize(k);
724 let z = F::from_u64(5);
725 let s = F::from_u64(100);
726 statement.add_constraint(z, s);
727
728 // The challenge γ is unused for a single constraint (it would multiply by γ^0 = 1).
729 let gamma = F::from_u64(2);
730 let shift = 0;
731
732 // Initialize accumulators.
733 let mut acc_weights = Poly::zero(k);
734 let mut acc_sum = F::ZERO;
735
736 // Combine the constraints.
737 statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
738
739 // The target sum should be S = γ^0 · s = 1 · s = s.
740 let expected_sum = s;
741 assert_eq!(acc_sum, expected_sum);
742
743 // The weight polynomial should be W(b) = select(pow(z), b) for all b ∈ {0,1}^k.
744 //
745 // Verify each entry manually using the property: select(pow(z), b) = z^b.
746 for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
747 let expected_weight = z.exp_u64(b as u64);
748 assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
749 }
750 }
751
752 #[test]
753 fn test_select_statement_combine_multiple_constraints() {
754 // Test combining multiple constraints with batching.
755 //
756 // For k=2 variables, we have a 2^2 = 4-point domain.
757 let k = 2;
758
759 // Create a statement with two constraints:
760 // - Constraint 0: p(z0) = s0
761 // - Constraint 1: p(z1) = s1
762 let mut statement = SelectStatement::<F, F>::initialize(k);
763 let z0 = F::from_u64(3);
764 let s0 = F::from_u64(10);
765 let z1 = F::from_u64(7);
766 let s1 = F::from_u64(20);
767 statement.add_constraint(z0, s0);
768 statement.add_constraint(z1, s1);
769
770 // Use challenge γ for batching.
771 let gamma = F::from_u64(2);
772 let shift = 0;
773
774 // Initialize accumulators.
775 let mut acc_weights = Poly::zero(k);
776 let mut acc_sum = F::ZERO;
777
778 // Combine the constraints.
779 statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
780
781 // The target sum should be:
782 // S = γ^0 · s0 + γ^1 · s1 = 1·s0 + γ·s1 = s0 + gamma*s1.
783 let expected_sum = s0 + gamma * s1;
784 assert_eq!(acc_sum, expected_sum);
785
786 // The weight polynomial should be:
787 // W(b) = γ^0 · select(pow(z0), b) + γ^1 · select(pow(z1), b)
788 // = select(pow(z0), b) + gamma · select(pow(z1), b)
789 // Using the property: select(pow(z), b) = z^b.
790 for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
791 let weight0 = z0.exp_u64(b as u64);
792 let weight1 = z1.exp_u64(b as u64);
793 let expected_weight = weight0 + gamma * weight1;
794 assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
795 }
796 }
797
798 #[test]
799 fn test_select_statement_combine_with_shift() {
800 // Test combining constraints with a non-zero shift parameter.
801 //
802 // The shift parameter allows multiple statement types to use non-overlapping
803 // challenge powers for batching.
804 let k = 1;
805
806 // Create a statement with one constraint: p(z) = s.
807 let mut statement = SelectStatement::<F, F>::initialize(k);
808 let z = F::from_u64(5);
809 let s = F::from_u64(100);
810 statement.add_constraint(z, s);
811
812 // Use challenge γ with shift.
813 // This means the constraint will be weighted by γ^{0+shift} = γ^shift.
814 let gamma = F::from_u64(2);
815 let shift = 3;
816
817 // Initialize accumulators.
818 let mut acc_weights = Poly::zero(k);
819 let mut acc_sum = F::ZERO;
820
821 // Combine the constraints.
822 statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
823
824 // The target sum should be S = γ^shift · s.
825 let gamma_to_shift = gamma.exp_u64(shift as u64);
826 let expected_sum = gamma_to_shift * s;
827 assert_eq!(acc_sum, expected_sum);
828
829 // The weight polynomial should be W(b) = γ^shift · select(pow(z), b).
830 // Using the property: select(pow(z), b) = z^b.
831 for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
832 let select_val = z.exp_u64(b as u64);
833 let expected_weight = gamma_to_shift * select_val;
834 assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
835 }
836 }
837
838 #[test]
839 fn test_select_statement_combine_empty() {
840 // Test that combining an empty statement does nothing.
841 let k = 2;
842 let statement = SelectStatement::<F, F>::initialize(k);
843
844 // Initialize accumulators with non-zero values.
845 let w0 = F::from_u64(1);
846 let w1 = F::from_u64(2);
847 let w2 = F::from_u64(3);
848 let w3 = F::from_u64(4);
849 let mut acc_weights = Poly::new(vec![w0, w1, w2, w3]);
850 let initial_sum = F::from_u64(99);
851 let mut acc_sum = initial_sum;
852
853 // Store original values.
854 let original_weights = acc_weights.clone();
855 let original_sum = acc_sum;
856
857 // Combine the empty statement.
858 let gamma = F::from_u64(2);
859 let shift = 0;
860 statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
861
862 // The accumulators should remain unchanged.
863 assert_eq!(acc_weights, original_weights);
864 assert_eq!(acc_sum, original_sum);
865 }
866
867 #[test]
868 fn test_select_statement_combine_accumulation() {
869 // Test that combine properly accumulates (adds to) existing values.
870 //
871 // This is important for batching multiple statements together.
872 let k = 1;
873
874 // Create first statement with constraint p(z1) = s1.
875 let mut statement1 = SelectStatement::<F, F>::initialize(k);
876 let z1 = F::from_u64(2);
877 let s1 = F::from_u64(5);
878 statement1.add_constraint(z1, s1);
879
880 // Create second statement with constraint p(z2) = s2.
881 let mut statement2 = SelectStatement::<F, F>::initialize(k);
882 let z2 = F::from_u64(3);
883 let s2 = F::from_u64(7);
884 statement2.add_constraint(z2, s2);
885
886 let gamma = F::from_u64(2);
887 let shift = 0;
888
889 // Initialize accumulators.
890 let mut acc_weights = Poly::zero(k);
891 let mut acc_sum = F::ZERO;
892
893 // Combine first statement.
894 statement1.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
895
896 // Store intermediate values.
897 let intermediate_weights = acc_weights.clone();
898 let intermediate_sum = acc_sum;
899
900 // Combine second statement (should add to existing values).
901 statement2.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
902
903 // The accumulated sum should be intermediate_sum + s2.
904 let expected_sum = intermediate_sum + s2;
905 assert_eq!(acc_sum, expected_sum);
906
907 // The accumulated weights should be the sum of both select functions.
908 // Using the property: select(pow(z), b) = z^b.
909 let domain_size = 1 << k;
910 for b in 0..domain_size {
911 let weight2 = z2.exp_u64(b as u64);
912 let expected_weight = intermediate_weights.as_slice()[b] + weight2;
913 assert_eq!(
914 acc_weights.as_slice()[b],
915 expected_weight,
916 "Accumulated weight mismatch at index {b}"
917 );
918 }
919 }
920
921 #[test]
922 fn test_select_statement_combine_evals() {
923 // Test the combine_evals method.
924 let k = 2;
925
926 // Create a statement with two constraints.
927 let mut statement = SelectStatement::<F, F>::initialize(k);
928 let s0 = F::from_u64(10);
929 let s1 = F::from_u64(20);
930 statement.add_constraint(F::from_u64(3), s0);
931 statement.add_constraint(F::from_u64(7), s1);
932
933 let gamma = F::from_u64(2);
934 let shift = 1;
935
936 // Test combine_evals.
937 let mut claimed_eval = F::ZERO;
938 statement.combine_evals(&mut claimed_eval, gamma, shift);
939
940 // Expected: S = γ^{shift} · s0 + γ^{shift+1} · s1 = γ^1·s0 + γ^2·s1.
941 let gamma_1 = gamma.exp_u64(shift as u64);
942 let gamma_2 = gamma.exp_u64((shift + 1) as u64);
943 let expected = gamma_1 * s0 + gamma_2 * s1;
944 assert_eq!(claimed_eval, expected);
945 }
946
947 #[test]
948 fn test_select_statement_combine_evals_accumulation() {
949 // Test that combine_evals properly accumulates.
950 let k = 1;
951
952 let mut statement = SelectStatement::<F, F>::initialize(k);
953 let s = F::from_u64(10);
954 statement.add_constraint(F::from_u64(5), s);
955
956 let gamma = F::from_u64(3);
957 let shift = 0;
958
959 // Start with a non-zero claimed_eval.
960 let initial_eval = F::from_u64(42);
961 let mut claimed_eval = initial_eval;
962
963 // Combine evals should add to the existing value.
964 statement.combine_evals(&mut claimed_eval, gamma, shift);
965
966 // Expected: initial_eval + γ^0 · s = initial_eval + 1·s = initial_eval + s.
967 let expected = initial_eval + s;
968 assert_eq!(claimed_eval, expected);
969 }
970
971 #[test]
972 fn test_select_combine_consistency_with_verify() {
973 // Test that combine and verify are consistent.
974 //
975 // If we create a polynomial that satisfies the constraints, then:
976 // 1. verify() should return true
977 // 2. The combined weights should correctly compute the polynomial evaluations
978 let k = 2;
979
980 // Create a simple polynomial: evaluations [c0, c1, c2, c3].
981 let c0 = F::from_u64(1);
982 let c1 = F::from_u64(2);
983 let c2 = F::from_u64(3);
984 let c3 = F::from_u64(4);
985 let poly = Poly::new(vec![c0, c1, c2, c3]);
986
987 // Create constraints that match the polynomial.
988 // Using Horner evaluation: p(z) = c0 + c1*z + c2*z^2 + c3*z^3.
989 let mut statement = SelectStatement::<F, F>::initialize(k);
990
991 // Evaluate p(z) at z using Horner's method.
992 let z = F::from_u64(2);
993 let expected_eval: F = poly.iter().copied().horner(z);
994 statement.add_constraint(z, expected_eval);
995
996 // Verify should pass.
997 assert!(statement.verify(&poly));
998
999 // Now combine and check that the weight polynomial correctly represents
1000 // the select function.
1001 let gamma = F::from_u64(3);
1002 let shift = 0;
1003 let mut acc_weights = Poly::zero(k);
1004 let mut acc_sum = F::ZERO;
1005 statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
1006
1007 // The sum should match the expected evaluation.
1008 assert_eq!(acc_sum, expected_eval);
1009
1010 // The weight polynomial should satisfy:
1011 // Σ_{b ∈ {0,1}^k} poly(b) · W(b) = expected_eval
1012 let mut computed_sum = F::ZERO;
1013 for (poly_val, acc_weight) in poly.as_slice().iter().zip(acc_weights.as_slice().iter()) {
1014 computed_sum += *poly_val * *acc_weight;
1015 }
1016 assert_eq!(computed_sum, expected_eval);
1017 }
1018
1019 proptest! {
1020 #[test]
1021 fn prop_select_statement_combine_sum(
1022 // Number of variables (1 to 4 for reasonable test size).
1023 k in 1usize..=4,
1024 // Number of constraints (1 to 5).
1025 num_constraints in 1usize..=5,
1026 // Random evaluation points (avoiding 0 for better coverage).
1027 // Generate exactly num_constraints values.
1028 z_values in prop::collection::vec(1u32..100, 1..=5),
1029 // Random expected evaluations.
1030 s_values in prop::collection::vec(0u32..100, 1..=5),
1031 // Random challenge.
1032 challenge in 1u32..50,
1033 ) {
1034 // Ensure we have enough values for the test.
1035 let actual_num_constraints = num_constraints.min(z_values.len()).min(s_values.len());
1036 if actual_num_constraints == 0 {
1037 return Ok(());
1038 }
1039
1040 let z_values = &z_values[..actual_num_constraints];
1041 let s_values = &s_values[..actual_num_constraints];
1042
1043 // Create statement with random constraints.
1044 let mut statement = SelectStatement::<F, F>::initialize(k);
1045 for (&z, &s) in z_values.iter().zip(s_values.iter()) {
1046 statement.add_constraint(F::from_u32(z), F::from_u32(s));
1047 }
1048
1049 let gamma = F::from_u32(challenge);
1050
1051 // Combine with shift=0.
1052 let mut acc_weights = Poly::zero(k);
1053 let mut acc_sum = F::ZERO;
1054 statement.combine(&mut acc_weights, &mut acc_sum, gamma, 0);
1055
1056 // Compute expected sum manually: S = Σ_i γ^i · s_i.
1057 let mut expected_sum = F::ZERO;
1058 for (i, &s) in s_values.iter().enumerate() {
1059 expected_sum += gamma.exp_u64(i as u64) * F::from_u32(s);
1060 }
1061
1062 prop_assert_eq!(acc_sum, expected_sum);
1063 }
1064 }
1065
1066 proptest! {
1067 #[test]
1068 fn prop_select_statement_verify(
1069 // Polynomial evaluations (2^k values for k=3).
1070 poly_evals in prop::collection::vec(0u32..100, 8),
1071 // Evaluation point (avoiding 0 for better coverage).
1072 z in 1u32..50,
1073 ) {
1074 let k = 3; // Fixed k=3 gives 2^3 = 8 evaluations.
1075 let poly = Poly::new(poly_evals.into_iter().map(F::from_u32).collect());
1076
1077 // Compute expected evaluation using Horner's method.
1078 let z_field = F::from_u32(z);
1079 let expected_eval: F = poly.iter().copied().horner(z_field);
1080
1081 // Create statement with correct constraint.
1082 let mut statement = SelectStatement::<F, F>::initialize(k);
1083 statement.add_constraint(z_field, expected_eval);
1084
1085 // Verify should pass.
1086 prop_assert!(statement.verify(&poly));
1087
1088 // Add a wrong constraint (off by 1, unless it wraps to same value).
1089 let wrong_eval = expected_eval + F::ONE;
1090 if wrong_eval != expected_eval {
1091 statement.add_constraint(z_field, wrong_eval);
1092 // Verify should fail now.
1093 prop_assert!(!statement.verify(&poly));
1094 }
1095 }
1096 }
1097
1098 proptest! {
1099 #[test]
1100 fn prop_combine_evals_consistency(
1101 // Number of constraints.
1102 num_constraints in 1usize..=5,
1103 // Random evaluations.
1104 s_values in prop::collection::vec(0u32..100, 1..=5),
1105 // Random challenge.
1106 challenge in 1u32..50,
1107 // Random shift.
1108 shift in 0usize..3,
1109 ) {
1110 let s_values = &s_values[..num_constraints.min(s_values.len())];
1111
1112 // Create statement with arbitrary z values (they don't matter for this test).
1113 let mut statement = SelectStatement::<F, F>::initialize(2);
1114 for &s in s_values {
1115 statement.add_constraint(F::from_u32(1), F::from_u32(s));
1116 }
1117
1118 let gamma = F::from_u32(challenge);
1119
1120 // Method 1: Use combine_evals.
1121 let mut claimed_eval1 = F::ZERO;
1122 statement.combine_evals(&mut claimed_eval1, gamma, shift);
1123
1124 // Method 2: Compute manually.
1125 let mut claimed_eval2 = F::ZERO;
1126 for (i, &s) in s_values.iter().enumerate() {
1127 claimed_eval2 += gamma.exp_u64((i + shift) as u64) * F::from_u32(s);
1128 }
1129
1130 prop_assert_eq!(claimed_eval1, claimed_eval2);
1131 }
1132 }
1133
1134 proptest! {
1135 #[test]
1136 fn prop_packed_combine_roundtrip(
1137 // Number of variables (covers both naive and split paths).
1138 k in 4usize..10,
1139 // Number of select constraints per batch.
1140 n in 1usize..12,
1141 // Challenge power offset.
1142 shift in 0usize..5,
1143 // RNG seed for reproducible randomness.
1144 seed in 0u64..100,
1145 ) {
1146 type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
1147
1148 let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
1149 if k < k_pack {
1150 return Ok(());
1151 }
1152
1153 let mut rng = SmallRng::seed_from_u64(seed);
1154 let challenge: EF = rng.random();
1155
1156 // Generate n random evaluation points and expected values.
1157 let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
1158 let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
1159
1160 let statement = SelectStatement::<F, EF>::new(k, vars, evals);
1161
1162 // Scalar path: combine into a 2^k evaluation list.
1163 let mut scalar_weights = Poly::<EF>::zero(k);
1164 let mut scalar_sum = EF::ZERO;
1165 statement.combine(&mut scalar_weights, &mut scalar_sum, challenge, shift);
1166
1167 // Packed path: combine into a 2^{k - k_pack} packed list.
1168 let mut packed_weights = Poly::<PackedExt>::zero(k - k_pack);
1169 let mut packed_sum = EF::ZERO;
1170 statement.combine_packed(&mut packed_weights, &mut packed_sum, challenge, shift);
1171
1172 // Unpack the packed result and compare element-by-element.
1173 let unpacked =
1174 <PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
1175 packed_weights.as_slice().iter().copied(),
1176 )
1177 .collect::<Vec<_>>();
1178 prop_assert_eq!(scalar_weights.as_slice(), &unpacked[..]);
1179
1180 // The scalar sums must match exactly.
1181 prop_assert_eq!(scalar_sum, packed_sum);
1182 }
1183
1184 #[test]
1185 fn prop_packed_combine_accumulation(
1186 k in 4usize..10,
1187 seed in 0u64..50,
1188 ) {
1189 type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
1190
1191 let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
1192 if k < k_pack {
1193 return Ok(());
1194 }
1195
1196 let mut rng = SmallRng::seed_from_u64(seed);
1197 let challenge: EF = rng.random();
1198
1199 let mut s_wt = Poly::<EF>::zero(k);
1200 let mut p_wt = Poly::<PackedExt>::zero(k - k_pack);
1201 let mut s_sum = EF::ZERO;
1202 let mut p_sum = EF::ZERO;
1203 let mut shift = 0;
1204
1205 // Two batches with different constraint counts.
1206 for n in [3, 7] {
1207 let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
1208 let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
1209 let stmt = SelectStatement::<F, EF>::new(k, vars, evals);
1210
1211 stmt.combine(&mut s_wt, &mut s_sum, challenge, shift);
1212 stmt.combine_packed(&mut p_wt, &mut p_sum, challenge, shift);
1213 shift += stmt.len();
1214 }
1215
1216 // Verify accumulated results match after both batches.
1217 let unpacked =
1218 <PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
1219 p_wt.as_slice().iter().copied(),
1220 )
1221 .collect::<Vec<_>>();
1222 prop_assert_eq!(s_wt.as_slice(), &unpacked[..]);
1223 prop_assert_eq!(s_sum, p_sum);
1224 }
1225 }
1226}