use crate::constants::PARALLEL_THRESHOLD;
use crate::spartan::{math::Math, polys::eq::EqPolynomial};
use core::ops::{Add, Index};
use ff::PrimeField;
use itertools::Itertools as _;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MultilinearPolynomial<Scalar: PrimeField> {
num_vars: usize, pub Z: Vec<Scalar>,
}
impl<Scalar: PrimeField> MultilinearPolynomial<Scalar> {
pub fn new(Z: Vec<Scalar>) -> Self {
let num_vars = Z.len().log_2();
assert_eq!(Z.len(), 1 << num_vars);
MultilinearPolynomial { num_vars, Z }
}
pub const fn get_num_vars(&self) -> usize {
self.num_vars
}
pub fn len(&self) -> usize {
self.Z.len()
}
pub fn is_empty(&self) -> bool {
self.Z.is_empty()
}
pub fn bind_poly_var_top(&mut self, r: &Scalar) {
assert!(self.num_vars > 0);
let n = self.len() / 2;
let (left, right) = self.Z.split_at_mut(n);
if n < PARALLEL_THRESHOLD {
left.iter_mut().zip(right.iter()).for_each(|(a, b)| {
*a += *r * (*b - *a);
});
} else {
zip_with_for_each!((left.par_iter_mut(), right.par_iter()), |a, b| {
*a += *r * (*b - *a);
});
}
self.Z.resize(n, Scalar::ZERO);
self.num_vars -= 1;
}
pub fn evaluate(&self, r: &[Scalar]) -> Scalar {
assert_eq!(r.len(), self.get_num_vars());
Self::evaluate_with(&self.Z, r)
}
pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar {
let s = r.len();
let s_right = s / 2;
let s_left = s - s_right;
let n_left = 1 << s_left;
let n_right = 1 << s_right;
let eq_left = EqPolynomial::evals_from_points(&r[..s_left]);
let eq_right = EqPolynomial::evals_from_points(&r[s_left..]);
let reduced: Vec<Scalar> = (0..n_left)
.into_par_iter()
.map(|i| {
let chunk = &Z[i * n_right..(i + 1) * n_right];
chunk
.iter()
.zip(eq_right.iter())
.map(|(z, e)| *z * *e)
.sum()
})
.collect();
zip_with!(
(eq_left.into_par_iter(), reduced.into_par_iter()),
|a, b| a * b
)
.sum()
}
pub fn multi_evaluate_with(Zs: &[&[Scalar]], r: &[Scalar]) -> Vec<Scalar> {
let k = Zs.len();
if k == 0 {
return Vec::new();
}
let s = r.len();
let n = 1usize << s;
assert!(Zs.iter().all(|z| z.len() == n));
let s_right = s / 2;
let s_left = s - s_right;
let n_left = 1 << s_left;
let n_right = 1 << s_right;
let eq_left = EqPolynomial::evals_from_points(&r[..s_left]);
let eq_right = EqPolynomial::evals_from_points(&r[s_left..]);
let mut all_reduced = vec![Scalar::ZERO; n_left * k];
all_reduced
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, sums)| {
let start = i * n_right;
for j in 0..n_right {
let eq_val = eq_right[j];
for (p, z) in Zs.iter().enumerate() {
sums[p] += z[start + j] * eq_val;
}
}
});
(0..k)
.map(|p| {
if n_left < PARALLEL_THRESHOLD {
(0..n_left)
.map(|i| eq_left[i] * all_reduced[i * k + p])
.sum()
} else {
(0..n_left)
.into_par_iter()
.map(|i| eq_left[i] * all_reduced[i * k + p])
.sum()
}
})
.collect()
}
}
impl<Scalar: PrimeField> Index<usize> for MultilinearPolynomial<Scalar> {
type Output = Scalar;
#[inline(always)]
fn index(&self, _index: usize) -> &Scalar {
&(self.Z[_index])
}
}
pub struct SparsePolynomial<Scalar: PrimeField> {
num_vars: usize,
pub Z: Vec<Scalar>,
}
impl<Scalar: PrimeField> SparsePolynomial<Scalar> {
pub fn new(num_vars: usize, Z: Vec<Scalar>) -> Self {
SparsePolynomial { num_vars, Z }
}
pub fn evaluate(&self, r: &[Scalar]) -> Scalar {
assert_eq!(self.num_vars, r.len());
let num_vars_z = self.Z.len().next_power_of_two().log_2();
let chis = EqPolynomial::evals_from_points(&r[self.num_vars - 1 - num_vars_z..]);
let eval_partial: Scalar = self
.Z
.iter()
.zip(chis.iter())
.map(|(z, chi)| *z * *chi)
.sum();
let common = (0..self.num_vars - 1 - num_vars_z)
.map(|i| Scalar::ONE - r[i])
.product::<Scalar>();
common * eval_partial
}
}
impl<Scalar: PrimeField> Add for MultilinearPolynomial<Scalar> {
type Output = Result<Self, &'static str>;
fn add(self, other: Self) -> Self::Output {
if self.get_num_vars() != other.get_num_vars() {
return Err("The two polynomials must have the same number of variables");
}
let sum: Vec<Scalar> = zip_with!(into_iter, (self.Z, other.Z), |a, b| a + b).collect();
Ok(MultilinearPolynomial::new(sum))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::{bn256_grumpkin::bn256, pasta::pallas, secp_secq::secp256k1};
use rand_chacha::ChaCha20Rng;
use rand_core::{CryptoRng, RngCore, SeedableRng};
fn make_mlp<F: PrimeField>(len: usize, value: F) -> MultilinearPolynomial<F> {
MultilinearPolynomial {
num_vars: len.count_ones() as usize,
Z: vec![value; len],
}
}
fn test_multilinear_polynomial_with<F: PrimeField>() {
let TWO = F::from(2);
let Z = vec![
F::ZERO,
F::ZERO,
F::ZERO,
F::ONE,
F::ZERO,
F::ONE,
F::ZERO,
TWO,
];
let m_poly = MultilinearPolynomial::<F>::new(Z.clone());
assert_eq!(m_poly.get_num_vars(), 3);
let x = vec![F::ONE, F::ONE, F::ONE];
assert_eq!(m_poly.evaluate(x.as_slice()), TWO);
let y = MultilinearPolynomial::<F>::evaluate_with(Z.as_slice(), x.as_slice());
assert_eq!(y, TWO);
}
fn test_sparse_polynomial_with<F: PrimeField>() {
let mut Z = vec![F::ONE, F::ONE, F::from(2)];
let m_poly = SparsePolynomial::<F>::new(4, Z.clone());
Z.resize(16, F::ZERO); let m_poly_dense = MultilinearPolynomial::new(Z);
let x = vec![F::from(5), F::from(8), F::from(5), F::from(3)];
assert_eq!(
m_poly.evaluate(x.as_slice()),
m_poly_dense.evaluate(x.as_slice())
);
}
#[test]
fn test_multilinear_polynomial() {
test_multilinear_polynomial_with::<pallas::Scalar>();
}
#[test]
fn test_sparse_polynomial() {
test_sparse_polynomial_with::<pallas::Scalar>();
}
fn test_mlp_add_with<F: PrimeField>() {
let mlp1 = make_mlp(4, F::from(3));
let mlp2 = make_mlp(4, F::from(7));
let mlp3 = mlp1.add(mlp2).unwrap();
assert_eq!(mlp3.Z, vec![F::from(10); 4]);
}
#[test]
fn test_mlp_add() {
test_mlp_add_with::<pallas::Scalar>();
test_mlp_add_with::<bn256::Scalar>();
test_mlp_add_with::<secp256k1::Scalar>();
}
fn test_evaluation_with<F: PrimeField>() {
let num_evals = 4;
let mut evals: Vec<F> = Vec::with_capacity(num_evals);
for _ in 0..num_evals {
evals.push(F::from(8));
}
let dense_poly: MultilinearPolynomial<F> = MultilinearPolynomial::new(evals.clone());
assert_eq!(
dense_poly.evaluate(vec![F::from(3), F::from(4)].as_slice()),
F::from(8)
);
}
#[test]
fn test_evaluation() {
test_evaluation_with::<pallas::Scalar>();
test_evaluation_with::<bn256::Scalar>();
test_evaluation_with::<secp256k1::Scalar>();
}
#[allow(clippy::needless_borrows_for_generic_args)]
fn random<R: RngCore + CryptoRng, Scalar: PrimeField>(
num_vars: usize,
mut rng: &mut R,
) -> MultilinearPolynomial<Scalar> {
MultilinearPolynomial::new(
std::iter::from_fn(|| Some(Scalar::random(&mut rng)))
.take(1 << num_vars)
.collect(),
)
}
fn bind_sequence<F: PrimeField>(
poly: &MultilinearPolynomial<F>,
values: &[F],
) -> MultilinearPolynomial<F> {
assert!(poly.Z.len().is_power_of_two());
assert!(poly.Z.len() >= 1 << values.len());
let mut tmp = poly.clone();
for v in values.iter() {
tmp.bind_poly_var_top(v);
}
tmp
}
fn bind_and_evaluate_with<F: PrimeField>() {
for i in 0..50 {
let n = 7;
let mut rng = ChaCha20Rng::from_seed([i as u8; 32]);
let poly = random(n, &mut rng);
let pt: Vec<_> = std::iter::from_fn(|| Some(F::random(&mut rng)))
.take(n)
.collect();
assert_eq!(poly.evaluate(&pt), bind_sequence(&poly, &pt).Z[0])
}
}
#[test]
fn test_bind_and_evaluate() {
bind_and_evaluate_with::<pallas::Scalar>();
bind_and_evaluate_with::<bn256::Scalar>();
bind_and_evaluate_with::<secp256k1::Scalar>();
}
fn test_multi_evaluate_with_matches_single<F: PrimeField>() {
let mut rng = ChaCha20Rng::from_seed([42u8; 32]);
let num_vars = 6;
let poly1 = random::<_, F>(num_vars, &mut rng);
let poly2 = random::<_, F>(num_vars, &mut rng);
let poly3 = random::<_, F>(num_vars, &mut rng);
let pt: Vec<F> = std::iter::from_fn(|| Some(F::random(&mut rng)))
.take(num_vars)
.collect();
let expected: Vec<F> = [&poly1, &poly2, &poly3]
.iter()
.map(|p| MultilinearPolynomial::evaluate_with(&p.Z, &pt))
.collect();
let zs: Vec<&[F]> = vec![&poly1.Z, &poly2.Z, &poly3.Z];
let result = MultilinearPolynomial::multi_evaluate_with(&zs, &pt);
assert_eq!(result, expected);
}
fn test_multi_evaluate_with_single_poly<F: PrimeField>() {
let mut rng = ChaCha20Rng::from_seed([7u8; 32]);
let num_vars = 4;
let poly = random::<_, F>(num_vars, &mut rng);
let pt: Vec<F> = std::iter::from_fn(|| Some(F::random(&mut rng)))
.take(num_vars)
.collect();
let expected = MultilinearPolynomial::evaluate_with(&poly.Z, &pt);
let zs: Vec<&[F]> = vec![&poly.Z];
let result = MultilinearPolynomial::multi_evaluate_with(&zs, &pt);
assert_eq!(result, vec![expected]);
}
fn test_multi_evaluate_with_known_values<F: PrimeField>() {
let two = F::from(2);
let z1 = vec![
F::ZERO,
F::ZERO,
F::ZERO,
F::ONE,
F::ZERO,
F::ONE,
F::ZERO,
two,
];
let z2 = vec![F::from(5); 8];
let pt = vec![F::ONE, F::ONE, F::ONE];
let result = MultilinearPolynomial::multi_evaluate_with(&[z1.as_slice(), z2.as_slice()], &pt);
assert_eq!(result[0], two); assert_eq!(result[1], F::from(5)); }
#[test]
fn test_multi_evaluate_with() {
test_multi_evaluate_with_matches_single::<pallas::Scalar>();
test_multi_evaluate_with_matches_single::<bn256::Scalar>();
test_multi_evaluate_with_matches_single::<secp256k1::Scalar>();
test_multi_evaluate_with_single_poly::<pallas::Scalar>();
test_multi_evaluate_with_single_poly::<bn256::Scalar>();
test_multi_evaluate_with_single_poly::<secp256k1::Scalar>();
test_multi_evaluate_with_known_values::<pallas::Scalar>();
test_multi_evaluate_with_known_values::<bn256::Scalar>();
test_multi_evaluate_with_known_values::<secp256k1::Scalar>();
}
}