lambdaworks_math/polynomial/
sparse_multilinear_poly.rs1#[cfg(feature = "parallel")]
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::field::{element::FieldElement, traits::IsField};
5use crate::polynomial::error::MultilinearError;
6use alloc::vec::Vec;
7
8pub struct SparseMultilinearPolynomial<F: IsField>
9where
10 <F as IsField>::BaseType: Send + Sync,
11{
12 num_vars: usize,
13 evals: Vec<(usize, FieldElement<F>)>,
14}
15
16impl<F: IsField> SparseMultilinearPolynomial<F>
17where
18 <F as IsField>::BaseType: Send + Sync,
19{
20 pub fn new(num_vars: usize, evals: Vec<(usize, FieldElement<F>)>) -> Self {
22 SparseMultilinearPolynomial { num_vars, evals }
23 }
24
25 pub fn num_vars(&self) -> usize {
27 self.num_vars
28 }
29
30 fn compute_chi(a: &[bool], r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
35 assert_eq!(a.len(), r.len());
36 if a.len() != r.len() {
37 return Err(MultilinearError::ChisAndEvalsLengthMismatch(
38 a.len(),
39 r.len(),
40 ));
41 }
42 let mut chi_i = FieldElement::one();
43 for j in 0..r.len() {
44 if a[j] {
45 chi_i *= &r[j];
46 } else {
47 chi_i *= FieldElement::<F>::one() - &r[j];
48 }
49 }
50 Ok(chi_i)
51 }
52
53 pub fn evaluate(&self, r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
56 if r.len() != self.num_vars() {
57 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
58 r.len(),
59 self.num_vars(),
60 ));
61 }
62
63 #[cfg(feature = "parallel")]
64 let iter = (0..self.evals.len()).into_par_iter();
65
66 #[cfg(not(feature = "parallel"))]
67 let iter = 0..self.evals.len();
68
69 Ok(iter
70 .map(|i| {
71 let bits = get_bits(self.evals[i].0, r.len());
72 let mut chi_i = FieldElement::<F>::one();
73 for j in 0..r.len() {
74 if bits[j] {
75 chi_i *= &r[j];
76 } else {
77 chi_i *= FieldElement::<F>::one() - &r[j];
78 }
79 }
80 chi_i * &self.evals[i].1
81 })
82 .sum())
83 }
84
85 pub fn evaluate_with(
88 num_vars: usize,
89 evals: &[(usize, FieldElement<F>)],
90 r: &[FieldElement<F>],
91 ) -> Result<FieldElement<F>, MultilinearError> {
92 assert_eq!(num_vars, r.len());
93 if r.len() != num_vars {
94 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
95 r.len(),
96 num_vars,
97 ));
98 }
99
100 #[cfg(feature = "parallel")]
101 let iter = (0..evals.len()).into_par_iter();
102
103 #[cfg(not(feature = "parallel"))]
104 let iter = 0..evals.len();
105 Ok(iter
106 .map(|i| {
107 let bits = get_bits(evals[i].0, r.len());
108 SparseMultilinearPolynomial::compute_chi(&bits, r).unwrap() * &evals[i].1
109 })
110 .sum())
111 }
112}
113
114fn get_bits(n: usize, num_bits: usize) -> Vec<bool> {
116 (0..num_bits)
117 .map(|shift_amount| (n & (1 << (num_bits - shift_amount - 1))) > 0)
118 .collect::<Vec<bool>>()
119}
120
121#[cfg(test)]
122mod test {
123
124 #[test]
125 fn evaluate() {
126 use crate::field::fields::u64_prime_field::U64PrimeField;
127 use alloc::vec;
128
129 use super::*;
130
131 const ORDER: u64 = 101;
132 type F = U64PrimeField<ORDER>;
133 type FE = FieldElement<F>;
134
135 let two = FE::from(2);
139 let z = vec![(3, FE::one()), (5, FE::one()), (7, two)];
140 let m_poly = SparseMultilinearPolynomial::<F>::new(3, z.clone());
141
142 let x = vec![FE::one(), FE::one(), FE::one()];
143 assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), two);
144 assert_eq!(
145 SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
146 two
147 );
148
149 let x = vec![FE::one(), FE::zero(), FE::one()];
150 assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), FE::one());
151 assert_eq!(
152 SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
153 FE::one()
154 );
155 }
156}