lambdaworks_math/polynomial/
dense_multilinear_poly.rs1use crate::{
2 field::{element::FieldElement, traits::IsField},
3 polynomial::{error::MultilinearError, Polynomial},
4};
5use alloc::{vec, vec::Vec};
6use core::ops::{Add, Index, Mul};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
9
10#[derive(Debug, PartialEq, Clone)]
12pub struct DenseMultilinearPolynomial<F: IsField>
13where
14 <F as IsField>::BaseType: Send + Sync,
15{
16 evals: Vec<FieldElement<F>>,
17 n_vars: usize,
18 len: usize,
19}
20
21impl<F: IsField> DenseMultilinearPolynomial<F>
22where
23 <F as IsField>::BaseType: Send + Sync,
24{
25 pub fn new(mut evals: Vec<FieldElement<F>>) -> Self {
28 while !evals.len().is_power_of_two() {
29 evals.push(FieldElement::zero());
30 }
31 let len = evals.len();
32 DenseMultilinearPolynomial {
33 n_vars: log_2(len),
34 evals,
35 len,
36 }
37 }
38
39 pub fn num_vars(&self) -> usize {
41 self.n_vars
42 }
43
44 pub fn evals(&self) -> &Vec<FieldElement<F>> {
46 &self.evals
47 }
48
49 #[allow(clippy::len_without_is_empty)]
51 pub fn len(&self) -> usize {
52 self.len
53 }
54
55 pub fn evaluate(&self, r: Vec<FieldElement<F>>) -> Result<FieldElement<F>, MultilinearError> {
58 if r.len() != self.num_vars() {
59 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
60 r.len(),
61 self.num_vars(),
62 ));
63 }
64 let mut chis: Vec<FieldElement<F>> =
65 vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
66 let mut size = 1;
67 for j in r {
68 size *= 2;
69 for i in (0..size).rev().step_by(2) {
70 let half_i = i / 2;
71 let temp = &chis[half_i] * &j;
72 chis[i] = temp;
73 chis[i - 1] = &chis[half_i] - &chis[i];
74 }
75 }
76 #[cfg(feature = "parallel")]
77 let iter = (0..chis.len()).into_par_iter();
78 #[cfg(not(feature = "parallel"))]
79 let iter = 0..chis.len();
80 Ok(iter.map(|i| &self.evals[i] * &chis[i]).sum())
81 }
82
83 pub fn evaluate_with(
85 evals: &[FieldElement<F>],
86 r: &[FieldElement<F>],
87 ) -> Result<FieldElement<F>, MultilinearError> {
88 let mut chis: Vec<FieldElement<F>> =
89 vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
90 if chis.len() != evals.len() {
91 return Err(MultilinearError::ChisAndEvalsLengthMismatch(
92 chis.len(),
93 evals.len(),
94 ));
95 }
96 let mut size = 1;
97 for j in r {
98 size *= 2;
99 for i in (0..size).rev().step_by(2) {
100 let half_i = i / 2;
101 let temp = &chis[half_i] * j;
102 chis[i] = temp;
103 chis[i - 1] = &chis[half_i] - &chis[i];
104 }
105 }
106 Ok((0..evals.len()).map(|i| &evals[i] * &chis[i]).sum())
107 }
108
109 pub fn fix_last_variable(&self, r: &FieldElement<F>) -> DenseMultilinearPolynomial<F> {
126 let n = self.num_vars();
127 assert!(n > 0, "Cannot fix variable in a 0-variable polynomial");
128 let half = 1 << (n - 1);
129 let new_evals: Vec<FieldElement<F>> = (0..half)
130 .map(|j| {
131 let a = &self.evals[j];
132 let b = &self.evals[j + half];
133 a + r * (b - a)
134 })
135 .collect();
136 DenseMultilinearPolynomial::from((n - 1, new_evals))
137 }
138
139 pub fn to_evaluations(&self) -> Vec<FieldElement<F>> {
142 self.evals.clone()
143 }
144
145 pub fn to_univariate(&self) -> Polynomial<FieldElement<F>> {
149 let poly0 = self.fix_last_variable(&FieldElement::zero());
150 let poly1 = self.fix_last_variable(&FieldElement::one());
151 let sum0: FieldElement<F> = poly0.to_evaluations().into_iter().sum();
152 let sum1: FieldElement<F> = poly1.to_evaluations().into_iter().sum();
153 let diff = sum1 - &sum0;
154 Polynomial::new(&[sum0, diff])
155 }
156
157 pub fn scalar_mul(&self, scalar: &FieldElement<F>) -> Self {
159 let mut new_poly = self.clone();
160 new_poly.evals.iter_mut().for_each(|eval| *eval *= scalar);
161 new_poly
162 }
163
164 pub fn extend(&mut self, other: &DenseMultilinearPolynomial<F>) {
166 debug_assert_eq!(self.evals.len(), self.len);
167 debug_assert_eq!(other.evals.len(), self.len);
168 self.evals.extend(other.evals.iter().cloned());
169 self.n_vars += 1;
170 self.len *= 2;
171 debug_assert_eq!(self.evals.len(), self.len);
172 }
173
174 pub fn merge(polys: &[DenseMultilinearPolynomial<F>]) -> DenseMultilinearPolynomial<F> {
177 let mut z: Vec<FieldElement<F>> = Vec::new();
179 for poly in polys {
180 z.extend(poly.evals.iter().cloned());
181 }
182 z.resize(z.len().next_power_of_two(), FieldElement::zero());
183 DenseMultilinearPolynomial::new(z)
184 }
185
186 pub fn from_u64(evals: &[u64]) -> Self {
188 DenseMultilinearPolynomial::new(evals.iter().map(|&i| FieldElement::from(i)).collect())
189 }
190}
191
192impl<F: IsField> Index<usize> for DenseMultilinearPolynomial<F>
193where
194 <F as IsField>::BaseType: Send + Sync,
195{
196 type Output = FieldElement<F>;
197
198 #[inline(always)]
199 fn index(&self, index: usize) -> &FieldElement<F> {
200 &self.evals[index]
201 }
202}
203
204impl<F: IsField> Add for DenseMultilinearPolynomial<F>
207where
208 <F as IsField>::BaseType: Send + Sync,
209{
210 type Output = Result<Self, &'static str>;
211
212 fn add(self, other: Self) -> Self::Output {
213 if self.num_vars() != other.num_vars() {
214 return Err("Polynomials must have the same number of variables");
215 }
216 #[cfg(feature = "parallel")]
217 let evals = self.evals.into_par_iter().zip(other.evals.into_par_iter());
218 #[cfg(not(feature = "parallel"))]
219 let evals = self.evals.iter().zip(other.evals.iter());
220 let sum: Vec<FieldElement<F>> = evals.map(|(a, b)| a + b).collect();
221 Ok(DenseMultilinearPolynomial::new(sum))
222 }
223}
224
225impl<F: IsField> Mul<FieldElement<F>> for DenseMultilinearPolynomial<F>
226where
227 <F as IsField>::BaseType: Send + Sync,
228{
229 type Output = DenseMultilinearPolynomial<F>;
230
231 fn mul(self, rhs: FieldElement<F>) -> Self::Output {
232 Self::scalar_mul(&self, &rhs)
233 }
234}
235
236impl<F: IsField> Mul<&FieldElement<F>> for DenseMultilinearPolynomial<F>
237where
238 <F as IsField>::BaseType: Send + Sync,
239{
240 type Output = DenseMultilinearPolynomial<F>;
241
242 fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
243 Self::scalar_mul(&self, rhs)
244 }
245}
246
247fn log_2(n: usize) -> usize {
249 if n == 0 {
250 return 0;
251 }
252 if n.is_power_of_two() {
253 (1usize.leading_zeros() - n.leading_zeros()) as usize
254 } else {
255 (0usize.leading_zeros() - n.leading_zeros()) as usize
256 }
257}
258
259impl<F: IsField> From<(usize, Vec<FieldElement<F>>)> for DenseMultilinearPolynomial<F>
260where
261 <F as IsField>::BaseType: Send + Sync,
262{
263 fn from((num_vars, evaluations): (usize, Vec<FieldElement<F>>)) -> Self {
264 assert_eq!(
265 evaluations.len(),
266 1 << num_vars,
267 "The size of evaluations should be 2^num_vars."
268 );
269 DenseMultilinearPolynomial {
270 n_vars: num_vars,
271 evals: evaluations,
272 len: 1 << num_vars,
273 }
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use crate::field::fields::u64_prime_field::U64PrimeField;
281 const ORDER: u64 = 101;
282 type F = U64PrimeField<ORDER>;
283 type FE = FieldElement<F>;
284
285 pub fn evals(r: Vec<FE>) -> Vec<FE> {
286 let mut evals: Vec<FE> = vec![FE::one(); (2usize).pow(r.len() as u32)];
287 let mut size = 1;
288 for j in r {
289 size *= 2;
290 for i in (0..size).rev().step_by(2) {
291 let scalar = evals[i / 2];
292 evals[i] = scalar * j;
293 evals[i - 1] = scalar - evals[i];
294 }
295 }
296 evals
297 }
298
299 pub fn compute_factored_evals(r: Vec<FE>) -> (Vec<FE>, Vec<FE>) {
300 let size = r.len();
301 let (left_num_vars, _right_num_vars) = (size / 2, size - size / 2);
302 let l = evals(r[..left_num_vars].to_vec());
303 let r = evals(r[left_num_vars..size].to_vec());
304 (l, r)
305 }
306
307 fn evaluate_with_lr(z: &[FE], r: &[FE]) -> FE {
308 let (l, r) = compute_factored_evals(r.to_vec());
309 let size = r.len();
310 assert!(size % 2 == 0);
312 let n = (2usize).pow(size as u32);
314 let m = (n as f64).sqrt() as usize;
316 let lz = (0..m)
318 .map(|i| {
319 (0..m).fold(FE::zero(), |mut acc, j| {
320 acc += l[j] * z[j * m + i];
321 acc
322 })
323 })
324 .collect::<Vec<FE>>();
325 (0..lz.len()).map(|i| lz[i] * r[i]).sum()
327 }
328
329 #[test]
330 fn evaluation() {
331 let z = vec![FE::one(), FE::from(2u64), FE::one(), FE::from(4u64)];
333 let r = vec![FE::from(4u64), FE::from(3u64)];
335 let eval_with_lr = evaluate_with_lr(&z, &r);
336 let poly = DenseMultilinearPolynomial::new(z);
337 let eval = poly.evaluate(r).unwrap();
338 assert_eq!(eval, FE::from(28u64));
339 assert_eq!(eval_with_lr, eval);
340 }
341
342 #[test]
343 fn evaluate_with() {
344 let two = FE::from(2);
345 let z = vec![
346 FE::zero(),
347 FE::zero(),
348 FE::zero(),
349 FE::one(),
350 FE::one(),
351 FE::one(),
352 FE::zero(),
353 two,
354 ];
355 let x = vec![FE::one(), FE::one(), FE::one()];
356 let y = DenseMultilinearPolynomial::<F>::evaluate_with(z.as_slice(), x.as_slice()).unwrap();
357 assert_eq!(y, two);
358 }
359
360 #[test]
361 fn add() {
362 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
363 let b = DenseMultilinearPolynomial::new(vec![FE::from(7); 4]);
364 let c = a.add(b).unwrap();
365 assert_eq!(*c.evals(), vec![FE::from(10); 4]);
366 }
367
368 #[test]
369 fn mul() {
370 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
371 let b = a.mul(&FE::from(2));
372 assert_eq!(*b.evals(), vec![FE::from(6); 4]);
373 }
374
375 #[test]
378 fn merge() {
379 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
380 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
381 let c = DenseMultilinearPolynomial::merge(&[a, b]);
382 assert_eq!(c.len(), 8);
383 assert_eq!(c[c.len() - 1], FE::zero());
384 assert_eq!(c[c.len() - 2], FE::zero());
385 }
386
387 #[test]
388 fn extend() {
389 let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
390 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
391 a.extend(&b);
392 assert_eq!(a.len(), 8);
393 assert_eq!(a.num_vars(), 3);
394 }
395
396 #[test]
397 #[should_panic]
398 fn extend_unequal() {
399 let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
400 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
401 a.extend(&b);
402 }
403}