Skip to main content

p3_sumcheck/constraints/statement/
select.rs

1use alloc::vec::Vec;
2
3use itertools::Itertools;
4use p3_field::{
5    ExtensionField, Field, HornerIter, PackedFieldExtension, PackedValue, PrimeCharacteristicRing,
6    dot_product,
7};
8use p3_matrix::Matrix;
9use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
10use p3_maybe_rayon::prelude::*;
11use p3_multilinear_util::point::Point;
12use p3_multilinear_util::poly::Poly;
13use p3_util::log2_strict_usize;
14use tracing::instrument;
15
16/// Expand powers-of-two seeds into the full power table via butterfly.
17///
18/// # Input
19///
20/// A `k × n` matrix where column `j` holds the squared powers of
21/// variable `v_j` in descending exponent order:
22///
23/// ```text
24/// row 0: [v_1^{2^{k-1}}, v_2^{2^{k-1}}, …, v_n^{2^{k-1}}]
25/// row 1: [v_1^{2^{k-2}}, v_2^{2^{k-2}}, …, v_n^{2^{k-2}}]
26///   ⋮
27/// row k-1: [v_1^1,         v_2^1,         …, v_n^1        ]
28/// ```
29///
30/// # Output
31///
32/// A `2^k × n` matrix where entry `[b, j] = v_j^b` (the full
33/// monomial power, not just a squared power).
34///
35/// # Algorithm
36///
37/// Uses a binary-tree butterfly. After processing row `i` of the
38/// input, the first `2^{i+1}` rows of the output are filled.
39/// Each step copies the existing rows and multiplies by the
40/// current squared power to fill the new rows:
41///
42/// ```text
43/// mat[b + 2^i, j] = mat[b, j] * points[i, j]
44/// ```
45fn batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F> {
46    let k = points.height();
47    let n = points.width();
48
49    let mut mat = RowMajorMatrix::new(F::zero_vec(n * (1 << k)), n);
50
51    // Base case: v_j^0 = 1 for all j.
52    mat.row_mut(0).fill(F::ONE);
53
54    // Butterfly expansion: each input row doubles the number of filled rows.
55    points.row_slices().enumerate().for_each(|(i, vars)| {
56        let (lo, mut hi) = mat.split_rows_mut(1 << i);
57        lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
58            // hi[j] = lo[j] * var[j], extending the power by 2^i.
59            vars.iter()
60                .zip(lo.zip(hi.iter_mut()))
61                .for_each(|(&var, (lo, hi))| *hi = lo * var);
62        });
63    });
64    mat
65}
66
67/// SIMD-packed variant of the power-table butterfly expansion.
68///
69/// # Overview
70///
71/// Splits the `k` input variables into two phases:
72///
73/// 1. **Packing phase** (first `k_pack` variables): Builds a small
74///    scalar power table per column, then packs it into a single
75///    SIMD lane. This fills the first row of the packed output.
76///
77/// 2. **Butterfly phase** (remaining `k - k_pack` variables): Applies
78///    the same butterfly as the scalar version, but operates on
79///    packed elements — multiplying all SIMD lanes in one instruction.
80///
81/// # Output
82///
83/// A `2^{k - k_pack} × n` matrix of packed elements, where
84/// unpacking row `r` column `j` yields the `F::Packing::WIDTH`
85/// consecutive scalar entries `v_j^{r * WIDTH}, …, v_j^{r * WIDTH + WIDTH - 1}`.
86fn packed_batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F::Packing> {
87    let k = points.height();
88    let n = points.width();
89    assert_ne!(n, 0);
90    let k_pack = log2_strict_usize(F::Packing::WIDTH);
91    assert!(k >= k_pack);
92
93    let (init_vars, rest_vars) = points.split_rows(k_pack);
94    let mut mat = RowMajorMatrix::new(F::Packing::zero_vec(n * (1 << (k - k_pack))), n);
95
96    if k_pack > 0 {
97        // Packing phase: build a scalar 2^{k_pack}-row power table
98        // per column and pack it into one SIMD element.
99        init_vars
100            .transpose()
101            .row_slices()
102            .zip(mat.values.iter_mut())
103            .for_each(|(vars, packed)| {
104                let point = RowMajorMatrixView::new(vars, 1);
105                *packed = *F::Packing::from_slice(&batch_pows(point).values);
106            });
107    } else {
108        // No packing needed: WIDTH = 1, seed row is all ones.
109        mat.row_mut(0).fill(F::Packing::ONE);
110    }
111
112    // Butterfly phase: same expansion as the scalar version,
113    // but each multiply operates on WIDTH lanes simultaneously.
114    rest_vars.row_slices().enumerate().for_each(|(i, vars)| {
115        let (lo, mut hi) = mat.split_rows_mut(1 << i);
116        lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
117            vars.iter()
118                .zip(lo.zip(hi.iter_mut()))
119                .for_each(|(&var, (lo, hi))| *hi = lo * var);
120        });
121    });
122    mat
123}
124
125/// A batched system of `select`-based evaluation constraints for multilinear polynomials.
126///
127/// This struct represents a collection of evaluation constraints of the form `p(z_i) = s_i`
128/// for a multilinear polynomial `p` over the Boolean hypercube `{0,1}^k`.
129///
130/// # The Select Function
131///
132/// For vectors `X, Y ∈ F^k`, the select function is defined as:
133///
134/// ```text
135/// select(X, Y) = ∏_i (X_i · Y_i + (1 - Y_i))
136/// ```
137///
138/// **Key Property:** When `Y ∈ {0,1}^k` is a Boolean vector and `X = pow(z)`:
139///
140/// ```text
141/// select(pow(z), b) = z^{int(b)}
142/// ```
143///
144/// where `pow(z) = (z, z^2, z^4, ..., z^{2^{k-1}})` and `int(b)` interprets the Boolean
145/// vector `b` as an integer in binary.
146///
147/// **Derivation:**
148/// ```text
149/// select(pow(z), b) = ∏_i (z^{2^i} · b_i + (1 - b_i))
150///                   = ∏_{i: b_i=1} (z^{2^i})     [since b_i ∈ {0,1}]
151///                   = z^{Σ_{i: b_i=1} 2^i}
152///                   = z^{int(b)}
153/// ```
154///
155/// # Verification Claims
156///
157/// Each constraint `(z_i, s_i)` in this statement asserts:
158///
159/// ```text
160/// Σ_{b ∈ {0,1}^k} P(b) · select(pow(z_i), b) = s_i
161/// ```
162///
163/// where `P(b)` are the evaluations of the polynomial over the Boolean hypercube.
164///
165/// # Batching
166///
167/// Multiple constraints are batched using random challenge `γ` to produce:
168///
169/// - **Weight polynomial**: `W(b) = Σ_i γ^i · select(pow(z_i), b)`
170/// - **Target sum**: `S = Σ_i γ^i · s_i`
171///
172/// This reduces `n` separate verification claims to a single sumcheck:
173///
174/// ```text
175/// Σ_{b ∈ {0,1}^k} P(b) · W(b) = S
176/// ```
177#[derive(Clone, Debug)]
178pub struct SelectStatement<F, EF> {
179    /// Number of variables `k` defining the Boolean hypercube `{0,1}^k`.
180    ///
181    /// This determines the dimension of the multilinear polynomial space and the size
182    /// of the evaluation domain (2^k points).
183    num_variables: usize,
184
185    /// Evaluation points `[z_1, z_2, ..., z_n]` where each constraint checks `p(z_i) = s_i`.
186    ///
187    /// Each `z_i ∈ F` is a base field element. The `pow` map will expand it to
188    /// `pow(z_i) = (z_i, z_i^2, z_i^4, ..., z_i^{2^{k-1}})` for the select function.
189    pub(crate) vars: Vec<F>,
190
191    /// Expected evaluation values `[s_1, s_2, ..., s_n]` corresponding to each constraint.
192    ///
193    /// Each `s_i ∈ EF` is an extension field element representing the claimed evaluation
194    /// of the polynomial at point `z_i`.
195    evaluations: Vec<EF>,
196}
197
198impl<F: Field, EF: ExtensionField<F>> SelectStatement<F, EF> {
199    /// Creates an empty select statement for polynomials over `{0,1}^k`.
200    ///
201    /// # Parameters
202    ///
203    /// - `num_variables`: The dimension `k` of the Boolean hypercube
204    ///
205    /// # Returns
206    ///
207    /// An initialized statement with no constraints, ready to accept constraints.
208    #[must_use]
209    pub const fn initialize(num_variables: usize) -> Self {
210        Self {
211            num_variables,
212            vars: Vec::new(),
213            evaluations: Vec::new(),
214        }
215    }
216
217    /// Creates a select statement pre-populated with constraints.
218    ///
219    /// # Parameters
220    ///
221    /// - `num_variables`: The dimension `k` of the Boolean hypercube
222    /// - `vars`: Evaluation points `[z_1, ..., z_n]`
223    /// - `evaluations`: Expected values `[s_1, ..., s_n]`
224    ///
225    /// # Panics
226    ///
227    /// Panics if the nu
228    #[must_use]
229    pub const fn new(num_variables: usize, vars: Vec<F>, evaluations: Vec<EF>) -> Self {
230        assert!(vars.len() == evaluations.len());
231        Self {
232            num_variables,
233            vars,
234            evaluations,
235        }
236    }
237
238    /// Returns the number of variables `k` defining the polynomial space dimension.
239    ///
240    /// This is the dimension of the Boolean hypercube `{0,1}^k` over which polynomials
241    /// are defined, containing `2^k` evaluation points.
242    #[must_use]
243    pub const fn num_variables(&self) -> usize {
244        self.num_variables
245    }
246
247    /// Returns `true` if no constraints have been added to this statement.
248    #[must_use]
249    pub const fn is_empty(&self) -> bool {
250        debug_assert!(self.vars.is_empty() == self.evaluations.is_empty());
251        self.vars.is_empty()
252    }
253
254    /// Returns an iterator over constraint pairs `(z_i, s_i)`.
255    ///
256    /// Each pair represents one evaluation constraint: `p(z_i) = s_i`.
257    pub fn iter(&self) -> impl Iterator<Item = (&F, &EF)> {
258        self.vars.iter().zip(self.evaluations.iter())
259    }
260
261    /// Returns the number of evaluation constraints `n` in this statement.
262    #[must_use]
263    pub const fn len(&self) -> usize {
264        debug_assert!(self.vars.len() == self.evaluations.len());
265        self.vars.len()
266    }
267
268    /// Verifies that a given polynomial satisfies all constraints in the statement.
269    ///
270    /// For each constraint `(z_i, s_i)`, this method interprets the evaluation table as
271    /// coefficients of a univariate polynomial, evaluates it at `z_i` using Horner's method,
272    /// and checks if the result equals the expected value `s_i`.
273    ///
274    /// For a polynomial represented by evaluations `[c_0, c_1, ..., c_{2^k-1}]`:
275    ///
276    /// ```text
277    /// p(z) = c_0 + z(c_1 + z(c_2 + z(...)))
278    /// ```
279    ///
280    /// This is computed right-to-left as:
281    /// ```text
282    /// acc = 0
283    /// for i = 2^k-1 down to 0:
284    ///     acc = acc * z + c_i
285    /// ```
286    ///
287    /// # Parameters
288    ///
289    /// - `poly`: Evaluation table treated as univariate polynomial coefficients
290    ///
291    /// # Returns
292    ///
293    /// `true` if all constraints are satisfied, `false` otherwise.
294    #[must_use]
295    pub fn verify(&self, poly: &Poly<EF>) -> bool {
296        self.iter().all(|(&var, &expected_eval)| {
297            // Evaluate the polynomial at `var` using Horner's method.
298            // This computes: p(var) = c_0 + var(c_1 + var(c_2 + ...))
299            poly.iter().copied().horner::<EF, _>(var) == expected_eval
300        })
301    }
302
303    /// Adds a single evaluation constraint `p(z) = s` to the statement.
304    ///
305    /// # Parameters
306    ///
307    /// - `var`: Evaluation point `z ∈ F`
308    /// - `eval`: Expected evaluation value `s ∈ EF`
309    pub fn add_constraint(&mut self, var: F, eval: EF) {
310        self.vars.push(var);
311        self.evaluations.push(eval);
312    }
313
314    /// Batches all constraints into a single weighted polynomial and target sum for sumcheck.
315    ///
316    /// Given constraints `p(z_1) = s_1, ..., p(z_n) = s_n`, this method transforms them into
317    /// a single sumcheck claim using random challenge `γ`:
318    ///
319    /// ```text
320    /// Σ_{b ∈ {0,1}^k} P(b) · W(b) = S
321    /// ```
322    ///
323    /// where:
324    /// - **Weight polynomial**: `W(b) = Σ_i γ^{i+shift} · select(pow(z_i), b)`
325    /// - **Target sum**: `S = Σ_i γ^{i+shift} · s_i`
326    ///
327    /// The method computes `W(b)` for all `b ∈ {0,1}^k` and `S`, adding them to the
328    /// provided accumulators.
329    ///
330    /// # Parameters
331    ///
332    /// - `acc_weights`: Accumulator for the weight polynomial `W(b)`. Must have `2^k` entries.
333    ///   This method **adds** the batched weights to existing values.
334    ///
335    /// - `acc_sum`: Accumulator for the target sum `S`. This method **adds** the batched
336    ///   evaluations to the existing value.
337    ///
338    /// - `challenge`: Random challenge `γ ∈ EF` used for batching.
339    ///
340    /// - `shift`: Power offset for challenge. Constraint `i` uses weight `γ^{i+shift}`.
341    ///   Allows multiple statement types to use non-overlapping challenge powers.
342    /// Batches all constraints into a single weighted polynomial and target sum for sumcheck.
343    ///
344    /// # Algorithm
345    ///
346    /// Three stages:
347    ///
348    /// 1. **Power map**: Build a `k × n` matrix where row `i`, column `j`
349    ///    holds `z_j^{2^i}`. Stored as a flat row-major buffer so each
350    ///    butterfly step reads a contiguous row (cache-friendly).
351    ///
352    /// 2. **Butterfly expansion**: Expand the power map into the full
353    ///    `2^k × n` select matrix using the same binary-tree doubling as
354    ///    the scalar power table. Entry `[b, j] = z_j^b`.
355    ///
356    /// 3. **Challenge combination**: Dot each row of the select matrix
357    ///    with the challenge power vector to produce the weight polynomial.
358    #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
359    pub fn combine(
360        &self,
361        acc_weights: &mut Poly<EF>,
362        acc_sum: &mut EF,
363        challenge: EF,
364        shift: usize,
365    ) {
366        // Early return for empty statement:
367        //
368        // No constraints means no contribution to the batched claim.
369        if self.vars.is_empty() {
370            return;
371        }
372
373        // Extract dimensions for clarity.
374        //
375        // Number of constraints
376        let n = self.len();
377        // Dimension of Boolean hypercube
378        let k = self.num_variables();
379
380        // ---------------------------------------------------------------
381        // Stage 1: Build the k × n power-of-two matrix.
382        // ---------------------------------------------------------------
383        //
384        // Row i contains [z_1^{2^i}, z_2^{2^i}, ..., z_n^{2^i}].
385        // Stored as a flat Vec<F> of size k * n in row-major order.
386        let mut pow_matrix = F::zero_vec(k * n);
387        for (j, &var) in self.vars.iter().enumerate() {
388            let mut v = var;
389            for i in 0..k {
390                // pow_matrix[i * n + j] = z_j^{2^i}
391                pow_matrix[i * n + j] = v;
392                v = v.square();
393            }
394        }
395
396        // ---------------------------------------------------------------
397        // Stage 2: Butterfly expansion into the 2^k × n select matrix.
398        // ---------------------------------------------------------------
399        //
400        // After iteration i, the first 2^{i+1} rows are filled.
401        // Entry [b, j] = z_j^b.
402        let mut acc = F::zero_vec((1 << k) * n);
403
404        // Base case: z_j^0 = 1 for all j.
405        acc[..n].fill(F::ONE);
406
407        for i in 0..k {
408            let num_existing_rows = 1 << i;
409            let (lo, hi) = acc.split_at_mut(num_existing_rows * n);
410
411            // Contiguous row slice — no per-iteration allocation.
412            let pow_row = &pow_matrix[i * n..(i + 1) * n];
413
414            // For each existing row, compute the new row:
415            //   acc[b + 2^i, j] = acc[b, j] * z_j^{2^i}
416            lo.par_chunks_mut(n)
417                .zip(hi.par_chunks_mut(n))
418                .for_each(|(lo_row, hi_row)| {
419                    pow_row
420                        .iter()
421                        .zip(lo_row.iter())
422                        .zip(hi_row.iter_mut())
423                        .for_each(|((&z_pow, &lo_val), hi_val)| {
424                            *hi_val = lo_val * z_pow;
425                        });
426                });
427        }
428
429        // ---------------------------------------------------------------
430        // Stage 3: Combine with challenge powers.
431        // ---------------------------------------------------------------
432
433        // Precompute [gamma^shift, gamma^{shift+1}, ..., gamma^{shift+n-1}].
434        let challenges = challenge
435            .shifted_powers(challenge.exp_u64(shift as u64))
436            .collect_n(n);
437
438        // W(b) += sum_i gamma^{i+shift} * z_i^b
439        acc.par_chunks(n)
440            .zip(acc_weights.as_mut_slice().par_iter_mut())
441            .for_each(|(row, weight_out)| {
442                *weight_out +=
443                    dot_product::<EF, _, _>(challenges.iter().copied(), row.iter().copied());
444            });
445
446        // S += sum_i gamma^{i+shift} * s_i
447        *acc_sum +=
448            dot_product::<EF, _, _>(challenges.into_iter(), self.evaluations.iter().copied());
449    }
450
451    /// SIMD-packed variant of constraint batching.
452    ///
453    /// # Overview
454    ///
455    /// Produces the same result as the scalar version, but stores the
456    /// weight polynomial in packed form (one SIMD element per
457    /// `Packing::WIDTH` consecutive hypercube entries).
458    ///
459    /// # Algorithm
460    ///
461    /// For small `k` (where `2 * k_pack > k`), falls back to a naive
462    /// per-constraint loop using shifted powers.
463    ///
464    /// For larger `k`, uses the split-and-dot approach:
465    ///
466    /// 1. Expand each evaluation point into its power-map representation.
467    /// 2. Transpose into a `k × n` matrix and split at `k / 2`.
468    /// 3. Build the packed left-half power table and the scalar right-half
469    ///    power table.
470    /// 4. For each pair of rows (left packed, right scalar), compute the
471    ///    weighted dot product with the challenge powers.
472    #[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
473    pub fn combine_packed(
474        &self,
475        weights: &mut Poly<EF::ExtensionPacking>,
476        sum: &mut EF,
477        challenge: EF,
478        shift: usize,
479    ) {
480        if self.vars.is_empty() {
481            return;
482        }
483
484        let n = self.len();
485        let k = self.num_variables();
486        let k_pack = log2_strict_usize(F::Packing::WIDTH);
487        assert!(k >= k_pack);
488        assert_eq!(weights.num_variables() + k_pack, k);
489
490        // Accumulate the scalar target sum first.
491        self.combine_evals(sum, challenge, shift);
492
493        // Naive fallback: when there aren't enough variables for the
494        // split approach, compute shifted powers directly per constraint.
495        if k_pack * 2 > k {
496            self.vars
497                .iter()
498                .zip(challenge.shifted_powers(challenge.exp_u64(shift as u64)))
499                .for_each(|(&var, challenge)| {
500                    // gamma^{shift+i} * [1, z, z^2, ..., z^{2^k - 1}]
501                    let pow = EF::from(var).shifted_powers(challenge).collect_n(1 << k);
502                    weights
503                        .as_mut_slice()
504                        .iter_mut()
505                        .zip_eq(pow.chunks(F::Packing::WIDTH))
506                        .for_each(|(out, chunk)| {
507                            *out += EF::ExtensionPacking::from_ext_slice(chunk);
508                        });
509                });
510            return;
511        }
512
513        // Split approach: expand each var into its power-map form,
514        // transpose, and split into left (packed) and right (scalar) halves.
515        let points = self
516            .vars
517            .iter()
518            .map(|&var| Point::expand_from_univariate(var, k))
519            .collect::<Vec<_>>();
520        let points = Point::transpose(&points, true);
521        let (left, right) = points.split_rows(k / 2);
522
523        // Left half → packed power table (operates in SIMD lanes).
524        let left = packed_batch_pows(left);
525        // Right half → scalar power table.
526        let right = batch_pows(right);
527
528        // Broadcast challenge powers into packed form for dot products.
529        let alphas = challenge
530            .shifted_powers(challenge.exp_u64(shift as u64))
531            .collect_n(n)
532            .into_iter()
533            .map(EF::ExtensionPacking::from)
534            .collect::<Vec<_>>();
535
536        // For each right-half row, dot all left-half rows against it
537        // (weighted by the challenge powers) and accumulate into the
538        // packed weight polynomial.
539        weights
540            .as_mut_slice()
541            .par_chunks_mut(left.height())
542            .zip(right.par_row_slices())
543            .for_each(|(out, right)| {
544                out.iter_mut().zip(left.rows()).for_each(|(out, left)| {
545                    *out += left
546                        .zip(right.iter())
547                        .zip(alphas.iter())
548                        .map(|((left, &right), &alpha)| alpha * (left * right))
549                        .sum::<EF::ExtensionPacking>();
550                });
551            });
552    }
553
554    /// Batches expected evaluation values into a single target sum using challenge powers.
555    ///
556    /// Computes and adds to `claimed_eval`:
557    ///
558    /// ```text
559    /// S = Σ_i γ^{i+shift} · s_i
560    /// ```
561    ///
562    /// where `s_i` are the expected evaluation values in `self.evaluations`.
563    ///
564    /// # Parameters
565    ///
566    /// - `claimed_eval`: Accumulator for the target sum. This method **adds** the batched
567    ///   evaluations to the existing value.
568    ///
569    /// - `challenge`: Random challenge `γ ∈ EF` used for batching.
570    ///
571    /// - `shift`: Power offset. Constraint `i` uses weight `γ^{i+shift}`.
572    pub fn combine_evals(&self, claimed_eval: &mut EF, challenge: EF, shift: usize) {
573        // Compute: Σ_i γ^{i+shift} · s_i
574        // This is equivalent to dot_product(evaluations, [γ^shift, γ^{shift+1}, ...])
575        *claimed_eval += dot_product::<EF, _, _>(
576            self.evaluations.iter().copied(),
577            challenge
578                .shifted_powers(challenge.exp_u64(shift as u64))
579                .take(self.len()),
580        );
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use alloc::vec;
587
588    use p3_baby_bear::BabyBear;
589    use p3_field::extension::BinomialExtensionField;
590    use p3_field::{PackedFieldExtension, PrimeCharacteristicRing};
591    use proptest::prelude::*;
592    use rand::rngs::SmallRng;
593    use rand::{RngExt, SeedableRng};
594
595    use super::*;
596
597    type F = BabyBear;
598    type EF = BinomialExtensionField<F, 4>;
599
600    #[test]
601    fn test_select_statement_initialize() {
602        // Test that initialize creates an empty statement with correct num_variables.
603        let statement = SelectStatement::<F, F>::initialize(3);
604
605        // The statement should have 3 variables.
606        assert_eq!(statement.num_variables(), 3);
607        // The statement should be empty (no constraints).
608        assert!(statement.is_empty());
609        // The length should be 0.
610        assert_eq!(statement.len(), 0);
611    }
612
613    #[test]
614    fn test_select_statement_new() {
615        // Test that new creates a statement with pre-populated constraints.
616        let vars = vec![F::from_u64(5), F::from_u64(7)];
617        let evaluations = vec![F::from_u64(10), F::from_u64(20)];
618
619        let statement = SelectStatement::new(2, vars.clone(), evaluations.clone());
620
621        // The statement should have 2 variables.
622        assert_eq!(statement.num_variables(), 2);
623        // The statement should not be empty.
624        assert!(!statement.is_empty());
625        // The statement should have 2 constraints.
626        assert_eq!(statement.len(), 2);
627        // The vars and evaluations should match.
628        assert_eq!(statement.vars, vars);
629        assert_eq!(statement.evaluations, evaluations);
630    }
631
632    #[test]
633    #[should_panic(expected = "assertion")]
634    fn test_select_statement_new_mismatched_lengths() {
635        // Test that new panics when vars.len() != evaluations.len().
636        let vars = vec![F::from_u64(5)];
637        let evaluations = vec![F::from_u64(10), F::from_u64(20)];
638
639        // This should panic due to length mismatch.
640        let _ = SelectStatement::new(2, vars, evaluations);
641    }
642
643    #[test]
644    fn test_select_statement_add_constraint() {
645        // Test adding constraints one at a time.
646        let mut statement = SelectStatement::<F, F>::initialize(2);
647
648        // Initially empty.
649        assert!(statement.is_empty());
650        assert_eq!(statement.len(), 0);
651
652        // Add first constraint: p(5) = 10.
653        statement.add_constraint(F::from_u64(5), F::from_u64(10));
654        assert!(!statement.is_empty());
655        assert_eq!(statement.len(), 1);
656
657        // Add second constraint: p(7) = 20.
658        statement.add_constraint(F::from_u64(7), F::from_u64(20));
659        assert_eq!(statement.len(), 2);
660
661        // Verify the constraints were added correctly.
662        let constraints: Vec<_> = statement.iter().collect();
663        assert_eq!(constraints.len(), 2);
664        assert_eq!(*constraints[0].0, F::from_u64(5));
665        assert_eq!(*constraints[0].1, F::from_u64(10));
666        assert_eq!(*constraints[1].0, F::from_u64(7));
667        assert_eq!(*constraints[1].1, F::from_u64(20));
668    }
669
670    #[test]
671    fn test_select_statement_verify_basic() {
672        // Test the verify method with a simple polynomial.
673        //
674        // Create a polynomial with evaluations [c0, c1, c2, c3] over {0,1}^2.
675        let c0 = F::from_u64(1);
676        let c1 = F::from_u64(2);
677        let c2 = F::from_u64(3);
678        let c3 = F::from_u64(4);
679        let poly = Poly::new(vec![c0, c1, c2, c3]);
680
681        // Create a statement with k=2 variables.
682        let k = 2;
683        let mut statement = SelectStatement::<F, F>::initialize(k);
684
685        // The polynomial evaluations [c0, c1, c2, c3] can be interpreted as a univariate polynomial:
686        // p(z) = c0 + c1*z + c2*z^2 + c3*z^3
687        //
688        // Test p(0) = c0 = 1.
689        let z0 = F::ZERO;
690        let eval0 = c0;
691        statement.add_constraint(z0, eval0);
692        assert!(statement.verify(&poly));
693
694        // Test p(1) = c0 + c1 + c2 + c3
695        let mut statement2 = SelectStatement::<F, F>::initialize(k);
696        let z1 = F::ONE;
697        let eval1 = c0 + c1 + c2 + c3;
698        statement2.add_constraint(z1, eval1);
699        assert!(statement2.verify(&poly));
700
701        // Test p(2) = c0 + c1*2 + c2*4 + c3*8
702        let mut statement3 = SelectStatement::<F, F>::initialize(k);
703        let z2 = F::from_u64(2);
704        let eval2 = c0 + c1 * z2 + c2 * z2 * z2 + c3 * z2 * z2 * z2;
705        statement3.add_constraint(z2, eval2);
706        assert!(statement3.verify(&poly));
707
708        // Test a failing verification: p(1) = wrong_eval
709        let mut statement4 = SelectStatement::<F, F>::initialize(k);
710        let wrong_eval = F::from_u64(56765);
711        statement4.add_constraint(z1, wrong_eval);
712        assert!(!statement4.verify(&poly));
713    }
714
715    #[test]
716    fn test_select_statement_combine_single_constraint() {
717        // Test combining a single constraint.
718        //
719        // For k=2 variables, we have a 2^2 = 4-point domain.
720        let k = 2;
721
722        // Create a statement with one constraint: p(z) = s.
723        let mut statement = SelectStatement::<F, F>::initialize(k);
724        let z = F::from_u64(5);
725        let s = F::from_u64(100);
726        statement.add_constraint(z, s);
727
728        // The challenge γ is unused for a single constraint (it would multiply by γ^0 = 1).
729        let gamma = F::from_u64(2);
730        let shift = 0;
731
732        // Initialize accumulators.
733        let mut acc_weights = Poly::zero(k);
734        let mut acc_sum = F::ZERO;
735
736        // Combine the constraints.
737        statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
738
739        // The target sum should be S = γ^0 · s = 1 · s = s.
740        let expected_sum = s;
741        assert_eq!(acc_sum, expected_sum);
742
743        // The weight polynomial should be W(b) = select(pow(z), b) for all b ∈ {0,1}^k.
744        //
745        // Verify each entry manually using the property: select(pow(z), b) = z^b.
746        for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
747            let expected_weight = z.exp_u64(b as u64);
748            assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
749        }
750    }
751
752    #[test]
753    fn test_select_statement_combine_multiple_constraints() {
754        // Test combining multiple constraints with batching.
755        //
756        // For k=2 variables, we have a 2^2 = 4-point domain.
757        let k = 2;
758
759        // Create a statement with two constraints:
760        // - Constraint 0: p(z0) = s0
761        // - Constraint 1: p(z1) = s1
762        let mut statement = SelectStatement::<F, F>::initialize(k);
763        let z0 = F::from_u64(3);
764        let s0 = F::from_u64(10);
765        let z1 = F::from_u64(7);
766        let s1 = F::from_u64(20);
767        statement.add_constraint(z0, s0);
768        statement.add_constraint(z1, s1);
769
770        // Use challenge γ for batching.
771        let gamma = F::from_u64(2);
772        let shift = 0;
773
774        // Initialize accumulators.
775        let mut acc_weights = Poly::zero(k);
776        let mut acc_sum = F::ZERO;
777
778        // Combine the constraints.
779        statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
780
781        // The target sum should be:
782        // S = γ^0 · s0 + γ^1 · s1 = 1·s0 + γ·s1 = s0 + gamma*s1.
783        let expected_sum = s0 + gamma * s1;
784        assert_eq!(acc_sum, expected_sum);
785
786        // The weight polynomial should be:
787        // W(b) = γ^0 · select(pow(z0), b) + γ^1 · select(pow(z1), b)
788        //      = select(pow(z0), b) + gamma · select(pow(z1), b)
789        // Using the property: select(pow(z), b) = z^b.
790        for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
791            let weight0 = z0.exp_u64(b as u64);
792            let weight1 = z1.exp_u64(b as u64);
793            let expected_weight = weight0 + gamma * weight1;
794            assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
795        }
796    }
797
798    #[test]
799    fn test_select_statement_combine_with_shift() {
800        // Test combining constraints with a non-zero shift parameter.
801        //
802        // The shift parameter allows multiple statement types to use non-overlapping
803        // challenge powers for batching.
804        let k = 1;
805
806        // Create a statement with one constraint: p(z) = s.
807        let mut statement = SelectStatement::<F, F>::initialize(k);
808        let z = F::from_u64(5);
809        let s = F::from_u64(100);
810        statement.add_constraint(z, s);
811
812        // Use challenge γ with shift.
813        // This means the constraint will be weighted by γ^{0+shift} = γ^shift.
814        let gamma = F::from_u64(2);
815        let shift = 3;
816
817        // Initialize accumulators.
818        let mut acc_weights = Poly::zero(k);
819        let mut acc_sum = F::ZERO;
820
821        // Combine the constraints.
822        statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
823
824        // The target sum should be S = γ^shift · s.
825        let gamma_to_shift = gamma.exp_u64(shift as u64);
826        let expected_sum = gamma_to_shift * s;
827        assert_eq!(acc_sum, expected_sum);
828
829        // The weight polynomial should be W(b) = γ^shift · select(pow(z), b).
830        // Using the property: select(pow(z), b) = z^b.
831        for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
832            let select_val = z.exp_u64(b as u64);
833            let expected_weight = gamma_to_shift * select_val;
834            assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
835        }
836    }
837
838    #[test]
839    fn test_select_statement_combine_empty() {
840        // Test that combining an empty statement does nothing.
841        let k = 2;
842        let statement = SelectStatement::<F, F>::initialize(k);
843
844        // Initialize accumulators with non-zero values.
845        let w0 = F::from_u64(1);
846        let w1 = F::from_u64(2);
847        let w2 = F::from_u64(3);
848        let w3 = F::from_u64(4);
849        let mut acc_weights = Poly::new(vec![w0, w1, w2, w3]);
850        let initial_sum = F::from_u64(99);
851        let mut acc_sum = initial_sum;
852
853        // Store original values.
854        let original_weights = acc_weights.clone();
855        let original_sum = acc_sum;
856
857        // Combine the empty statement.
858        let gamma = F::from_u64(2);
859        let shift = 0;
860        statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
861
862        // The accumulators should remain unchanged.
863        assert_eq!(acc_weights, original_weights);
864        assert_eq!(acc_sum, original_sum);
865    }
866
867    #[test]
868    fn test_select_statement_combine_accumulation() {
869        // Test that combine properly accumulates (adds to) existing values.
870        //
871        // This is important for batching multiple statements together.
872        let k = 1;
873
874        // Create first statement with constraint p(z1) = s1.
875        let mut statement1 = SelectStatement::<F, F>::initialize(k);
876        let z1 = F::from_u64(2);
877        let s1 = F::from_u64(5);
878        statement1.add_constraint(z1, s1);
879
880        // Create second statement with constraint p(z2) = s2.
881        let mut statement2 = SelectStatement::<F, F>::initialize(k);
882        let z2 = F::from_u64(3);
883        let s2 = F::from_u64(7);
884        statement2.add_constraint(z2, s2);
885
886        let gamma = F::from_u64(2);
887        let shift = 0;
888
889        // Initialize accumulators.
890        let mut acc_weights = Poly::zero(k);
891        let mut acc_sum = F::ZERO;
892
893        // Combine first statement.
894        statement1.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
895
896        // Store intermediate values.
897        let intermediate_weights = acc_weights.clone();
898        let intermediate_sum = acc_sum;
899
900        // Combine second statement (should add to existing values).
901        statement2.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
902
903        // The accumulated sum should be intermediate_sum + s2.
904        let expected_sum = intermediate_sum + s2;
905        assert_eq!(acc_sum, expected_sum);
906
907        // The accumulated weights should be the sum of both select functions.
908        // Using the property: select(pow(z), b) = z^b.
909        let domain_size = 1 << k;
910        for b in 0..domain_size {
911            let weight2 = z2.exp_u64(b as u64);
912            let expected_weight = intermediate_weights.as_slice()[b] + weight2;
913            assert_eq!(
914                acc_weights.as_slice()[b],
915                expected_weight,
916                "Accumulated weight mismatch at index {b}"
917            );
918        }
919    }
920
921    #[test]
922    fn test_select_statement_combine_evals() {
923        // Test the combine_evals method.
924        let k = 2;
925
926        // Create a statement with two constraints.
927        let mut statement = SelectStatement::<F, F>::initialize(k);
928        let s0 = F::from_u64(10);
929        let s1 = F::from_u64(20);
930        statement.add_constraint(F::from_u64(3), s0);
931        statement.add_constraint(F::from_u64(7), s1);
932
933        let gamma = F::from_u64(2);
934        let shift = 1;
935
936        // Test combine_evals.
937        let mut claimed_eval = F::ZERO;
938        statement.combine_evals(&mut claimed_eval, gamma, shift);
939
940        // Expected: S = γ^{shift} · s0 + γ^{shift+1} · s1 = γ^1·s0 + γ^2·s1.
941        let gamma_1 = gamma.exp_u64(shift as u64);
942        let gamma_2 = gamma.exp_u64((shift + 1) as u64);
943        let expected = gamma_1 * s0 + gamma_2 * s1;
944        assert_eq!(claimed_eval, expected);
945    }
946
947    #[test]
948    fn test_select_statement_combine_evals_accumulation() {
949        // Test that combine_evals properly accumulates.
950        let k = 1;
951
952        let mut statement = SelectStatement::<F, F>::initialize(k);
953        let s = F::from_u64(10);
954        statement.add_constraint(F::from_u64(5), s);
955
956        let gamma = F::from_u64(3);
957        let shift = 0;
958
959        // Start with a non-zero claimed_eval.
960        let initial_eval = F::from_u64(42);
961        let mut claimed_eval = initial_eval;
962
963        // Combine evals should add to the existing value.
964        statement.combine_evals(&mut claimed_eval, gamma, shift);
965
966        // Expected: initial_eval + γ^0 · s = initial_eval + 1·s = initial_eval + s.
967        let expected = initial_eval + s;
968        assert_eq!(claimed_eval, expected);
969    }
970
971    #[test]
972    fn test_select_combine_consistency_with_verify() {
973        // Test that combine and verify are consistent.
974        //
975        // If we create a polynomial that satisfies the constraints, then:
976        // 1. verify() should return true
977        // 2. The combined weights should correctly compute the polynomial evaluations
978        let k = 2;
979
980        // Create a simple polynomial: evaluations [c0, c1, c2, c3].
981        let c0 = F::from_u64(1);
982        let c1 = F::from_u64(2);
983        let c2 = F::from_u64(3);
984        let c3 = F::from_u64(4);
985        let poly = Poly::new(vec![c0, c1, c2, c3]);
986
987        // Create constraints that match the polynomial.
988        // Using Horner evaluation: p(z) = c0 + c1*z + c2*z^2 + c3*z^3.
989        let mut statement = SelectStatement::<F, F>::initialize(k);
990
991        // Evaluate p(z) at z using Horner's method.
992        let z = F::from_u64(2);
993        let expected_eval: F = poly.iter().copied().horner(z);
994        statement.add_constraint(z, expected_eval);
995
996        // Verify should pass.
997        assert!(statement.verify(&poly));
998
999        // Now combine and check that the weight polynomial correctly represents
1000        // the select function.
1001        let gamma = F::from_u64(3);
1002        let shift = 0;
1003        let mut acc_weights = Poly::zero(k);
1004        let mut acc_sum = F::ZERO;
1005        statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
1006
1007        // The sum should match the expected evaluation.
1008        assert_eq!(acc_sum, expected_eval);
1009
1010        // The weight polynomial should satisfy:
1011        // Σ_{b ∈ {0,1}^k} poly(b) · W(b) = expected_eval
1012        let mut computed_sum = F::ZERO;
1013        for (poly_val, acc_weight) in poly.as_slice().iter().zip(acc_weights.as_slice().iter()) {
1014            computed_sum += *poly_val * *acc_weight;
1015        }
1016        assert_eq!(computed_sum, expected_eval);
1017    }
1018
1019    proptest! {
1020        #[test]
1021        fn prop_select_statement_combine_sum(
1022            // Number of variables (1 to 4 for reasonable test size).
1023            k in 1usize..=4,
1024            // Number of constraints (1 to 5).
1025            num_constraints in 1usize..=5,
1026            // Random evaluation points (avoiding 0 for better coverage).
1027            // Generate exactly num_constraints values.
1028            z_values in prop::collection::vec(1u32..100, 1..=5),
1029            // Random expected evaluations.
1030            s_values in prop::collection::vec(0u32..100, 1..=5),
1031            // Random challenge.
1032            challenge in 1u32..50,
1033        ) {
1034            // Ensure we have enough values for the test.
1035            let actual_num_constraints = num_constraints.min(z_values.len()).min(s_values.len());
1036            if actual_num_constraints == 0 {
1037                return Ok(());
1038            }
1039
1040            let z_values = &z_values[..actual_num_constraints];
1041            let s_values = &s_values[..actual_num_constraints];
1042
1043            // Create statement with random constraints.
1044            let mut statement = SelectStatement::<F, F>::initialize(k);
1045            for (&z, &s) in z_values.iter().zip(s_values.iter()) {
1046                statement.add_constraint(F::from_u32(z), F::from_u32(s));
1047            }
1048
1049            let gamma = F::from_u32(challenge);
1050
1051            // Combine with shift=0.
1052            let mut acc_weights = Poly::zero(k);
1053            let mut acc_sum = F::ZERO;
1054            statement.combine(&mut acc_weights, &mut acc_sum, gamma, 0);
1055
1056            // Compute expected sum manually: S = Σ_i γ^i · s_i.
1057            let mut expected_sum = F::ZERO;
1058            for (i, &s) in s_values.iter().enumerate() {
1059                expected_sum += gamma.exp_u64(i as u64) * F::from_u32(s);
1060            }
1061
1062            prop_assert_eq!(acc_sum, expected_sum);
1063        }
1064    }
1065
1066    proptest! {
1067        #[test]
1068        fn prop_select_statement_verify(
1069            // Polynomial evaluations (2^k values for k=3).
1070            poly_evals in prop::collection::vec(0u32..100, 8),
1071            // Evaluation point (avoiding 0 for better coverage).
1072            z in 1u32..50,
1073        ) {
1074            let k = 3; // Fixed k=3 gives 2^3 = 8 evaluations.
1075            let poly = Poly::new(poly_evals.into_iter().map(F::from_u32).collect());
1076
1077            // Compute expected evaluation using Horner's method.
1078            let z_field = F::from_u32(z);
1079            let expected_eval: F = poly.iter().copied().horner(z_field);
1080
1081            // Create statement with correct constraint.
1082            let mut statement = SelectStatement::<F, F>::initialize(k);
1083            statement.add_constraint(z_field, expected_eval);
1084
1085            // Verify should pass.
1086            prop_assert!(statement.verify(&poly));
1087
1088            // Add a wrong constraint (off by 1, unless it wraps to same value).
1089            let wrong_eval = expected_eval + F::ONE;
1090            if wrong_eval != expected_eval {
1091                statement.add_constraint(z_field, wrong_eval);
1092                // Verify should fail now.
1093                prop_assert!(!statement.verify(&poly));
1094            }
1095        }
1096    }
1097
1098    proptest! {
1099        #[test]
1100        fn prop_combine_evals_consistency(
1101            // Number of constraints.
1102            num_constraints in 1usize..=5,
1103            // Random evaluations.
1104            s_values in prop::collection::vec(0u32..100, 1..=5),
1105            // Random challenge.
1106            challenge in 1u32..50,
1107            // Random shift.
1108            shift in 0usize..3,
1109        ) {
1110            let s_values = &s_values[..num_constraints.min(s_values.len())];
1111
1112            // Create statement with arbitrary z values (they don't matter for this test).
1113            let mut statement = SelectStatement::<F, F>::initialize(2);
1114            for &s in s_values {
1115                statement.add_constraint(F::from_u32(1), F::from_u32(s));
1116            }
1117
1118            let gamma = F::from_u32(challenge);
1119
1120            // Method 1: Use combine_evals.
1121            let mut claimed_eval1 = F::ZERO;
1122            statement.combine_evals(&mut claimed_eval1, gamma, shift);
1123
1124            // Method 2: Compute manually.
1125            let mut claimed_eval2 = F::ZERO;
1126            for (i, &s) in s_values.iter().enumerate() {
1127                claimed_eval2 += gamma.exp_u64((i + shift) as u64) * F::from_u32(s);
1128            }
1129
1130            prop_assert_eq!(claimed_eval1, claimed_eval2);
1131        }
1132    }
1133
1134    proptest! {
1135        #[test]
1136        fn prop_packed_combine_roundtrip(
1137            // Number of variables (covers both naive and split paths).
1138            k in 4usize..10,
1139            // Number of select constraints per batch.
1140            n in 1usize..12,
1141            // Challenge power offset.
1142            shift in 0usize..5,
1143            // RNG seed for reproducible randomness.
1144            seed in 0u64..100,
1145        ) {
1146            type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
1147
1148            let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
1149            if k < k_pack {
1150                return Ok(());
1151            }
1152
1153            let mut rng = SmallRng::seed_from_u64(seed);
1154            let challenge: EF = rng.random();
1155
1156            // Generate n random evaluation points and expected values.
1157            let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
1158            let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
1159
1160            let statement = SelectStatement::<F, EF>::new(k, vars, evals);
1161
1162            // Scalar path: combine into a 2^k evaluation list.
1163            let mut scalar_weights = Poly::<EF>::zero(k);
1164            let mut scalar_sum = EF::ZERO;
1165            statement.combine(&mut scalar_weights, &mut scalar_sum, challenge, shift);
1166
1167            // Packed path: combine into a 2^{k - k_pack} packed list.
1168            let mut packed_weights = Poly::<PackedExt>::zero(k - k_pack);
1169            let mut packed_sum = EF::ZERO;
1170            statement.combine_packed(&mut packed_weights, &mut packed_sum, challenge, shift);
1171
1172            // Unpack the packed result and compare element-by-element.
1173            let unpacked =
1174                <PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
1175                    packed_weights.as_slice().iter().copied(),
1176                )
1177                .collect::<Vec<_>>();
1178            prop_assert_eq!(scalar_weights.as_slice(), &unpacked[..]);
1179
1180            // The scalar sums must match exactly.
1181            prop_assert_eq!(scalar_sum, packed_sum);
1182        }
1183
1184        #[test]
1185        fn prop_packed_combine_accumulation(
1186            k in 4usize..10,
1187            seed in 0u64..50,
1188        ) {
1189            type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
1190
1191            let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
1192            if k < k_pack {
1193                return Ok(());
1194            }
1195
1196            let mut rng = SmallRng::seed_from_u64(seed);
1197            let challenge: EF = rng.random();
1198
1199            let mut s_wt = Poly::<EF>::zero(k);
1200            let mut p_wt = Poly::<PackedExt>::zero(k - k_pack);
1201            let mut s_sum = EF::ZERO;
1202            let mut p_sum = EF::ZERO;
1203            let mut shift = 0;
1204
1205            // Two batches with different constraint counts.
1206            for n in [3, 7] {
1207                let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
1208                let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
1209                let stmt = SelectStatement::<F, EF>::new(k, vars, evals);
1210
1211                stmt.combine(&mut s_wt, &mut s_sum, challenge, shift);
1212                stmt.combine_packed(&mut p_wt, &mut p_sum, challenge, shift);
1213                shift += stmt.len();
1214            }
1215
1216            // Verify accumulated results match after both batches.
1217            let unpacked =
1218                <PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
1219                    p_wt.as_slice().iter().copied(),
1220                )
1221                .collect::<Vec<_>>();
1222            prop_assert_eq!(s_wt.as_slice(), &unpacked[..]);
1223            prop_assert_eq!(s_sum, p_sum);
1224        }
1225    }
1226}