use alloc::vec::Vec;
use itertools::Itertools;
use p3_field::{
ExtensionField, Field, PackedFieldExtension, PackedValue, PrimeCharacteristicRing, dot_product,
};
use p3_matrix::Matrix;
use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixView};
use p3_maybe_rayon::prelude::*;
use p3_multilinear_util::eq_batch::eval_eq_batch;
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use p3_util::log2_strict_usize;
use tracing::instrument;
fn batch_eqs<F: Field, EF: ExtensionField<F>>(
points: RowMajorMatrixView<'_, EF>,
alpha: EF,
) -> RowMajorMatrix<EF> {
let k = points.height();
let n = points.width();
assert_ne!(n, 0);
let mut mat = RowMajorMatrix::new(EF::zero_vec(n * (1 << k)), n);
mat.row_mut(0).copy_from_slice(&alpha.powers().collect_n(n));
points.row_slices().enumerate().for_each(|(i, vars)| {
let (mut lo, mut hi) = mat.split_rows_mut(1 << i);
lo.rows_mut().zip(hi.rows_mut()).for_each(|(lo, hi)| {
vars.iter()
.zip(lo.iter_mut().zip(hi.iter_mut()))
.for_each(|(&var, (lo, hi))| {
*hi = *lo * var;
*lo -= *hi;
});
});
});
mat
}
fn packed_batch_eqs<F: Field, EF: ExtensionField<F>>(
points: RowMajorMatrixView<'_, EF>,
) -> RowMajorMatrix<EF::ExtensionPacking> {
let k = points.height();
let n = points.width();
assert_ne!(n, 0);
let k_pack = log2_strict_usize(F::Packing::WIDTH);
assert!(k >= k_pack);
let (init_vars, rest_vars) = points.split_rows(k_pack);
let mut mat = RowMajorMatrix::new(EF::ExtensionPacking::zero_vec(n * (1 << (k - k_pack))), n);
if k_pack > 0 {
init_vars
.transpose()
.row_slices()
.zip(mat.values.iter_mut())
.for_each(|(vars, packed)| {
let point = vars.iter().rev().copied().collect::<Vec<_>>();
*packed = EF::ExtensionPacking::from_ext_slice(
Poly::new_from_point(&point, EF::ONE).as_slice(),
);
});
} else {
mat.row_mut(0).fill(EF::ExtensionPacking::ONE);
}
rest_vars.row_slices().enumerate().for_each(|(i, vars)| {
let (mut lo, mut hi) = mat.split_rows_mut(1 << i);
lo.rows_mut().zip(hi.rows_mut()).for_each(|(lo, hi)| {
vars.iter()
.zip(lo.iter_mut().zip(hi.iter_mut()))
.for_each(|(&var, (lo, hi))| {
*hi = *lo * var;
*lo -= *hi;
});
});
});
mat
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct EqStatement<F> {
num_variables: usize,
pub points: Vec<Point<F>>,
pub evaluations: Vec<F>,
}
impl<F: Field> EqStatement<F> {
#[must_use]
pub const fn initialize(num_variables: usize) -> Self {
Self {
num_variables,
points: Vec::new(),
evaluations: Vec::new(),
}
}
#[must_use]
pub const fn num_variables(&self) -> usize {
self.num_variables
}
#[must_use]
pub const fn is_empty(&self) -> bool {
debug_assert!(self.points.is_empty() == self.evaluations.is_empty());
self.points.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&Point<F>, &F)> {
self.points.iter().zip(self.evaluations.iter())
}
#[must_use]
pub const fn len(&self) -> usize {
debug_assert!(self.points.len() == self.evaluations.len());
self.points.len()
}
#[must_use]
pub fn verify(&self, poly: &Poly<F>) -> bool {
self.iter()
.all(|(point, &expected_eval)| poly.eval_base(point) == expected_eval)
}
pub fn concatenate(&mut self, other: &Self) {
assert_eq!(self.num_variables, other.num_variables);
self.points.extend_from_slice(&other.points);
self.evaluations.extend_from_slice(&other.evaluations);
}
pub fn add_evaluated_constraint(&mut self, point: Point<F>, eval: F) {
assert_eq!(point.num_variables(), self.num_variables());
self.points.push(point);
self.evaluations.push(eval);
}
#[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
pub fn combine_hypercube<Base, const INITIALIZED: bool>(
&self,
acc_weights: &mut Poly<F>,
acc_sum: &mut F,
challenge: F,
) where
Base: Field,
F: ExtensionField<Base>,
{
if self.points.is_empty() {
return;
}
let num_constraints = self.len();
let challenges = challenge.powers().collect_n(num_constraints);
let points_matrix = Point::transpose(&self.points, false);
eval_eq_batch::<Base, F, INITIALIZED>(
points_matrix.as_view(),
acc_weights.as_mut_slice(),
&challenges,
);
*acc_sum +=
dot_product::<F, _, _>(self.evaluations.iter().copied(), challenges.into_iter());
}
#[instrument(skip_all, fields(num_constraints = self.len(), num_variables = self.num_variables()))]
pub fn combine_hypercube_packed<Base, const INITIALIZED: bool>(
&self,
weights: &mut Poly<F::ExtensionPacking>,
sum: &mut F,
challenge: F,
) where
Base: Field,
F: ExtensionField<Base>,
{
if self.points.is_empty() {
return;
}
let k = self.num_variables();
let k_pack = log2_strict_usize(Base::Packing::WIDTH);
assert!(k >= k_pack);
assert_eq!(weights.num_variables() + k_pack, k);
self.combine_evals(sum, challenge);
if k_pack * 2 > k {
self.points
.iter()
.zip(challenge.powers())
.enumerate()
.for_each(|(i, (point, challenge))| {
let eq = Poly::new_from_point(point.as_slice(), challenge);
weights
.as_mut_slice()
.iter_mut()
.zip_eq(eq.as_slice().chunks(Base::Packing::WIDTH))
.for_each(|(out, chunk)| {
let packed = F::ExtensionPacking::from_ext_slice(chunk);
if INITIALIZED || i > 0 {
*out += packed;
} else {
*out = packed;
}
});
});
return;
}
let points = Point::transpose(&self.points, true);
let (left, right) = points.split_rows(k / 2);
let left = packed_batch_eqs::<Base, F>(left);
let right = batch_eqs::<Base, F>(right, challenge);
weights
.as_mut_slice()
.par_chunks_mut(left.height())
.zip_eq(right.par_row_slices())
.for_each(|(out, right)| {
out.iter_mut().zip(left.rows()).for_each(|(out, left)| {
if INITIALIZED {
*out +=
dot_product::<F::ExtensionPacking, _, _>(left, right.iter().copied());
} else {
*out = dot_product(left, right.iter().copied());
}
});
});
}
pub fn combine_evals(&self, claimed_eval: &mut F, gamma: F) {
*claimed_eval += dot_product(self.evaluations.iter().copied(), gamma.powers());
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use p3_field::extension::BinomialExtensionField;
use proptest::prelude::*;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use super::*;
impl<F: Field> EqStatement<F> {
#[must_use]
pub fn new_hypercube(points: Vec<Point<F>>, evaluations: Vec<F>) -> Self {
assert_eq!(
points.len(),
evaluations.len(),
"Number of points ({}) must match number of evaluations ({})",
points.len(),
evaluations.len()
);
let num_variables = points
.iter()
.map(Point::num_variables)
.all_equal_value()
.unwrap();
Self {
num_variables,
points,
evaluations,
}
}
}
type F = BabyBear;
type EF = BinomialExtensionField<F, 4>;
#[test]
fn test_statement_combine_single_constraint() {
let mut statement = EqStatement::initialize(1);
let point = Point::new(vec![F::ONE]);
let expected_eval = F::from_u64(7);
statement.add_evaluated_constraint(point.clone(), expected_eval);
let challenge = F::from_u64(2); let mut combined_evals = Poly::zero(statement.num_variables());
let mut combined_sum = F::ZERO;
statement.combine_hypercube::<_, false>(&mut combined_evals, &mut combined_sum, challenge);
let expected_combined_evals = Poly::new_from_point(point.as_slice(), F::ONE);
assert_eq!(combined_evals, expected_combined_evals);
assert_eq!(combined_sum, expected_eval);
}
#[test]
fn test_statement_with_multiple_constraints() {
let mut statement = EqStatement::initialize(2);
let point1 = Point::new(vec![F::ONE, F::ZERO]);
let eval1 = F::from_u64(5);
statement.add_evaluated_constraint(point1.clone(), eval1);
let point2 = Point::new(vec![F::ZERO, F::ONE]);
let eval2 = F::from_u64(7);
statement.add_evaluated_constraint(point2.clone(), eval2);
let challenge = F::from_u64(2);
let mut combined_evals = Poly::zero(statement.num_variables());
let mut combined_sum = F::ZERO;
statement.combine_hypercube::<_, false>(&mut combined_evals, &mut combined_sum, challenge);
let expected_eq1 = Poly::new_from_point(point1.as_slice(), F::ONE);
let expected_eq2 = Poly::new_from_point(point2.as_slice(), challenge);
let expected_combined_evals = Poly::new(
expected_eq1
.iter()
.zip(expected_eq2.iter())
.map(|(&a, &b)| a + b)
.collect(),
);
let expected_combined_sum = eval1 + challenge * eval2;
assert_eq!(combined_evals, expected_combined_evals);
assert_eq!(combined_sum, expected_combined_sum);
}
#[test]
fn test_compute_evaluation_weight() {
let point = Point::new(vec![F::from_u64(3)]);
let folding_randomness = Point::new(vec![F::from_u64(2)]);
let expected = point.eq_poly(&folding_randomness);
assert_eq!(point.eq_poly(&folding_randomness), expected);
}
#[test]
fn test_constructors_and_basic_properties() {
let point = Point::new(vec![F::ONE]);
let eval = F::from_u64(42);
let statement = EqStatement::new_hypercube(vec![point], vec![eval]);
assert_eq!(statement.num_variables(), 1);
assert_eq!(statement.len(), 1);
assert!(!statement.is_empty());
let empty_statement = EqStatement::<F>::initialize(2);
assert_eq!(empty_statement.num_variables(), 2);
assert_eq!(empty_statement.len(), 0);
assert!(empty_statement.is_empty());
}
#[test]
fn test_verify_constraints() {
let poly = Poly::new(vec![F::from_u64(1), F::from_u64(2)]);
let mut statement = EqStatement::<F>::initialize(1);
statement.add_evaluated_constraint(Point::new(vec![F::ZERO]), F::from_u64(1));
assert!(statement.verify(&poly));
statement.add_evaluated_constraint(Point::new(vec![F::ONE]), F::from_u64(5));
assert!(!statement.verify(&poly));
}
#[test]
fn test_concatenate() {
let mut statement1 = EqStatement::<F>::initialize(1);
let mut statement2 = EqStatement::<F>::initialize(1);
statement1.add_evaluated_constraint(Point::new(vec![F::ZERO]), F::from_u64(10));
statement2.add_evaluated_constraint(Point::new(vec![F::ONE]), F::from_u64(20));
statement1.concatenate(&statement2);
assert_eq!(statement1.len(), 2);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn test_concatenate_mismatched_variables() {
let mut statement1 = EqStatement::<F>::initialize(2);
let statement2 = EqStatement::<F>::initialize(3);
statement1.concatenate(&statement2); }
#[test]
fn test_add_evaluated_constraint() {
let poly = Poly::new(vec![F::from_u64(1), F::from_u64(2)]);
let point = Point::new(vec![F::ZERO]);
let mut statement = EqStatement::<F>::initialize(1);
let eval = poly.eval_base(&point);
statement.add_evaluated_constraint(point, eval);
assert_eq!(statement.len(), 1);
assert!(statement.verify(&poly));
assert_eq!(statement.points.len(), 1);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn test_wrong_variable_count() {
let mut statement = EqStatement::<F>::initialize(1);
let wrong_point = Point::new(vec![F::ONE, F::ZERO]); statement.add_evaluated_constraint(wrong_point, F::from_u64(5));
}
#[test]
fn test_combine_operations() {
let empty_statement = EqStatement::<F>::initialize(1);
let mut combined_evals = Poly::zero(empty_statement.num_variables());
let mut combined_sum = F::ZERO;
empty_statement.combine_hypercube::<_, false>(
&mut combined_evals,
&mut combined_sum,
F::from_u64(42),
);
assert_eq!(combined_sum, F::ZERO);
let mut statement = EqStatement::<F>::initialize(1);
statement.add_evaluated_constraint(Point::new(vec![F::ZERO]), F::from_u64(3));
statement.add_evaluated_constraint(Point::new(vec![F::ONE]), F::from_u64(7));
let mut claimed_eval = F::ZERO;
statement.combine_evals(&mut claimed_eval, F::from_u64(2));
assert_eq!(claimed_eval, F::from_u64(17));
}
proptest! {
#[test]
fn prop_statement_workflow(
poly_evals in prop::collection::vec(0u32..100, 16),
challenge in 1u32..50,
point_coords in prop::collection::vec(0u32..10, 8),
) {
let poly = Poly::new(poly_evals.into_iter().map(F::from_u32).collect());
let mut statement = EqStatement::<F>::initialize(4);
let point1 = Point::new(vec![
F::from_u32(point_coords[0]), F::from_u32(point_coords[1]),
F::from_u32(point_coords[2]), F::from_u32(point_coords[3])
]);
let point2 = Point::new(vec![
F::from_u32(point_coords[4]), F::from_u32(point_coords[5]),
F::from_u32(point_coords[6]), F::from_u32(point_coords[7])
]);
let eval1 = poly.eval_base(&point1);
let eval2 = poly.eval_base(&point2);
statement.add_evaluated_constraint(point1, eval1);
statement.add_evaluated_constraint(point2, eval2);
prop_assert!(statement.verify(&poly));
let gamma = F::from_u32(challenge);
let mut combined_poly = Poly::zero(statement.num_variables());
let mut combined_sum = F::ZERO;
statement.combine_hypercube::<_, false>(&mut combined_poly, &mut combined_sum, gamma);
prop_assert_eq!(combined_poly.num_variables(), 4);
let mut claimed_eval = F::ZERO;
statement.combine_evals(&mut claimed_eval, gamma);
prop_assert_eq!(combined_sum, claimed_eval);
let wrong_point = Point::new(vec![F::ZERO, F::ZERO, F::ZERO, F::ZERO]);
let wrong_eval = F::from_u32(999);
let actual_eval = poly.eval_base(&wrong_point);
if wrong_eval != actual_eval {
statement.add_evaluated_constraint(wrong_point, wrong_eval);
prop_assert!(!statement.verify(&poly));
}
}
}
#[test]
#[should_panic(expected = "Number of points (2) must match number of evaluations (1)")]
fn test_new_mismatched_lengths() {
let points = vec![Point::new(vec![F::ONE]), Point::new(vec![F::ZERO])];
let evaluations = vec![F::from_u64(100)];
let _ = EqStatement::new_hypercube(points, evaluations);
}
proptest! {
#[test]
fn prop_packed_combine_roundtrip(
k in 4usize..10,
n in 1usize..12,
seed in 0u64..100,
) {
let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
if k < k_pack {
return Ok(());
}
let mut rng = SmallRng::seed_from_u64(seed);
let challenge: EF = rng.random();
let points = (0..n)
.map(|_| Point::rand(&mut rng, k))
.collect::<Vec<_>>();
let evals = (0..n).map(|_| rng.random()).collect::<Vec<EF>>();
let statement = EqStatement::<EF>::new_hypercube(points, evals);
let mut scalar_weights = Poly::<EF>::zero(k);
let mut scalar_sum = EF::ZERO;
statement.combine_hypercube::<F, false>(
&mut scalar_weights, &mut scalar_sum, challenge,
);
let mut packed_weights =
Poly::<<EF as ExtensionField<F>>::ExtensionPacking>::zero(k - k_pack);
let mut packed_sum = EF::ZERO;
statement.combine_hypercube_packed::<F, false>(
&mut packed_weights, &mut packed_sum, challenge,
);
let unpacked =
<<EF as ExtensionField<F>>::ExtensionPacking as PackedFieldExtension<F, EF>>::to_ext_iter(
packed_weights.as_slice().iter().copied(),
)
.collect::<Vec<_>>();
prop_assert_eq!(scalar_weights.as_slice(), &unpacked[..]);
prop_assert_eq!(scalar_sum, packed_sum);
}
#[test]
fn prop_packed_combine_accumulation(
k in 4usize..10,
seed in 0u64..50,
) {
let k_pack = log2_strict_usize(<F as Field>::Packing::WIDTH);
if k < k_pack {
return Ok(());
}
let mut rng = SmallRng::seed_from_u64(seed);
let challenge: EF = rng.random();
let points1 = (0..3)
.map(|_| Point::rand(&mut rng, k))
.collect::<Vec<_>>();
let evals1 = (0..3).map(|_| rng.random()).collect::<Vec<EF>>();
let stmt1 = EqStatement::<EF>::new_hypercube(points1, evals1);
let mut s_wt = Poly::<EF>::zero(k);
let mut s_sum = EF::ZERO;
stmt1.combine_hypercube::<F, false>(&mut s_wt, &mut s_sum, challenge);
let mut p_wt =
Poly::<<EF as ExtensionField<F>>::ExtensionPacking>::zero(k - k_pack);
let mut p_sum = EF::ZERO;
stmt1.combine_hypercube_packed::<F, false>(&mut p_wt, &mut p_sum, challenge);
let points2 = (0..5)
.map(|_| Point::rand(&mut rng, k))
.collect::<Vec<_>>();
let evals2 = (0..5).map(|_| rng.random()).collect::<Vec<EF>>();
let stmt2 = EqStatement::<EF>::new_hypercube(points2, evals2);
stmt2.combine_hypercube::<F, true>(&mut s_wt, &mut s_sum, challenge);
stmt2.combine_hypercube_packed::<F, true>(&mut p_wt, &mut p_sum, challenge);
let unpacked =
<<EF as ExtensionField<F>>::ExtensionPacking as PackedFieldExtension<F, EF>>::to_ext_iter(
p_wt.as_slice().iter().copied(),
)
.collect::<Vec<_>>();
prop_assert_eq!(s_wt.as_slice(), &unpacked[..]);
prop_assert_eq!(s_sum, p_sum);
}
}
}