use p3_challenger::{FieldChallenger, GrindingChallenger};
use p3_field::{ExtensionField, Field, PackedFieldExtension, PackedValue, dot_product};
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use p3_util::log2_strict_usize;
use tracing::instrument;
use crate::constraints::Constraint;
use crate::strategy::VariableOrder;
use crate::{SumcheckData, extrapolate_01inf};
#[derive(Debug, Clone)]
enum MaybePacked<F: Field, EF: ExtensionField<F>> {
Packed {
evals: Poly<EF::ExtensionPacking>,
weights: Poly<EF::ExtensionPacking>,
},
Unpacked {
evals: Poly<EF>,
weights: Poly<EF>,
},
}
#[derive(Debug, Clone)]
pub struct ProductPolynomial<F: Field, EF: ExtensionField<F>> {
inner: MaybePacked<F, EF>,
order: VariableOrder,
}
impl<F: Field, EF: ExtensionField<F>> ProductPolynomial<F, EF> {
pub fn new_packed(
order: VariableOrder,
evals: Poly<EF::ExtensionPacking>,
weights: Poly<EF::ExtensionPacking>,
) -> Self {
assert_eq!(evals.num_variables(), weights.num_variables());
let mut poly = Self {
inner: MaybePacked::Packed { evals, weights },
order,
};
poly.transition();
poly
}
pub const fn new_unpacked(order: VariableOrder, evals: Poly<EF>, weights: Poly<EF>) -> Self {
Self {
inner: MaybePacked::Unpacked { evals, weights },
order,
}
}
pub const fn order(&self) -> VariableOrder {
self.order
}
pub fn num_variables(&self) -> usize {
match &self.inner {
MaybePacked::Packed { evals, weights } => {
let k = evals.num_variables();
assert_eq!(k, weights.num_variables());
k + log2_strict_usize(F::Packing::WIDTH)
}
MaybePacked::Unpacked { evals, weights } => {
let k = evals.num_variables();
assert_eq!(k, weights.num_variables());
k
}
}
}
pub fn eval(&self, point: &Point<EF>) -> EF {
match &self.inner {
MaybePacked::Packed { evals, .. } => evals.eval_packed(point),
MaybePacked::Unpacked { evals, .. } => evals.eval_ext::<F>(point),
}
}
fn compress(&mut self, r: EF) {
let order = self.order;
match &mut self.inner {
MaybePacked::Packed { evals, weights } => {
order.fix_var(evals, r);
order.fix_var(weights, r);
}
MaybePacked::Unpacked { evals, weights } => {
order.fix_var(evals, r);
order.fix_var(weights, r);
}
}
}
fn transition(&mut self) {
if let MaybePacked::Packed { evals, weights } = &mut self.inner {
let k = evals.num_variables();
assert_eq!(k, weights.num_variables());
if k == 0 {
let evals =
EF::ExtensionPacking::to_ext_iter(evals.as_slice().iter().copied()).collect();
let weights =
EF::ExtensionPacking::to_ext_iter(weights.as_slice().iter().copied()).collect();
*self = Self::new_unpacked(self.order, Poly::new(evals), Poly::new(weights));
}
}
}
#[instrument(skip_all)]
pub fn round<Challenger>(
&mut self,
sumcheck_data: &mut SumcheckData<F, EF>,
challenger: &mut Challenger,
sum: &mut EF,
pow_bits: usize,
) -> EF
where
Challenger: FieldChallenger<F> + GrindingChallenger<Witness = F>,
{
let order = self.order;
let (c0, c_inf) = match &self.inner {
MaybePacked::Packed { evals, weights } => {
let (c0, c_inf) = order.sumcheck_coefficients(evals.as_slice(), weights.as_slice());
(
EF::ExtensionPacking::to_ext_iter([c0]).sum(),
EF::ExtensionPacking::to_ext_iter([c_inf]).sum(),
)
}
MaybePacked::Unpacked { evals, weights } => {
order.sumcheck_coefficients(evals.as_slice(), weights.as_slice())
}
};
let r = sumcheck_data.observe_and_sample(challenger, c0, c_inf, pow_bits);
self.compress(r);
*sum = extrapolate_01inf(c0, *sum - c0, c_inf, r);
debug_assert_eq!(*sum, self.dot_product());
self.transition();
r
}
pub(crate) fn round_coefficients(&self) -> (EF, EF) {
let order = self.order;
match &self.inner {
MaybePacked::Packed { evals, weights } => {
let (c0, c_inf) = order.sumcheck_coefficients(evals.as_slice(), weights.as_slice());
(
EF::ExtensionPacking::to_ext_iter([c0]).sum(),
EF::ExtensionPacking::to_ext_iter([c_inf]).sum(),
)
}
MaybePacked::Unpacked { evals, weights } => {
order.sumcheck_coefficients(evals.as_slice(), weights.as_slice())
}
}
}
pub(crate) fn fold_round(&mut self, r: EF) {
self.compress(r);
self.transition();
}
pub(crate) fn scale_weights(&mut self, scale: EF) {
match &mut self.inner {
MaybePacked::Packed { weights, .. } => {
for value in weights.as_mut_slice() {
*value *= scale;
}
}
MaybePacked::Unpacked { weights, .. } => {
for value in weights.as_mut_slice() {
*value *= scale;
}
}
}
}
pub(crate) fn accumulate_weights(&mut self, delta: &[EF]) {
assert_eq!(delta.len(), 1 << self.num_variables());
match &mut self.inner {
MaybePacked::Packed { weights, .. } => {
let width = F::Packing::WIDTH;
for (value, chunk) in weights.as_mut_slice().iter_mut().zip(delta.chunks(width)) {
*value += EF::ExtensionPacking::from_ext_slice(chunk);
}
}
MaybePacked::Unpacked { weights, .. } => {
for (value, &d) in weights.as_mut_slice().iter_mut().zip(delta) {
*value += d;
}
}
}
}
pub fn evals(&self) -> Poly<EF> {
match &self.inner {
MaybePacked::Packed { evals, .. } => Poly::new(
EF::ExtensionPacking::to_ext_iter(evals.as_slice().iter().copied()).collect(),
),
MaybePacked::Unpacked { evals, .. } => evals.clone(),
}
}
pub fn weights(&self) -> Poly<EF> {
match &self.inner {
MaybePacked::Packed { weights, .. } => Poly::new(
EF::ExtensionPacking::to_ext_iter(weights.as_slice().iter().copied()).collect(),
),
MaybePacked::Unpacked { weights, .. } => weights.clone(),
}
}
pub fn combine(&mut self, sum: &mut EF, constraint: &Constraint<F, EF>) {
match &mut self.inner {
MaybePacked::Packed { weights, .. } => {
constraint.combine_packed(weights, sum);
}
MaybePacked::Unpacked { weights, .. } => {
constraint.combine(weights, sum);
}
}
}
pub fn dot_product(&self) -> EF {
match &self.inner {
MaybePacked::Packed { evals, weights } => {
let sum_packed = dot_product(evals.iter().copied(), weights.iter().copied());
EF::ExtensionPacking::to_ext_iter([sum_packed]).sum()
}
MaybePacked::Unpacked { evals, weights } => {
dot_product(evals.iter().copied(), weights.iter().copied())
}
}
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;
use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
use p3_challenger::DuplexChallenger;
use p3_field::extension::BinomialExtensionField;
use p3_field::{Field, PrimeCharacteristicRing};
use proptest::prelude::*;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use super::*;
use crate::strategy::sumcheck_coefficients_prefix;
type F = BabyBear;
type EF = BinomialExtensionField<BabyBear, 4>;
type Perm = Poseidon2BabyBear<16>;
type TestChallenger = DuplexChallenger<F, Perm, 16, 8>;
fn make_challenger() -> TestChallenger {
let perm = Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(42));
DuplexChallenger::new(perm)
}
#[test]
fn test_num_variables_small_variant() {
let evals = Poly::new(vec![EF::ONE; 8]);
let weights = Poly::new(vec![EF::TWO; 8]);
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
assert_eq!(poly.num_variables(), 3);
}
#[test]
fn test_dot_product_manual_calculation() {
let e0 = EF::from_u64(1);
let e1 = EF::from_u64(2);
let e2 = EF::from_u64(3);
let e3 = EF::from_u64(4);
let w0 = EF::from_u64(5);
let w1 = EF::from_u64(6);
let w2 = EF::from_u64(7);
let w3 = EF::from_u64(8);
let evals = Poly::new(vec![e0, e1, e2, e3]);
let weights = Poly::new(vec![w0, w1, w2, w3]);
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let expected = e0 * w0 + e1 * w1 + e2 * w2 + e3 * w3;
assert_eq!(poly.dot_product(), expected);
}
#[test]
fn test_sumcheck_coefficients_manual_calculation() {
let e0 = EF::from_u64(3);
let e1 = EF::from_u64(7);
let w0 = EF::from_u64(2);
let w1 = EF::from_u64(5);
let evals = Poly::new(vec![e0, e1]);
let weights = Poly::new(vec![w0, w1]);
let (h0, h_inf) = sumcheck_coefficients_prefix(evals.as_slice(), weights.as_slice());
let expected_h0 = e0 * w0;
assert_eq!(h0, expected_h0);
let expected_h_inf = (e1 - e0) * (w1 - w0);
assert_eq!(h_inf, expected_h_inf);
let h_1 = e1 * w1;
let sum = e0 * w0 + e1 * w1;
assert_eq!(h0 + h_1, sum);
}
#[test]
fn test_compress_manual_calculation() {
let e0 = EF::from_u64(1);
let e1 = EF::from_u64(2);
let e2 = EF::from_u64(5);
let e3 = EF::from_u64(8);
let w0 = EF::from_u64(3);
let w1 = EF::from_u64(4);
let w2 = EF::from_u64(6);
let w3 = EF::from_u64(7);
let evals = Poly::new(vec![e0, e1, e2, e3]);
let weights = Poly::new(vec![w0, w1, w2, w3]);
let mut poly =
ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let initial_sum = e0 * w0 + e1 * w1 + e2 * w2 + e3 * w3;
assert_eq!(poly.dot_product(), initial_sum);
let r = EF::from_u64(2);
poly.compress(r);
let folded_evals = poly.evals();
let expected_e0 = e0 + r * (e2 - e0);
let expected_e1 = e1 + r * (e3 - e1);
assert_eq!(folded_evals.as_slice(), &[expected_e0, expected_e1]);
let h_0 = e0 * w0 + e1 * w1;
let a = (e2 - e0) * (w2 - w0) + (e3 - e1) * (w3 - w1);
let h_1 = e2 * w2 + e3 * w3;
let b = h_1 - h_0 - a;
let h_r = h_0 + b * r + a * r.square();
assert_eq!(poly.dot_product(), h_r);
}
#[test]
fn test_eval_multilinear_interpolation() {
let e0 = EF::from_u64(2);
let e1 = EF::from_u64(5);
let e2 = EF::from_u64(3);
let e3 = EF::from_u64(11);
let evals = Poly::new(vec![e0, e1, e2, e3]);
let weights = Poly::new(vec![EF::ONE; 4]);
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let x0 = EF::from_u64(3);
let x1 = EF::from_u64(4);
let point = Point::new(vec![x0, x1]);
let one = EF::ONE;
let expected = e0 * (one - x0) * (one - x1)
+ e1 * (one - x0) * x1
+ e2 * x0 * (one - x1)
+ e3 * x0 * x1;
assert_eq!(poly.eval(&point), expected);
}
#[test]
fn test_transition_from_packed_to_small() {
type EP = <EF as ExtensionField<F>>::ExtensionPacking;
let simd_width = <F as Field>::Packing::WIDTH;
let simd_log = log2_strict_usize(simd_width);
let num_variables = simd_log + 2;
let num_evals = 1 << num_variables;
let evals_scalar = vec![EF::ONE; num_evals];
let weights_scalar = vec![EF::ONE; num_evals];
let packed_evals = Poly::new(
evals_scalar
.chunks(simd_width)
.map(EP::from_ext_slice)
.collect(),
);
let packed_weights = Poly::new(
weights_scalar
.chunks(simd_width)
.map(EP::from_ext_slice)
.collect(),
);
let mut poly = ProductPolynomial::<F, EF>::new_packed(
VariableOrder::Prefix,
packed_evals,
packed_weights,
);
match &poly.inner {
MaybePacked::Packed {
evals: packed_evals,
weights: packed_weights,
} => {
let expected_packed_len = num_evals / simd_width;
assert_eq!(packed_evals.num_evals(), expected_packed_len);
assert_eq!(packed_weights.num_evals(), expected_packed_len);
}
MaybePacked::Unpacked { .. } => {
panic!("Expected Packed variant initially");
}
}
assert_eq!(poly.num_variables(), num_variables);
for _ in 0..2 {
let challenge = EF::from_u64(7);
poly.compress(challenge);
poly.transition();
}
match &poly.inner {
MaybePacked::Unpacked { evals, weights } => {
assert_eq!(evals.num_evals(), simd_width);
assert_eq!(weights.num_evals(), simd_width);
}
MaybePacked::Packed { .. } => {
panic!("Expected Small variant after transition");
}
}
assert_eq!(poly.num_variables(), simd_log);
}
#[test]
fn test_new_packed_with_single_element_transitions() {
type EP = <EF as ExtensionField<F>>::ExtensionPacking;
let simd_width = <F as Field>::Packing::WIDTH;
let evals_scalar: Vec<EF> = (0..simd_width).map(|i| EF::from_u64(i as u64)).collect();
let weights_scalar: Vec<EF> = (0..simd_width)
.map(|i| EF::from_u64(100 + i as u64))
.collect();
let evals = Poly::new(vec![EP::from_ext_slice(&evals_scalar)]);
let weights = Poly::new(vec![EP::from_ext_slice(&weights_scalar)]);
let poly = ProductPolynomial::<F, EF>::new_packed(VariableOrder::Prefix, evals, weights);
match &poly.inner {
MaybePacked::Unpacked {
evals: small_evals,
weights: small_weights,
} => {
assert_eq!(small_evals.as_slice(), &evals_scalar);
assert_eq!(small_weights.as_slice(), &weights_scalar);
}
MaybePacked::Packed { .. } => {
panic!("Expected Small variant after transition from single packed element");
}
}
assert_eq!(poly.num_variables(), log2_strict_usize(simd_width));
}
#[test]
fn test_round_updates_sum_correctly() {
let e0 = EF::from_u64(2);
let e1 = EF::from_u64(5);
let e2 = EF::from_u64(3);
let e3 = EF::from_u64(7);
let w0 = EF::from_u64(1);
let w1 = EF::from_u64(4);
let w2 = EF::from_u64(2);
let w3 = EF::from_u64(6);
let evals = Poly::new(vec![e0, e1, e2, e3]);
let weights = Poly::new(vec![w0, w1, w2, w3]);
let mut poly =
ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let mut sum = e0 * w0 + e1 * w1 + e2 * w2 + e3 * w3;
assert_eq!(poly.dot_product(), sum);
let mut sumcheck_data = SumcheckData::default();
let mut challenger = make_challenger();
let _r = poly.round(&mut sumcheck_data, &mut challenger, &mut sum, 0);
assert_eq!(poly.dot_product(), sum);
assert!(!sumcheck_data.polynomial_evaluations.is_empty());
}
#[test]
fn test_round_multiple_rounds() {
let mut rng = SmallRng::seed_from_u64(123);
let num_variables = 4;
let num_evals = 1 << num_variables;
let evals: Vec<EF> = (0..num_evals).map(|_| EF::from_u64(rng.random())).collect();
let weights: Vec<EF> = (0..num_evals).map(|_| EF::from_u64(rng.random())).collect();
let mut poly = ProductPolynomial::<F, EF>::new_unpacked(
VariableOrder::Prefix,
Poly::new(evals),
Poly::new(weights),
);
let mut sum = poly.dot_product();
let mut sumcheck_data = SumcheckData::default();
let mut challenger = make_challenger();
for expected_vars in (1..=num_variables).rev() {
assert_eq!(poly.num_variables(), expected_vars);
let _ = poly.round(&mut sumcheck_data, &mut challenger, &mut sum, 0);
assert_eq!(poly.dot_product(), sum);
}
assert_eq!(poly.num_variables(), 0);
}
#[test]
fn test_dot_product_packed_matches_scalar() {
type EP = <EF as ExtensionField<F>>::ExtensionPacking;
let simd_width = <F as Field>::Packing::WIDTH;
let num_variables = log2_strict_usize(simd_width) + 1;
let num_evals = 1 << num_variables;
let mut rng = SmallRng::seed_from_u64(456);
let evals_scalar: Vec<EF> = (0..num_evals).map(|_| EF::from_u64(rng.random())).collect();
let weights_scalar: Vec<EF> = (0..num_evals).map(|_| EF::from_u64(rng.random())).collect();
let expected: EF = evals_scalar
.iter()
.zip(weights_scalar.iter())
.map(|(&e, &w)| e * w)
.sum();
let small_poly = ProductPolynomial::<F, EF>::new_unpacked(
VariableOrder::Prefix,
Poly::new(evals_scalar.clone()),
Poly::new(weights_scalar.clone()),
);
assert_eq!(small_poly.dot_product(), expected);
let packed_evals = Poly::new(
evals_scalar
.chunks(simd_width)
.map(EP::from_ext_slice)
.collect(),
);
let packed_weights = Poly::new(
weights_scalar
.chunks(simd_width)
.map(EP::from_ext_slice)
.collect(),
);
let packed_poly = ProductPolynomial::<F, EF>::new_packed(
VariableOrder::Prefix,
packed_evals,
packed_weights,
);
assert_eq!(packed_poly.dot_product(), expected);
}
#[test]
fn test_evals_extraction() {
let e0 = EF::from_u64(10);
let e1 = EF::from_u64(20);
let e2 = EF::from_u64(30);
let e3 = EF::from_u64(40);
let evals = Poly::new(vec![e0, e1, e2, e3]);
let weights = Poly::new(vec![EF::ONE; 4]);
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let extracted = poly.evals();
assert_eq!(extracted.as_slice(), &[e0, e1, e2, e3]);
}
#[test]
fn test_combine_updates_weights_and_sum() {
use crate::constraints::Constraint;
use crate::constraints::statement::EqStatement;
let num_variables = 2;
let evals = Poly::new(vec![EF::ONE; 4]);
let weights = Poly::new(vec![EF::ONE; 4]);
let mut poly =
ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals.clone(), weights);
let initial_dot = poly.dot_product();
assert_eq!(initial_dot, EF::from_u64(4));
let mut eq_statement = EqStatement::initialize(num_variables);
let point = Point::new(vec![EF::from_u64(2), EF::from_u64(3)]);
let eval = evals.eval_ext::<F>(&point);
eq_statement.add_evaluated_constraint(point, eval);
let challenge = EF::from_u64(7);
let constraint = Constraint::<F, EF>::new_eq_only(challenge, eq_statement);
let mut sum = poly.dot_product();
poly.combine(&mut sum, &constraint);
assert_eq!(poly.dot_product(), sum);
}
#[test]
fn test_eval_at_boolean_points() {
let e00 = EF::from_u64(1);
let e01 = EF::from_u64(2);
let e10 = EF::from_u64(3);
let e11 = EF::from_u64(4);
let evals = Poly::new(vec![e00, e01, e10, e11]);
let weights = Poly::new(vec![EF::ONE; 4]);
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix, evals, weights);
let point_00 = Point::new(vec![EF::ZERO, EF::ZERO]);
assert_eq!(poly.eval(&point_00), e00);
let point_01 = Point::new(vec![EF::ZERO, EF::ONE]);
assert_eq!(poly.eval(&point_01), e01);
let point_10 = Point::new(vec![EF::ONE, EF::ZERO]);
assert_eq!(poly.eval(&point_10), e10);
let point_11 = Point::new(vec![EF::ONE, EF::ONE]);
assert_eq!(poly.eval(&point_11), e11);
}
proptest! {
#[test]
fn prop_dot_product_consistency(seed in 0u64..1000) {
let mut rng = SmallRng::seed_from_u64(seed);
let num_variables = 3;
let num_evals = 1 << num_variables;
let evals: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let weights: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix,
Poly::new(evals.clone()),
Poly::new(weights.clone()),
);
let expected: EF = evals
.iter()
.zip(weights.iter())
.map(|(&e, &w)| e * w)
.sum();
prop_assert_eq!(poly.dot_product(), expected);
}
#[test]
fn prop_compress_maintains_invariant(seed in 0u64..1000, challenge_val in 1u64..100) {
let mut rng = SmallRng::seed_from_u64(seed);
let num_variables = 3;
let num_evals = 1 << num_variables;
let evals: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let weights: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let mut poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix,
Poly::new(evals),
Poly::new(weights),
);
let (c0, c_inf) = match &poly.inner {
MaybePacked::Unpacked {
evals: small_evals,
weights: small_weights,
} => sumcheck_coefficients_prefix(small_evals.as_slice(), small_weights.as_slice()),
MaybePacked::Packed { .. } => unreachable!(),
};
let initial_sum = poly.dot_product();
let c1 = initial_sum - c0;
let r = EF::from_u64(challenge_val);
poly.compress(r);
let h_r = extrapolate_01inf(c0, c1, c_inf, r);
prop_assert_eq!(poly.dot_product(), h_r);
}
#[test]
fn prop_round_maintains_invariant(seed in 0u64..1000) {
let mut rng = SmallRng::seed_from_u64(seed);
let num_variables = 4;
let num_evals = 1 << num_variables;
let evals: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let weights: Vec<EF> = (0..num_evals)
.map(|_| EF::from_u64(u64::from(rng.random::<u32>())))
.collect();
let mut poly = ProductPolynomial::<F, EF>::new_unpacked(VariableOrder::Prefix,
Poly::new(evals),
Poly::new(weights),
);
let mut sum = poly.dot_product();
let mut sumcheck_data = SumcheckData::default();
let perm = Perm::new_from_rng_128(&mut SmallRng::seed_from_u64(seed + 1000));
let mut challenger: TestChallenger = DuplexChallenger::new(perm);
let _ = poly.round(&mut sumcheck_data, &mut challenger, &mut sum, 0);
prop_assert_eq!(poly.dot_product(), sum);
}
}
}