lambdaworks_math/polynomial/
sparse_multilinear_poly.rs

1#[cfg(feature = "parallel")]
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::field::{element::FieldElement, traits::IsField};
5use crate::polynomial::error::MultilinearError;
6use alloc::vec::Vec;
7
8pub struct SparseMultilinearPolynomial<F: IsField>
9where
10    <F as IsField>::BaseType: Send + Sync,
11{
12    num_vars: usize,
13    evals: Vec<(usize, FieldElement<F>)>,
14}
15
16impl<F: IsField> SparseMultilinearPolynomial<F>
17where
18    <F as IsField>::BaseType: Send + Sync,
19{
20    /// Creates a new sparse multilinear polynomial
21    pub fn new(num_vars: usize, evals: Vec<(usize, FieldElement<F>)>) -> Self {
22        SparseMultilinearPolynomial { num_vars, evals }
23    }
24
25    /// Returns the number of variables of the sparse multilinear polynomial
26    pub fn num_vars(&self) -> usize {
27        self.num_vars
28    }
29
30    /// Computes the eq extension polynomial of the polynomial.
31    /// return 1 when a == r, otherwise return 0.
32    /// Chi is the Lagrange basis polynomial evaluated at r
33    /// a indicates the binary decomposition of the index of the Lagrange basis polynomial
34    fn compute_chi(a: &[bool], r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
35        assert_eq!(a.len(), r.len());
36        if a.len() != r.len() {
37            return Err(MultilinearError::ChisAndEvalsLengthMismatch(
38                a.len(),
39                r.len(),
40            ));
41        }
42        let mut chi_i = FieldElement::one();
43        for j in 0..r.len() {
44            if a[j] {
45                chi_i *= &r[j];
46            } else {
47                chi_i *= FieldElement::<F>::one() - &r[j];
48            }
49        }
50        Ok(chi_i)
51    }
52
53    // Takes O(n log n)
54    /// Evaluates the multilinear polynomial at a point r
55    pub fn evaluate(&self, r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
56        if r.len() != self.num_vars() {
57            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
58                r.len(),
59                self.num_vars(),
60            ));
61        }
62
63        #[cfg(feature = "parallel")]
64        let iter = (0..self.evals.len()).into_par_iter();
65
66        #[cfg(not(feature = "parallel"))]
67        let iter = 0..self.evals.len();
68
69        Ok(iter
70            .map(|i| {
71                let bits = get_bits(self.evals[i].0, r.len());
72                let mut chi_i = FieldElement::<F>::one();
73                for j in 0..r.len() {
74                    if bits[j] {
75                        chi_i *= &r[j];
76                    } else {
77                        chi_i *= FieldElement::<F>::one() - &r[j];
78                    }
79                }
80                chi_i * &self.evals[i].1
81            })
82            .sum())
83    }
84
85    // Takes O(n log n)
86    /// Evaluates the multilinear polynomial at a point r
87    pub fn evaluate_with(
88        num_vars: usize,
89        evals: &[(usize, FieldElement<F>)],
90        r: &[FieldElement<F>],
91    ) -> Result<FieldElement<F>, MultilinearError> {
92        assert_eq!(num_vars, r.len());
93        if r.len() != num_vars {
94            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
95                r.len(),
96                num_vars,
97            ));
98        }
99
100        #[cfg(feature = "parallel")]
101        let iter = (0..evals.len()).into_par_iter();
102
103        #[cfg(not(feature = "parallel"))]
104        let iter = 0..evals.len();
105        Ok(iter
106            .map(|i| {
107                let bits = get_bits(evals[i].0, r.len());
108                SparseMultilinearPolynomial::compute_chi(&bits, r).unwrap() * &evals[i].1
109            })
110            .sum())
111    }
112}
113
114/// Returns the bit decomposition (Vec<bool>) of the `index` of an evaluation within the sparse multilinear polynomial.
115fn get_bits(n: usize, num_bits: usize) -> Vec<bool> {
116    (0..num_bits)
117        .map(|shift_amount| (n & (1 << (num_bits - shift_amount - 1))) > 0)
118        .collect::<Vec<bool>>()
119}
120
121#[cfg(test)]
122mod test {
123
124    #[test]
125    fn evaluate() {
126        use crate::field::fields::u64_prime_field::U64PrimeField;
127        use alloc::vec;
128
129        use super::*;
130
131        const ORDER: u64 = 101;
132        type F = U64PrimeField<ORDER>;
133        type FE = FieldElement<F>;
134
135        // Let the polynomial have 3 variables, p(x_1, x_2, x_3) = (x_1 + x_2) * x_3
136        // Evaluations of the polynomial at boolean cube are [0, 0, 0, 1, 0, 1, 0, 2].
137
138        let two = FE::from(2);
139        let z = vec![(3, FE::one()), (5, FE::one()), (7, two)];
140        let m_poly = SparseMultilinearPolynomial::<F>::new(3, z.clone());
141
142        let x = vec![FE::one(), FE::one(), FE::one()];
143        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), two);
144        assert_eq!(
145            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
146            two
147        );
148
149        let x = vec![FE::one(), FE::zero(), FE::one()];
150        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), FE::one());
151        assert_eq!(
152            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
153            FE::one()
154        );
155    }
156}