lambdaworks_math/polynomial/
dense_multilinear_poly.rs1use crate::{
2 field::{element::FieldElement, traits::IsField},
3 polynomial::{error::MultilinearError, Polynomial},
4};
5use alloc::{vec, vec::Vec};
6use core::ops::{Add, Index, Mul};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
9
10#[derive(Debug, PartialEq, Clone)]
12pub struct DenseMultilinearPolynomial<F: IsField>
13where
14 <F as IsField>::BaseType: Send + Sync,
15{
16 evals: Vec<FieldElement<F>>,
17 n_vars: usize,
18 len: usize,
19}
20
21impl<F: IsField> DenseMultilinearPolynomial<F>
22where
23 <F as IsField>::BaseType: Send + Sync,
24{
25 pub fn new(mut evals: Vec<FieldElement<F>>) -> Self {
28 while !evals.len().is_power_of_two() {
29 evals.push(FieldElement::zero());
30 }
31 let len = evals.len();
32 DenseMultilinearPolynomial {
33 n_vars: log_2(len),
34 evals,
35 len,
36 }
37 }
38
39 pub fn num_vars(&self) -> usize {
41 self.n_vars
42 }
43
44 pub fn evals(&self) -> &Vec<FieldElement<F>> {
46 &self.evals
47 }
48
49 #[allow(clippy::len_without_is_empty)]
51 pub fn len(&self) -> usize {
52 self.len
53 }
54
55 pub fn evaluate(&self, r: Vec<FieldElement<F>>) -> Result<FieldElement<F>, MultilinearError> {
58 if r.len() != self.num_vars() {
59 return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
60 r.len(),
61 self.num_vars(),
62 ));
63 }
64 let mut chis: Vec<FieldElement<F>> =
65 vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
66 let mut size = 1;
67 for j in r {
68 size *= 2;
69 for i in (0..size).rev().step_by(2) {
70 let half_i = i / 2;
71 let temp = &chis[half_i] * &j;
72 chis[i] = temp;
73 chis[i - 1] = &chis[half_i] - &chis[i];
74 }
75 }
76 #[cfg(feature = "parallel")]
77 let iter = (0..chis.len()).into_par_iter();
78 #[cfg(not(feature = "parallel"))]
79 let iter = 0..chis.len();
80 Ok(iter.map(|i| &self.evals[i] * &chis[i]).sum())
81 }
82
83 pub fn evaluate_with(
85 evals: &[FieldElement<F>],
86 r: &[FieldElement<F>],
87 ) -> Result<FieldElement<F>, MultilinearError> {
88 let mut chis: Vec<FieldElement<F>> =
89 vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
90 if chis.len() != evals.len() {
91 return Err(MultilinearError::ChisAndEvalsLengthMismatch(
92 chis.len(),
93 evals.len(),
94 ));
95 }
96 let mut size = 1;
97 for j in r {
98 size *= 2;
99 for i in (0..size).rev().step_by(2) {
100 let half_i = i / 2;
101 let temp = &chis[half_i] * j;
102 chis[i] = temp;
103 chis[i - 1] = &chis[half_i] - &chis[i];
104 }
105 }
106 Ok((0..evals.len()).map(|i| &evals[i] * &chis[i]).sum())
107 }
108
109 pub fn fix_first_variable(&self, r: &FieldElement<F>) -> DenseMultilinearPolynomial<F> {
124 let n = self.num_vars();
125 assert!(n > 0, "Cannot fix variable in a 0-variable polynomial");
126 let half = 1 << (n - 1);
127 let new_evals: Vec<FieldElement<F>> = (0..half)
128 .map(|j| {
129 let a = &self.evals[j];
130 let b = &self.evals[j + half];
131 a + r * (b - a)
132 })
133 .collect();
134 DenseMultilinearPolynomial::from((n - 1, new_evals))
135 }
136
137 pub fn to_evaluations(&self) -> Vec<FieldElement<F>> {
140 self.evals.clone()
141 }
142
143 pub fn to_univariate(&self) -> Polynomial<FieldElement<F>> {
147 let poly0 = self.fix_first_variable(&FieldElement::zero());
148 let poly1 = self.fix_first_variable(&FieldElement::one());
149 let sum0: FieldElement<F> = poly0.to_evaluations().into_iter().sum();
150 let sum1: FieldElement<F> = poly1.to_evaluations().into_iter().sum();
151 let diff = sum1 - &sum0;
152 Polynomial::new(&[sum0, diff])
153 }
154
155 pub fn scalar_mul(&self, scalar: &FieldElement<F>) -> Self {
157 let mut new_poly = self.clone();
158 new_poly.evals.iter_mut().for_each(|eval| *eval *= scalar);
159 new_poly
160 }
161
162 pub fn extend(&mut self, other: &DenseMultilinearPolynomial<F>) {
164 debug_assert_eq!(self.evals.len(), self.len);
165 debug_assert_eq!(other.evals.len(), self.len);
166 self.evals.extend(other.evals.iter().cloned());
167 self.n_vars += 1;
168 self.len *= 2;
169 debug_assert_eq!(self.evals.len(), self.len);
170 }
171
172 pub fn merge(polys: &[DenseMultilinearPolynomial<F>]) -> DenseMultilinearPolynomial<F> {
175 let mut z: Vec<FieldElement<F>> = Vec::new();
177 for poly in polys {
178 z.extend(poly.evals.iter().cloned());
179 }
180 z.resize(z.len().next_power_of_two(), FieldElement::zero());
181 DenseMultilinearPolynomial::new(z)
182 }
183
184 pub fn from_u64(evals: &[u64]) -> Self {
186 DenseMultilinearPolynomial::new(evals.iter().map(|&i| FieldElement::from(i)).collect())
187 }
188}
189
190impl<F: IsField> Index<usize> for DenseMultilinearPolynomial<F>
191where
192 <F as IsField>::BaseType: Send + Sync,
193{
194 type Output = FieldElement<F>;
195
196 #[inline(always)]
197 fn index(&self, index: usize) -> &FieldElement<F> {
198 &self.evals[index]
199 }
200}
201
202impl<F: IsField> Add for DenseMultilinearPolynomial<F>
205where
206 <F as IsField>::BaseType: Send + Sync,
207{
208 type Output = Result<Self, &'static str>;
209
210 fn add(self, other: Self) -> Self::Output {
211 if self.num_vars() != other.num_vars() {
212 return Err("Polynomials must have the same number of variables");
213 }
214 #[cfg(feature = "parallel")]
215 let evals = self.evals.into_par_iter().zip(other.evals.into_par_iter());
216 #[cfg(not(feature = "parallel"))]
217 let evals = self.evals.iter().zip(other.evals.iter());
218 let sum: Vec<FieldElement<F>> = evals.map(|(a, b)| a + b).collect();
219 Ok(DenseMultilinearPolynomial::new(sum))
220 }
221}
222
223impl<F: IsField> Mul<FieldElement<F>> for DenseMultilinearPolynomial<F>
224where
225 <F as IsField>::BaseType: Send + Sync,
226{
227 type Output = DenseMultilinearPolynomial<F>;
228
229 fn mul(self, rhs: FieldElement<F>) -> Self::Output {
230 Self::scalar_mul(&self, &rhs)
231 }
232}
233
234impl<F: IsField> Mul<&FieldElement<F>> for DenseMultilinearPolynomial<F>
235where
236 <F as IsField>::BaseType: Send + Sync,
237{
238 type Output = DenseMultilinearPolynomial<F>;
239
240 fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
241 Self::scalar_mul(&self, rhs)
242 }
243}
244
245fn log_2(n: usize) -> usize {
247 if n == 0 {
248 return 0;
249 }
250 if n.is_power_of_two() {
251 (1usize.leading_zeros() - n.leading_zeros()) as usize
252 } else {
253 (0usize.leading_zeros() - n.leading_zeros()) as usize
254 }
255}
256
257impl<F: IsField> From<(usize, Vec<FieldElement<F>>)> for DenseMultilinearPolynomial<F>
258where
259 <F as IsField>::BaseType: Send + Sync,
260{
261 fn from((num_vars, evaluations): (usize, Vec<FieldElement<F>>)) -> Self {
262 assert_eq!(
263 evaluations.len(),
264 1 << num_vars,
265 "The size of evaluations should be 2^num_vars."
266 );
267 DenseMultilinearPolynomial {
268 n_vars: num_vars,
269 evals: evaluations,
270 len: 1 << num_vars,
271 }
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::field::fields::u64_prime_field::U64PrimeField;
279 const ORDER: u64 = 101;
280 type F = U64PrimeField<ORDER>;
281 type FE = FieldElement<F>;
282
283 pub fn evals(r: Vec<FE>) -> Vec<FE> {
284 let mut evals: Vec<FE> = vec![FE::one(); (2usize).pow(r.len() as u32)];
285 let mut size = 1;
286 for j in r {
287 size *= 2;
288 for i in (0..size).rev().step_by(2) {
289 let scalar = evals[i / 2];
290 evals[i] = scalar * j;
291 evals[i - 1] = scalar - evals[i];
292 }
293 }
294 evals
295 }
296
297 pub fn compute_factored_evals(r: Vec<FE>) -> (Vec<FE>, Vec<FE>) {
298 let size = r.len();
299 let (left_num_vars, _right_num_vars) = (size / 2, size - size / 2);
300 let l = evals(r[..left_num_vars].to_vec());
301 let r = evals(r[left_num_vars..size].to_vec());
302 (l, r)
303 }
304
305 fn evaluate_with_lr(z: &[FE], r: &[FE]) -> FE {
306 let (l, r) = compute_factored_evals(r.to_vec());
307 let size = r.len();
308 assert!(size % 2 == 0);
310 let n = (2usize).pow(size as u32);
312 let m = (n as f64).sqrt() as usize;
314 let lz = (0..m)
316 .map(|i| {
317 (0..m).fold(FE::zero(), |mut acc, j| {
318 acc += l[j] * z[j * m + i];
319 acc
320 })
321 })
322 .collect::<Vec<FE>>();
323 (0..lz.len()).map(|i| lz[i] * r[i]).sum()
325 }
326
327 #[test]
328 fn evaluation() {
329 let z = vec![FE::one(), FE::from(2u64), FE::one(), FE::from(4u64)];
331 let r = vec![FE::from(4u64), FE::from(3u64)];
333 let eval_with_lr = evaluate_with_lr(&z, &r);
334 let poly = DenseMultilinearPolynomial::new(z);
335 let eval = poly.evaluate(r).unwrap();
336 assert_eq!(eval, FE::from(28u64));
337 assert_eq!(eval_with_lr, eval);
338 }
339
340 #[test]
341 fn evaluate_with() {
342 let two = FE::from(2);
343 let z = vec![
344 FE::zero(),
345 FE::zero(),
346 FE::zero(),
347 FE::one(),
348 FE::one(),
349 FE::one(),
350 FE::zero(),
351 two,
352 ];
353 let x = vec![FE::one(), FE::one(), FE::one()];
354 let y = DenseMultilinearPolynomial::<F>::evaluate_with(z.as_slice(), x.as_slice()).unwrap();
355 assert_eq!(y, two);
356 }
357
358 #[test]
359 fn add() {
360 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
361 let b = DenseMultilinearPolynomial::new(vec![FE::from(7); 4]);
362 let c = a.add(b).unwrap();
363 assert_eq!(*c.evals(), vec![FE::from(10); 4]);
364 }
365
366 #[test]
367 fn mul() {
368 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
369 let b = a.mul(&FE::from(2));
370 assert_eq!(*b.evals(), vec![FE::from(6); 4]);
371 }
372
373 #[test]
376 fn merge() {
377 let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
378 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
379 let c = DenseMultilinearPolynomial::merge(&[a, b]);
380 assert_eq!(c.len(), 8);
381 assert_eq!(c[c.len() - 1], FE::zero());
382 assert_eq!(c[c.len() - 2], FE::zero());
383 }
384
385 #[test]
386 fn extend() {
387 let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
388 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
389 a.extend(&b);
390 assert_eq!(a.len(), 8);
391 assert_eq!(a.num_vars(), 3);
392 }
393
394 #[test]
395 #[should_panic]
396 fn extend_unequal() {
397 let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
398 let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
399 a.extend(&b);
400 }
401}