lambdaworks_math/polynomial/
sparse_multilinear_poly.rs1#[cfg(feature = "parallel")]
2use rayon::iter::{IntoParallelIterator, ParallelIterator};
3
4use crate::field::{element::FieldElement, traits::IsField};
5use crate::polynomial::error::MultilinearError;
6use alloc::vec::Vec;
7
8pub struct SparseMultilinearPolynomial<F: IsField>
9where
10 <F as IsField>::BaseType: Send + Sync,
11{
12 num_vars: usize,
13 evals: Vec<(usize, FieldElement<F>)>,
14}
15
16impl<F: IsField> SparseMultilinearPolynomial<F>
17where
18 <F as IsField>::BaseType: Send + Sync,
19{
20 pub fn new(num_vars: usize, evals: Vec<(usize, FieldElement<F>)>) -> Self {
21 SparseMultilinearPolynomial { num_vars, evals }
22 }
23
24 pub fn num_vars(&self) -> usize {
25 self.num_vars
26 }
27
28 fn compute_chi(a: &[bool], r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
31 assert_eq!(a.len(), r.len());
32 if a.len() != r.len() {
33 return Err(MultilinearError::ChisAndEvalsLengthMismatch(
34 a.len(),
35 r.len(),
36 ));
37 }
38 let mut chi_i = FieldElement::one();
39 for j in 0..r.len() {
40 if a[j] {
41 chi_i *= &r[j];
42 } else {
43 chi_i *= FieldElement::<F>::one() - &r[j];
44 }
45 }
46 Ok(chi_i)
47 }
48
49 pub fn evaluate(&self, r: &[FieldElement<F>]) -> Result<FieldElement<F>, MultilinearError> {
51 if r.len() != self.num_vars() {
52 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
53 r.len(),
54 self.num_vars(),
55 ));
56 }
57
58 #[cfg(feature = "parallel")]
59 let iter = (0..self.evals.len()).into_par_iter();
60
61 #[cfg(not(feature = "parallel"))]
62 let iter = 0..self.evals.len();
63
64 Ok(iter
65 .map(|i| {
66 let bits = get_bits(self.evals[i].0, r.len());
67 let mut chi_i = FieldElement::<F>::one();
68 for j in 0..r.len() {
69 if bits[j] {
70 chi_i *= &r[j];
71 } else {
72 chi_i *= FieldElement::<F>::one() - &r[j];
73 }
74 }
75 chi_i * &self.evals[i].1
76 })
77 .sum())
78 }
79
80 pub fn evaluate_with(
82 num_vars: usize,
83 evals: &[(usize, FieldElement<F>)],
84 r: &[FieldElement<F>],
85 ) -> Result<FieldElement<F>, MultilinearError> {
86 assert_eq!(num_vars, r.len());
87 if r.len() != num_vars {
88 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
89 r.len(),
90 num_vars,
91 ));
92 }
93
94 #[cfg(feature = "parallel")]
95 let iter = (0..evals.len()).into_par_iter();
96
97 #[cfg(not(feature = "parallel"))]
98 let iter = 0..evals.len();
99 Ok(iter
100 .map(|i| {
101 let bits = get_bits(evals[i].0, r.len());
102 SparseMultilinearPolynomial::compute_chi(&bits, r).unwrap() * &evals[i].1
103 })
104 .sum())
105 }
106}
107
108fn get_bits(n: usize, num_bits: usize) -> Vec<bool> {
110 (0..num_bits)
111 .map(|shift_amount| ((n & (1 << (num_bits - shift_amount - 1))) > 0))
112 .collect::<Vec<bool>>()
113}
114
115#[cfg(test)]
116mod test {
117
118 #[test]
119 fn evaluate() {
120 use crate::field::fields::u64_prime_field::U64PrimeField;
121 use alloc::vec;
122
123 use super::*;
124
125 const ORDER: u64 = 101;
126 type F = U64PrimeField<ORDER>;
127 type FE = FieldElement<F>;
128
129 let two = FE::from(2);
133 let z = vec![(3, FE::one()), (5, FE::one()), (7, two)];
134 let m_poly = SparseMultilinearPolynomial::<F>::new(3, z.clone());
135
136 let x = vec![FE::one(), FE::one(), FE::one()];
137 assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), two);
138 assert_eq!(
139 SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
140 two
141 );
142
143 let x = vec![FE::one(), FE::zero(), FE::one()];
144 assert_eq!(m_poly.evaluate(x.as_slice()).unwrap(), FE::one());
145 assert_eq!(
146 SparseMultilinearPolynomial::evaluate_with(3, &z, x.as_slice()).unwrap(),
147 FE::one()
148 );
149 }
150}