use super::CompositePolynomialBuilder;
use crate::base::{polynomial::MultilinearExtension, scalar::Scalar};
use alloc::{boxed::Box, vec::Vec};
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub enum SumcheckSubpolynomialType {
Identity,
ZeroSum,
}
pub type SumcheckSubpolynomialTerm<'a, S> = (S, Vec<Box<dyn MultilinearExtension<S> + 'a>>);
#[derive(Debug)]
pub struct SumcheckSubpolynomial<'a, S: Scalar> {
terms: Vec<SumcheckSubpolynomialTerm<'a, S>>,
subpolynomial_type: SumcheckSubpolynomialType,
}
impl<'a, S: Scalar> SumcheckSubpolynomial<'a, S> {
pub fn new(
subpolynomial_type: SumcheckSubpolynomialType,
terms: Vec<SumcheckSubpolynomialTerm<'a, S>>,
) -> Self {
Self {
terms,
subpolynomial_type,
}
}
pub fn compose(
&self,
composite_polynomial: &mut CompositePolynomialBuilder<S>,
group_multiplier: S,
) {
for (mult, term) in &self.terms {
match self.subpolynomial_type {
SumcheckSubpolynomialType::Identity => {
composite_polynomial.produce_fr_multiplicand(&(*mult * group_multiplier), term);
}
SumcheckSubpolynomialType::ZeroSum => composite_polynomial
.produce_zerosum_multiplicand(&(*mult * group_multiplier), term),
}
}
}
pub(crate) fn subpolynomial_type(&self) -> SumcheckSubpolynomialType {
self.subpolynomial_type
}
pub(crate) fn iter_mul_by(
&self,
multiplier: S,
) -> impl Iterator<
Item = (
SumcheckSubpolynomialType,
S,
&Vec<Box<dyn MultilinearExtension<S> + 'a>>,
),
> {
self.terms.iter().map(move |(coeff, multiplicands)| {
(self.subpolynomial_type, multiplier * *coeff, multiplicands)
})
}
}
#[cfg(test)]
mod tests {
use super::{SumcheckSubpolynomial, SumcheckSubpolynomialTerm, SumcheckSubpolynomialType};
use crate::base::scalar::test_scalar::TestScalar;
use alloc::boxed::Box;
#[test]
fn test_iter_mul_by() {
let mle1 = vec![TestScalar::from(1), TestScalar::from(2)];
let mle2 = vec![TestScalar::from(3), TestScalar::from(4)];
let terms: Vec<SumcheckSubpolynomialTerm<_>> = vec![
(TestScalar::from(2), vec![Box::new(&mle1)]),
(TestScalar::from(3), vec![Box::new(&mle2)]),
];
let subpoly = SumcheckSubpolynomial::new(SumcheckSubpolynomialType::Identity, terms);
let multiplier = TestScalar::from(5);
let mut iter = subpoly.iter_mul_by(multiplier);
let (subpoly_type, coeff, _extensions) = iter.next().unwrap();
assert_eq!(subpoly_type, SumcheckSubpolynomialType::Identity);
assert_eq!(coeff, TestScalar::from(10));
let (subpoly_type, coeff, _extensions) = iter.next().unwrap();
assert_eq!(subpoly_type, SumcheckSubpolynomialType::Identity);
assert_eq!(coeff, TestScalar::from(15));
assert!(iter.next().is_none());
}
}