lambdaworks_math/polynomial/
sparse_multilinear_poly.rs

1#[cfg(feature = "parallel")]
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::field::{element::FieldElement, traits::IsField};
5use crate::polynomial::error::MultilinearError;
6use alloc::vec::Vec;
7
8pub struct SparseMultilinearPolynomial<F: IsField>
9where
10    <F as IsField>::BaseType: Send + Sync,
11{
12    num_vars: usize,
13    evals: Vec<(usize, FieldElement<F>)>,
14}
15
16impl<F: IsField> SparseMultilinearPolynomial<F>
17where
18    <F as IsField>::BaseType: Send + Sync,
19{
20    pub fn new(num_vars: usize, evals: Vec<(usize, FieldElement<F>)>) -> Self {
21        SparseMultilinearPolynomial { num_vars, evals }
22    }
23
24    pub fn num_vars(&self) -> usize {
25        self.num_vars
26    }
27
28    /// Computes the eq extension polynomial of the polynomial.
29    /// return 1 when a == r, otherwise return 0.
30    fn compute_chi(a: &[bool], r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
31        assert_eq!(a.len(), r.len());
32        if a.len() != r.len() {
33            return Err(MultilinearError::ChisAndEvalsLengthMismatch(
34                a.len(),
35                r.len(),
36            ));
37        }
38        let mut chi_i = FieldElement::one();
39        for j in 0..r.len() {
40            if a[j] {
41                chi_i *= &r[j];
42            } else {
43                chi_i *= FieldElement::<F>::one() - &r[j];
44            }
45        }
46        Ok(chi_i)
47    }
48
49    // Takes O(n log n)
50    pub fn evaluate(&self, r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
51        if r.len() != self.num_vars() {
52            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
53                r.len(),
54                self.num_vars(),
55            ));
56        }
57
58        #[cfg(feature = "parallel")]
59        let iter = (0..self.evals.len()).into_par_iter();
60
61        #[cfg(not(feature = "parallel"))]
62        let iter = 0..self.evals.len();
63
64        Ok(iter
65            .map(|i| {
66                let bits = get_bits(self.evals[i].0, r.len());
67                let mut chi_i = FieldElement::<F>::one();
68                for j in 0..r.len() {
69                    if bits[j] {
70                        chi_i *= &r[j];
71                    } else {
72                        chi_i *= FieldElement::<F>::one() - &r[j];
73                    }
74                }
75                chi_i * &self.evals[i].1
76            })
77            .sum())
78    }
79
80    // Takes O(n log n)
81    pub fn evaluate_with(
82        num_vars: usize,
83        evals: &[(usize, FieldElement<F>)],
84        r: &[FieldElement<F>],
85    ) -> Result<FieldElement<F>, MultilinearError> {
86        assert_eq!(num_vars, r.len());
87        if r.len() != num_vars {
88            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
89                r.len(),
90                num_vars,
91            ));
92        }
93
94        #[cfg(feature = "parallel")]
95        let iter = (0..evals.len()).into_par_iter();
96
97        #[cfg(not(feature = "parallel"))]
98        let iter = 0..evals.len();
99        Ok(iter
100            .map(|i| {
101                let bits = get_bits(evals[i].0, r.len());
102                SparseMultilinearPolynomial::compute_chi(&bits, r).unwrap() * &evals[i].1
103            })
104            .sum())
105    }
106}
107
108/// Returns the bit decomposition (Vec<bool>) of the `index` of an evaluation within the sparse multilinear polynomial.
109fn get_bits(n: usize, num_bits: usize) -> Vec<bool> {
110    (0..num_bits)
111        .map(|shift_amount| ((n & (1 << (num_bits - shift_amount - 1))) > 0))
112        .collect::<Vec<bool>>()
113}
114
115#[cfg(test)]
116mod test {
117
118    #[test]
119    fn evaluate() {
120        use crate::field::fields::u64_prime_field::U64PrimeField;
121        use alloc::vec;
122
123        use super::*;
124
125        const ORDER: u64 = 101;
126        type F = U64PrimeField<ORDER>;
127        type FE = FieldElement<F>;
128
129        // Let the polynomial have 3 variables, p(x_1, x_2, x_3) = (x_1 + x_2) * x_3
130        // Evaluations of the polynomial at boolean cube are [0, 0, 0, 1, 0, 1, 0, 2].
131
132        let two = FE::from(2);
133        let z = vec![(3, FE::one()), (5, FE::one()), (7, two)];
134        let m_poly = SparseMultilinearPolynomial::<F>::new(3, z.clone());
135
136        let x = vec![FE::one(), FE::one(), FE::one()];
137        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), two);
138        assert_eq!(
139            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
140            two
141        );
142
143        let x = vec![FE::one(), FE::zero(), FE::one()];
144        assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), FE::one());
145        assert_eq!(
146            SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
147            FE::one()
148        );
149    }
150}