use alloc::vec::Vec;
use itertools::Itertools;
use p3_field::{
ExtensionField, Field, HornerIter, PackedFieldExtension, PackedValue, PrimeCharacteristicRing,
dot_product,
};
use p3_matrix::Matrix;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_maybe_rayon::prelude::*;
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use p3_util::log2_strict_usize;
use tracing::instrument;
fn batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F> {
let k = points.height();
let n = points.width();
let mut mat = RowMajorMatrix::new(F::zero_vec(n * (1 << k)), n);
mat.row_mut(0).fill(F::ONE);
points.row_slices().enumerate().for_each(|(i, vars)| {
let (lo, mut hi) = mat.split_rows_mut(1 << i);
lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
vars.iter()
.zip(lo.zip(hi.iter_mut()))
.for_each(|(&var, (lo, hi))| *hi = lo * var);
});
});
mat
}
fn packed_batch_pows<F: Field>(points: RowMajorMatrixView<'_, F>) -> RowMajorMatrix<F::Packing> {
let k = points.height();
let n = points.width();
assert_ne!(n, 0);
let k_pack = log2_strict_usize(F::Packing::WIDTH);
assert!(k >= k_pack);
let (init_vars, rest_vars) = points.split_rows(k_pack);
let mut mat = RowMajorMatrix::new(F::Packing::zero_vec(n * (1 << (k - k_pack))), n);
if k_pack > 0 {
init_vars
.transpose()
.row_slices()
.zip(mat.values.iter_mut())
.for_each(|(vars, packed)| {
let point = RowMajorMatrixView::new(vars, 1);
*packed = *F::Packing::from_slice(&batch_pows(point).values);
});
} else {
mat.row_mut(0).fill(F::Packing::ONE);
}
rest_vars.row_slices().enumerate().for_each(|(i, vars)| {
let (lo, mut hi) = mat.split_rows_mut(1 << i);
lo.rows().zip(hi.rows_mut()).for_each(|(lo, hi)| {
vars.iter()
.zip(lo.zip(hi.iter_mut()))
.for_each(|(&var, (lo, hi))| *hi = lo * var);
});
});
mat
}
#[derive(Clone, Debug)]
pub struct SelectStatement<F, EF> {
num_variables: usize,
pub(crate) vars: Vec<F>,
evaluations: Vec<EF>,
}
impl<F: Field, EF: ExtensionField<F>> SelectStatement<F, EF> {
#[must_use]
pub const fn initialize(num_variables: usize) -> Self {
Self {
num_variables,
vars: Vec::new(),
evaluations: Vec::new(),
}
}
#[must_use]
pub const fn new(num_variables: usize, vars: Vec<F>, evaluations: Vec<EF>) -> Self {
assert!(vars.len() == evaluations.len());
Self {
num_variables,
vars,
evaluations,
}
}
#[must_use]
pub const fn num_variables(&self) -> usize {
self.num_variables
}
#[must_use]
pub const fn is_empty(&self) -> bool {
debug_assert!(self.vars.is_empty() == self.evaluations.is_empty());
self.vars.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&F, &EF)> {
self.vars.iter().zip(self.evaluations.iter())
}
#[must_use]
pub const fn len(&self) -> usize {
debug_assert!(self.vars.len() == self.evaluations.len());
self.vars.len()
}
#[must_use]
pub fn verify(&self, poly: &Poly<EF>) -> bool {
self.iter().all(|(&var, &expected_eval)| {
poly.iter().copied().horner::<EF, _>(var) == expected_eval
})
}
pub fn add_constraint(&mut self, var: F, eval: EF) {
self.vars.push(var);
self.evaluations.push(eval);
}
#[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
pub fn combine(
&self,
acc_weights: &mut Poly<EF>,
acc_sum: &mut EF,
challenge: EF,
shift: usize,
) {
if self.vars.is_empty() {
return;
}
let n = self.len();
let k = self.num_variables();
let mut pow_matrix = F::zero_vec(k * n);
for (j, &var) in self.vars.iter().enumerate() {
let mut v = var;
for i in 0..k {
pow_matrix[i * n + j] = v;
v = v.square();
}
}
let mut acc = F::zero_vec((1 << k) * n);
acc[..n].fill(F::ONE);
for i in 0..k {
let num_existing_rows = 1 << i;
let (lo, hi) = acc.split_at_mut(num_existing_rows * n);
let pow_row = &pow_matrix[i * n..(i + 1) * n];
lo.par_chunks_mut(n)
.zip(hi.par_chunks_mut(n))
.for_each(|(lo_row, hi_row)| {
pow_row
.iter()
.zip(lo_row.iter())
.zip(hi_row.iter_mut())
.for_each(|((&z_pow, &lo_val), hi_val)| {
*hi_val = lo_val * z_pow;
});
});
}
let challenges = challenge
.shifted_powers(challenge.exp_u64(shift as u64))
.collect_n(n);
acc.par_chunks(n)
.zip(acc_weights.as_mut_slice().par_iter_mut())
.for_each(|(row, weight_out)| {
*weight_out +=
dot_product::<EF, _, _>(challenges.iter().copied(), row.iter().copied());
});
*acc_sum +=
dot_product::<EF, _, _>(challenges.into_iter(), self.evaluations.iter().copied());
}
#[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
pub fn combine_packed(
&self,
weights: &mut Poly<EF::ExtensionPacking>,
sum: &mut EF,
challenge: EF,
shift: usize,
) {
if self.vars.is_empty() {
return;
}
let n = self.len();
let k = self.num_variables();
let k_pack = log2_strict_usize(F::Packing::WIDTH);
assert!(k >= k_pack);
assert_eq!(weights.num_variables() + k_pack, k);
self.combine_evals(sum, challenge, shift);
if k_pack * 2 > k {
self.vars
.iter()
.zip(challenge.shifted_powers(challenge.exp_u64(shift as u64)))
.for_each(|(&var, challenge)| {
let pow = EF::from(var).shifted_powers(challenge).collect_n(1 << k);
weights
.as_mut_slice()
.iter_mut()
.zip_eq(pow.chunks(F::Packing::WIDTH))
.for_each(|(out, chunk)| {
*out += EF::ExtensionPacking::from_ext_slice(chunk);
});
});
return;
}
let points = self
.vars
.iter()
.map(|&var| Point::expand_from_univariate(var, k))
.collect::<Vec<_>>();
let points = Point::transpose(&points, true);
let (left, right) = points.split_rows(k / 2);
let left = packed_batch_pows(left);
let right = batch_pows(right);
let alphas = challenge
.shifted_powers(challenge.exp_u64(shift as u64))
.collect_n(n)
.into_iter()
.map(EF::ExtensionPacking::from)
.collect::<Vec<_>>();
weights
.as_mut_slice()
.par_chunks_mut(left.height())
.zip(right.par_row_slices())
.for_each(|(out, right)| {
out.iter_mut().zip(left.rows()).for_each(|(out, left)| {
*out += left
.zip(right.iter())
.zip(alphas.iter())
.map(|((left, &right), &alpha)| alpha * (left * right))
.sum::<EF::ExtensionPacking>();
});
});
}
pub fn combine_evals(&self, claimed_eval: &mut EF, challenge: EF, shift: usize) {
*claimed_eval += dot_product::<EF, _, _>(
self.evaluations.iter().copied(),
challenge
.shifted_powers(challenge.exp_u64(shift as u64))
.take(self.len()),
);
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_baby_bear::BabyBear;
use p3_field::extension::BinomialExtensionField;
use p3_field::{PackedFieldExtension, PrimeCharacteristicRing};
use proptest::prelude::*;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use super::*;
type F = BabyBear;
type EF = BinomialExtensionField<F, 4>;
#[test]
fn test_select_statement_initialize() {
let statement = SelectStatement::<F, F>::initialize(3);
assert_eq!(statement.num_variables(), 3);
assert!(statement.is_empty());
assert_eq!(statement.len(), 0);
}
#[test]
fn test_select_statement_new() {
let vars = vec![F::from_u64(5), F::from_u64(7)];
let evaluations = vec![F::from_u64(10), F::from_u64(20)];
let statement = SelectStatement::new(2, vars.clone(), evaluations.clone());
assert_eq!(statement.num_variables(), 2);
assert!(!statement.is_empty());
assert_eq!(statement.len(), 2);
assert_eq!(statement.vars, vars);
assert_eq!(statement.evaluations, evaluations);
}
#[test]
#[should_panic(expected = "assertion")]
fn test_select_statement_new_mismatched_lengths() {
let vars = vec![F::from_u64(5)];
let evaluations = vec![F::from_u64(10), F::from_u64(20)];
let _ = SelectStatement::new(2, vars, evaluations);
}
#[test]
fn test_select_statement_add_constraint() {
let mut statement = SelectStatement::<F, F>::initialize(2);
assert!(statement.is_empty());
assert_eq!(statement.len(), 0);
statement.add_constraint(F::from_u64(5), F::from_u64(10));
assert!(!statement.is_empty());
assert_eq!(statement.len(), 1);
statement.add_constraint(F::from_u64(7), F::from_u64(20));
assert_eq!(statement.len(), 2);
let constraints: Vec<_> = statement.iter().collect();
assert_eq!(constraints.len(), 2);
assert_eq!(*constraints[0].0, F::from_u64(5));
assert_eq!(*constraints[0].1, F::from_u64(10));
assert_eq!(*constraints[1].0, F::from_u64(7));
assert_eq!(*constraints[1].1, F::from_u64(20));
}
#[test]
fn test_select_statement_verify_basic() {
let c0 = F::from_u64(1);
let c1 = F::from_u64(2);
let c2 = F::from_u64(3);
let c3 = F::from_u64(4);
let poly = Poly::new(vec![c0, c1, c2, c3]);
let k = 2;
let mut statement = SelectStatement::<F, F>::initialize(k);
let z0 = F::ZERO;
let eval0 = c0;
statement.add_constraint(z0, eval0);
assert!(statement.verify(&poly));
let mut statement2 = SelectStatement::<F, F>::initialize(k);
let z1 = F::ONE;
let eval1 = c0 + c1 + c2 + c3;
statement2.add_constraint(z1, eval1);
assert!(statement2.verify(&poly));
let mut statement3 = SelectStatement::<F, F>::initialize(k);
let z2 = F::from_u64(2);
let eval2 = c0 + c1 * z2 + c2 * z2 * z2 + c3 * z2 * z2 * z2;
statement3.add_constraint(z2, eval2);
assert!(statement3.verify(&poly));
let mut statement4 = SelectStatement::<F, F>::initialize(k);
let wrong_eval = F::from_u64(56765);
statement4.add_constraint(z1, wrong_eval);
assert!(!statement4.verify(&poly));
}
#[test]
fn test_select_statement_combine_single_constraint() {
let k = 2;
let mut statement = SelectStatement::<F, F>::initialize(k);
let z = F::from_u64(5);
let s = F::from_u64(100);
statement.add_constraint(z, s);
let gamma = F::from_u64(2);
let shift = 0;
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
let expected_sum = s;
assert_eq!(acc_sum, expected_sum);
for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
let expected_weight = z.exp_u64(b as u64);
assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
}
}
#[test]
fn test_select_statement_combine_multiple_constraints() {
let k = 2;
let mut statement = SelectStatement::<F, F>::initialize(k);
let z0 = F::from_u64(3);
let s0 = F::from_u64(10);
let z1 = F::from_u64(7);
let s1 = F::from_u64(20);
statement.add_constraint(z0, s0);
statement.add_constraint(z1, s1);
let gamma = F::from_u64(2);
let shift = 0;
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
let expected_sum = s0 + gamma * s1;
assert_eq!(acc_sum, expected_sum);
for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
let weight0 = z0.exp_u64(b as u64);
let weight1 = z1.exp_u64(b as u64);
let expected_weight = weight0 + gamma * weight1;
assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
}
}
#[test]
fn test_select_statement_combine_with_shift() {
let k = 1;
let mut statement = SelectStatement::<F, F>::initialize(k);
let z = F::from_u64(5);
let s = F::from_u64(100);
statement.add_constraint(z, s);
let gamma = F::from_u64(2);
let shift = 3;
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
let gamma_to_shift = gamma.exp_u64(shift as u64);
let expected_sum = gamma_to_shift * s;
assert_eq!(acc_sum, expected_sum);
for (b, acc_weight) in acc_weights.as_slice().iter().enumerate() {
let select_val = z.exp_u64(b as u64);
let expected_weight = gamma_to_shift * select_val;
assert_eq!(*acc_weight, expected_weight, "Weight mismatch at index {b}");
}
}
#[test]
fn test_select_statement_combine_empty() {
let k = 2;
let statement = SelectStatement::<F, F>::initialize(k);
let w0 = F::from_u64(1);
let w1 = F::from_u64(2);
let w2 = F::from_u64(3);
let w3 = F::from_u64(4);
let mut acc_weights = Poly::new(vec![w0, w1, w2, w3]);
let initial_sum = F::from_u64(99);
let mut acc_sum = initial_sum;
let original_weights = acc_weights.clone();
let original_sum = acc_sum;
let gamma = F::from_u64(2);
let shift = 0;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
assert_eq!(acc_weights, original_weights);
assert_eq!(acc_sum, original_sum);
}
#[test]
fn test_select_statement_combine_accumulation() {
let k = 1;
let mut statement1 = SelectStatement::<F, F>::initialize(k);
let z1 = F::from_u64(2);
let s1 = F::from_u64(5);
statement1.add_constraint(z1, s1);
let mut statement2 = SelectStatement::<F, F>::initialize(k);
let z2 = F::from_u64(3);
let s2 = F::from_u64(7);
statement2.add_constraint(z2, s2);
let gamma = F::from_u64(2);
let shift = 0;
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement1.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
let intermediate_weights = acc_weights.clone();
let intermediate_sum = acc_sum;
statement2.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
let expected_sum = intermediate_sum + s2;
assert_eq!(acc_sum, expected_sum);
let domain_size = 1 << k;
for b in 0..domain_size {
let weight2 = z2.exp_u64(b as u64);
let expected_weight = intermediate_weights.as_slice()[b] + weight2;
assert_eq!(
acc_weights.as_slice()[b],
expected_weight,
"Accumulated weight mismatch at index {b}"
);
}
}
#[test]
fn test_select_statement_combine_evals() {
let k = 2;
let mut statement = SelectStatement::<F, F>::initialize(k);
let s0 = F::from_u64(10);
let s1 = F::from_u64(20);
statement.add_constraint(F::from_u64(3), s0);
statement.add_constraint(F::from_u64(7), s1);
let gamma = F::from_u64(2);
let shift = 1;
let mut claimed_eval = F::ZERO;
statement.combine_evals(&mut claimed_eval, gamma, shift);
let gamma_1 = gamma.exp_u64(shift as u64);
let gamma_2 = gamma.exp_u64((shift + 1) as u64);
let expected = gamma_1 * s0 + gamma_2 * s1;
assert_eq!(claimed_eval, expected);
}
#[test]
fn test_select_statement_combine_evals_accumulation() {
let k = 1;
let mut statement = SelectStatement::<F, F>::initialize(k);
let s = F::from_u64(10);
statement.add_constraint(F::from_u64(5), s);
let gamma = F::from_u64(3);
let shift = 0;
let initial_eval = F::from_u64(42);
let mut claimed_eval = initial_eval;
statement.combine_evals(&mut claimed_eval, gamma, shift);
let expected = initial_eval + s;
assert_eq!(claimed_eval, expected);
}
#[test]
fn test_select_combine_consistency_with_verify() {
let k = 2;
let c0 = F::from_u64(1);
let c1 = F::from_u64(2);
let c2 = F::from_u64(3);
let c3 = F::from_u64(4);
let poly = Poly::new(vec![c0, c1, c2, c3]);
let mut statement = SelectStatement::<F, F>::initialize(k);
let z = F::from_u64(2);
let expected_eval: F = poly.iter().copied().horner(z);
statement.add_constraint(z, expected_eval);
assert!(statement.verify(&poly));
let gamma = F::from_u64(3);
let shift = 0;
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, shift);
assert_eq!(acc_sum, expected_eval);
let mut computed_sum = F::ZERO;
for (poly_val, acc_weight) in poly.as_slice().iter().zip(acc_weights.as_slice().iter()) {
computed_sum += *poly_val * *acc_weight;
}
assert_eq!(computed_sum, expected_eval);
}
proptest! {
#[test]
fn prop_select_statement_combine_sum(
k in 1usize..=4,
num_constraints in 1usize..=5,
z_values in prop::collection::vec(1u32..100, 1..=5),
s_values in prop::collection::vec(0u32..100, 1..=5),
challenge in 1u32..50,
) {
let actual_num_constraints = num_constraints.min(z_values.len()).min(s_values.len());
if actual_num_constraints == 0 {
return Ok(());
}
let z_values = &z_values[..actual_num_constraints];
let s_values = &s_values[..actual_num_constraints];
let mut statement = SelectStatement::<F, F>::initialize(k);
for (&z, &s) in z_values.iter().zip(s_values.iter()) {
statement.add_constraint(F::from_u32(z), F::from_u32(s));
}
let gamma = F::from_u32(challenge);
let mut acc_weights = Poly::zero(k);
let mut acc_sum = F::ZERO;
statement.combine(&mut acc_weights, &mut acc_sum, gamma, 0);
let mut expected_sum = F::ZERO;
for (i, &s) in s_values.iter().enumerate() {
expected_sum += gamma.exp_u64(i as u64) * F::from_u32(s);
}
prop_assert_eq!(acc_sum, expected_sum);
}
}
proptest! {
#[test]
fn prop_select_statement_verify(
poly_evals in prop::collection::vec(0u32..100, 8),
z in 1u32..50,
) {
let k = 3; let poly = Poly::new(poly_evals.into_iter().map(F::from_u32).collect());
let z_field = F::from_u32(z);
let expected_eval: F = poly.iter().copied().horner(z_field);
let mut statement = SelectStatement::<F, F>::initialize(k);
statement.add_constraint(z_field, expected_eval);
prop_assert!(statement.verify(&poly));
let wrong_eval = expected_eval + F::ONE;
if wrong_eval != expected_eval {
statement.add_constraint(z_field, wrong_eval);
prop_assert!(!statement.verify(&poly));
}
}
}
proptest! {
#[test]
fn prop_combine_evals_consistency(
num_constraints in 1usize..=5,
s_values in prop::collection::vec(0u32..100, 1..=5),
challenge in 1u32..50,
shift in 0usize..3,
) {
let s_values = &s_values[..num_constraints.min(s_values.len())];
let mut statement = SelectStatement::<F, F>::initialize(2);
for &s in s_values {
statement.add_constraint(F::from_u32(1), F::from_u32(s));
}
let gamma = F::from_u32(challenge);
let mut claimed_eval1 = F::ZERO;
statement.combine_evals(&mut claimed_eval1, gamma, shift);
let mut claimed_eval2 = F::ZERO;
for (i, &s) in s_values.iter().enumerate() {
claimed_eval2 += gamma.exp_u64((i + shift) as u64) * F::from_u32(s);
}
prop_assert_eq!(claimed_eval1, claimed_eval2);
}
}
proptest! {
#[test]
fn prop_packed_combine_roundtrip(
k in 4usize..10,
n in 1usize..12,
shift in 0usize..5,
seed in 0u64..100,
) {
type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
if k < k_pack {
return Ok(());
}
let mut rng = SmallRng::seed_from_u64(seed);
let challenge: EF = rng.random();
let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
let statement = SelectStatement::<F, EF>::new(k, vars, evals);
let mut scalar_weights = Poly::<EF>::zero(k);
let mut scalar_sum = EF::ZERO;
statement.combine(&mut scalar_weights, &mut scalar_sum, challenge, shift);
let mut packed_weights = Poly::<PackedExt>::zero(k - k_pack);
let mut packed_sum = EF::ZERO;
statement.combine_packed(&mut packed_weights, &mut packed_sum, challenge, shift);
let unpacked =
<PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
packed_weights.as_slice().iter().copied(),
)
.collect::<Vec<_>>();
prop_assert_eq!(scalar_weights.as_slice(), &unpacked[..]);
prop_assert_eq!(scalar_sum, packed_sum);
}
#[test]
fn prop_packed_combine_accumulation(
k in 4usize..10,
seed in 0u64..50,
) {
type PackedExt = <EF as ExtensionField<F>>::ExtensionPacking;
let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
if k < k_pack {
return Ok(());
}
let mut rng = SmallRng::seed_from_u64(seed);
let challenge: EF = rng.random();
let mut s_wt = Poly::<EF>::zero(k);
let mut p_wt = Poly::<PackedExt>::zero(k - k_pack);
let mut s_sum = EF::ZERO;
let mut p_sum = EF::ZERO;
let mut shift = 0;
for n in [3, 7] {
let vars = (0..n).map(|_| rng.random()).collect::<Vec<F>>();
let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
let stmt = SelectStatement::<F, EF>::new(k, vars, evals);
stmt.combine(&mut s_wt, &mut s_sum, challenge, shift);
stmt.combine_packed(&mut p_wt, &mut p_sum, challenge, shift);
shift += stmt.len();
}
let unpacked =
<PackedExt as PackedFieldExtension<F, EF>>::to_ext_iter(
p_wt.as_slice().iter().copied(),
)
.collect::<Vec<_>>();
prop_assert_eq!(s_wt.as_slice(), &unpacked[..]);
prop_assert_eq!(s_sum, p_sum);
}
}
}