lambdaworks_math/polynomial/
dense_multilinear_poly.rs

1use crate::{
2    field::{element::FieldElement, traits::IsField},
3    polynomial::{error::MultilinearError, Polynomial},
4};
5use alloc::{vec, vec::Vec};
6use core::ops::{Add, Index, Mul};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
9
10/// Represents a multilinear polynomial as a vector of evaluations (FieldElements) in Lagrange basis.
11#[derive(Debug, PartialEq, Clone)]
12pub struct DenseMultilinearPolynomial<F: IsField>
13where
14    <F as IsField>::BaseType: Send + Sync,
15{
16    evals: Vec<FieldElement<F>>,
17    n_vars: usize,
18    len: usize,
19}
20
21impl<F: IsField> DenseMultilinearPolynomial<F>
22where
23    <F as IsField>::BaseType: Send + Sync,
24{
25    /// Constructs a new multilinear polynomial from a collection of evaluations.
26    /// Pads non-power-of-2 evaluations with zeros.
27    pub fn new(mut evals: Vec<FieldElement<F>>) -> Self {
28        while !evals.len().is_power_of_two() {
29            evals.push(FieldElement::zero());
30        }
31        let len = evals.len();
32        DenseMultilinearPolynomial {
33            n_vars: log_2(len),
34            evals,
35            len,
36        }
37    }
38
39    /// Returns the number of variables.
40    pub fn num_vars(&self) -> usize {
41        self.n_vars
42    }
43
44    /// Returns a reference to the evaluations vector.
45    pub fn evals(&self) -> &Vec<FieldElement<F>> {
46        &self.evals
47    }
48
49    /// Returns the total number of evaluations (2^num_vars).
50    #[allow(clippy::len_without_is_empty)]
51    pub fn len(&self) -> usize {
52        self.len
53    }
54
55    /// Evaluates `self` at the point `r` (a vector of FieldElements) in O(n) time.
56    /// `r` must have a value for each variable.
57    pub fn evaluate(&self, r: Vec<FieldElement<F>>) -> Result<FieldElement<F>, MultilinearError> {
58        if r.len() != self.num_vars() {
59            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
60                r.len(),
61                self.num_vars(),
62            ));
63        }
64        let mut chis: Vec<FieldElement<F>> =
65            vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
66        let mut size = 1;
67        for j in r {
68            size *= 2;
69            for i in (0..size).rev().step_by(2) {
70                let half_i = i / 2;
71                let temp = &chis[half_i] * &j;
72                chis[i] = temp;
73                chis[i - 1] = &chis[half_i] - &chis[i];
74            }
75        }
76        #[cfg(feature = "parallel")]
77        let iter = (0..chis.len()).into_par_iter();
78        #[cfg(not(feature = "parallel"))]
79        let iter = 0..chis.len();
80        Ok(iter.map(|i| &self.evals[i] * &chis[i]).sum())
81    }
82
83    /// Evaluates a slice of evaluations with the given point `r`.
84    pub fn evaluate_with(
85        evals: &[FieldElement<F>],
86        r: &[FieldElement<F>],
87    ) -> Result<FieldElement<F>, MultilinearError> {
88        let mut chis: Vec<FieldElement<F>> =
89            vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
90        if chis.len() != evals.len() {
91            return Err(MultilinearError::ChisAndEvalsLengthMismatch(
92                chis.len(),
93                evals.len(),
94            ));
95        }
96        let mut size = 1;
97        for j in r {
98            size *= 2;
99            for i in (0..size).rev().step_by(2) {
100                let half_i = i / 2;
101                let temp = &chis[half_i] * j;
102                chis[i] = temp;
103                chis[i - 1] = &chis[half_i] - &chis[i];
104            }
105        }
106        Ok((0..evals.len()).map(|i| &evals[i] * &chis[i]).sum())
107    }
108
109    /// Fixes the last variable to the given value `r` and returns a new DenseMultilinearPolynomial
110    /// with one fewer variable.
111    /// Evaluations are ordered so that the first half corresponds to the last variable = 0,
112    /// and the second half corresponds to the last variable = 1.
113    ///
114    /// Combines each pair of evaluations as: new_eval = a + r * (b - a)
115    ///  This reduces the polynomial by one variable, allowing it to later be collapsed
116    /// into a univariate polynomial by summing over the remaining variables.
117    ///
118    /// Example (2 variables): evaluations ordered as:
119    ///     [f(0,0), f(0,1), f(1,0), f(1,1)]
120    /// Fixing the second variable `y = r` produces evaluations of a 1-variable polynomial:
121    ///     [f(0,r), f(1,r)]
122    /// computed explicitly as:
123    ///     f(0,r) = f(0,0) + r*(f(0,1)-f(0,0)),
124    ///     f(1,r) = f(1,0) + r*(f(1,1)-f(1,0))
125    pub fn fix_last_variable(&self, r: &FieldElement<F>) -> DenseMultilinearPolynomial<F> {
126        let n = self.num_vars();
127        assert!(n > 0, "Cannot fix variable in a 0-variable polynomial");
128        let half = 1 << (n - 1);
129        let new_evals: Vec<FieldElement<F>> = (0..half)
130            .map(|j| {
131                let a = &self.evals[j];
132                let b = &self.evals[j + half];
133                a + r * (b - a)
134            })
135            .collect();
136        DenseMultilinearPolynomial::from((n - 1, new_evals))
137    }
138
139    /// Returns the evaluations of the polynomial on the Boolean hypercube \(\{0,1\}^n\).
140    /// Since we are in Lagrange basis, this is just the elements stored in self.evals.
141    pub fn to_evaluations(&self) -> Vec<FieldElement<F>> {
142        self.evals.clone()
143    }
144
145    /// Collapses the last variable by fixing it to 0 and 1,
146    /// sums the evaluations, and returns a univariate polynomial (as a Polynomial)
147    /// of the form: sum0 + (sum1 - sum0) * x.
148    pub fn to_univariate(&self) -> Polynomial<FieldElement<F>> {
149        let poly0 = self.fix_last_variable(&FieldElement::zero());
150        let poly1 = self.fix_last_variable(&FieldElement::one());
151        let sum0: FieldElement<F> = poly0.to_evaluations().into_iter().sum();
152        let sum1: FieldElement<F> = poly1.to_evaluations().into_iter().sum();
153        let diff = sum1 - &sum0;
154        Polynomial::new(&[sum0, diff])
155    }
156
157    /// Multiplies the polynomial by a scalar.
158    pub fn scalar_mul(&self, scalar: &FieldElement<F>) -> Self {
159        let mut new_poly = self.clone();
160        new_poly.evals.iter_mut().for_each(|eval| *eval *= scalar);
161        new_poly
162    }
163
164    /// Extends this DenseMultilinearPolynomial by concatenating another polynomial of the same length.
165    pub fn extend(&mut self, other: &DenseMultilinearPolynomial<F>) {
166        debug_assert_eq!(self.evals.len(), self.len);
167        debug_assert_eq!(other.evals.len(), self.len);
168        self.evals.extend(other.evals.iter().cloned());
169        self.n_vars += 1;
170        self.len *= 2;
171        debug_assert_eq!(self.evals.len(), self.len);
172    }
173
174    /// Merges a series of DenseMultilinearPolynomials into one polynomial.
175    /// Zero-pads the final merged polynomial to the next power-of-two length if necessary.
176    pub fn merge(polys: &[DenseMultilinearPolynomial<F>]) -> DenseMultilinearPolynomial<F> {
177        // TODO (performance): pre-allocate vector to avoid repeated resizing.
178        let mut z: Vec<FieldElement<F>> = Vec::new();
179        for poly in polys {
180            z.extend(poly.evals.iter().cloned());
181        }
182        z.resize(z.len().next_power_of_two(), FieldElement::zero());
183        DenseMultilinearPolynomial::new(z)
184    }
185
186    /// Constructs a DenseMultilinearPolynomial from a slice of u64 values.
187    pub fn from_u64(evals: &[u64]) -> Self {
188        DenseMultilinearPolynomial::new(evals.iter().map(|&i| FieldElement::from(i)).collect())
189    }
190}
191
192impl<F: IsField> Index<usize> for DenseMultilinearPolynomial<F>
193where
194    <F as IsField>::BaseType: Send + Sync,
195{
196    type Output = FieldElement<F>;
197
198    #[inline(always)]
199    fn index(&self, index: usize) -> &FieldElement<F> {
200        &self.evals[index]
201    }
202}
203
204/// Adds two DenseMultilinearPolynomials.
205/// Assumes that both polynomials have the same number of variables.
206impl<F: IsField> Add for DenseMultilinearPolynomial<F>
207where
208    <F as IsField>::BaseType: Send + Sync,
209{
210    type Output = Result<Self, &'static str>;
211
212    fn add(self, other: Self) -> Self::Output {
213        if self.num_vars() != other.num_vars() {
214            return Err("Polynomials must have the same number of variables");
215        }
216        #[cfg(feature = "parallel")]
217        let evals = self.evals.into_par_iter().zip(other.evals.into_par_iter());
218        #[cfg(not(feature = "parallel"))]
219        let evals = self.evals.iter().zip(other.evals.iter());
220        let sum: Vec<FieldElement<F>> = evals.map(|(a, b)| a + b).collect();
221        Ok(DenseMultilinearPolynomial::new(sum))
222    }
223}
224
225impl<F: IsField> Mul<FieldElement<F>> for DenseMultilinearPolynomial<F>
226where
227    <F as IsField>::BaseType: Send + Sync,
228{
229    type Output = DenseMultilinearPolynomial<F>;
230
231    fn mul(self, rhs: FieldElement<F>) -> Self::Output {
232        Self::scalar_mul(&self, &rhs)
233    }
234}
235
236impl<F: IsField> Mul<&FieldElement<F>> for DenseMultilinearPolynomial<F>
237where
238    <F as IsField>::BaseType: Send + Sync,
239{
240    type Output = DenseMultilinearPolynomial<F>;
241
242    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
243        Self::scalar_mul(&self, rhs)
244    }
245}
246
247/// Helper function to calculate logâ‚‚(n).
248fn log_2(n: usize) -> usize {
249    if n == 0 {
250        return 0;
251    }
252    if n.is_power_of_two() {
253        (1usize.leading_zeros() - n.leading_zeros()) as usize
254    } else {
255        (0usize.leading_zeros() - n.leading_zeros()) as usize
256    }
257}
258
259impl<F: IsField> From<(usize, Vec<FieldElement<F>>)> for DenseMultilinearPolynomial<F>
260where
261    <F as IsField>::BaseType: Send + Sync,
262{
263    fn from((num_vars, evaluations): (usize, Vec<FieldElement<F>>)) -> Self {
264        assert_eq!(
265            evaluations.len(),
266            1 << num_vars,
267            "The size of evaluations should be 2^num_vars."
268        );
269        DenseMultilinearPolynomial {
270            n_vars: num_vars,
271            evals: evaluations,
272            len: 1 << num_vars,
273        }
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280    use crate::field::fields::u64_prime_field::U64PrimeField;
281    const ORDER: u64 = 101;
282    type F = U64PrimeField<ORDER>;
283    type FE = FieldElement<F>;
284
285    pub fn evals(r: Vec<FE>) -> Vec<FE> {
286        let mut evals: Vec<FE> = vec![FE::one(); (2usize).pow(r.len() as u32)];
287        let mut size = 1;
288        for j in r {
289            size *= 2;
290            for i in (0..size).rev().step_by(2) {
291                let scalar = evals[i / 2];
292                evals[i] = scalar * j;
293                evals[i - 1] = scalar - evals[i];
294            }
295        }
296        evals
297    }
298
299    pub fn compute_factored_evals(r: Vec<FE>) -> (Vec<FE>, Vec<FE>) {
300        let size = r.len();
301        let (left_num_vars, _right_num_vars) = (size / 2, size - size / 2);
302        let l = evals(r[..left_num_vars].to_vec());
303        let r = evals(r[left_num_vars..size].to_vec());
304        (l, r)
305    }
306
307    fn evaluate_with_lr(z: &[FE], r: &[FE]) -> FE {
308        let (l, r) = compute_factored_evals(r.to_vec());
309        let size = r.len();
310        // Ensure size is even.
311        assert!(size % 2 == 0);
312        // n = 2^size
313        let n = (2usize).pow(size as u32);
314        // Compute m = sqrt(n) = 2^(l/2)
315        let m = (n as f64).sqrt() as usize;
316        // Compute vector-matrix product between L and Z (viewed as a matrix)
317        let lz = (0..m)
318            .map(|i| {
319                (0..m).fold(FE::zero(), |mut acc, j| {
320                    acc += l[j] * z[j * m + i];
321                    acc
322                })
323            })
324            .collect::<Vec<FE>>();
325        // Compute dot product between LZ and R
326        (0..lz.len()).map(|i| lz[i] * r[i]).sum()
327    }
328
329    #[test]
330    fn evaluation() {
331        // Example: Z = [1, 2, 1, 4]
332        let z = vec![FE::one(), FE::from(2u64), FE::one(), FE::from(4u64)];
333        // r = [4, 3]
334        let r = vec![FE::from(4u64), FE::from(3u64)];
335        let eval_with_lr = evaluate_with_lr(&z, &r);
336        let poly = DenseMultilinearPolynomial::new(z);
337        let eval = poly.evaluate(r).unwrap();
338        assert_eq!(eval, FE::from(28u64));
339        assert_eq!(eval_with_lr, eval);
340    }
341
342    #[test]
343    fn evaluate_with() {
344        let two = FE::from(2);
345        let z = vec![
346            FE::zero(),
347            FE::zero(),
348            FE::zero(),
349            FE::one(),
350            FE::one(),
351            FE::one(),
352            FE::zero(),
353            two,
354        ];
355        let x = vec![FE::one(), FE::one(), FE::one()];
356        let y = DenseMultilinearPolynomial::<F>::evaluate_with(z.as_slice(), x.as_slice()).unwrap();
357        assert_eq!(y, two);
358    }
359
360    #[test]
361    fn add() {
362        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
363        let b = DenseMultilinearPolynomial::new(vec![FE::from(7); 4]);
364        let c = a.add(b).unwrap();
365        assert_eq!(*c.evals(), vec![FE::from(10); 4]);
366    }
367
368    #[test]
369    fn mul() {
370        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
371        let b = a.mul(&FE::from(2));
372        assert_eq!(*b.evals(), vec![FE::from(6); 4]);
373    }
374
375    // Take a multilinear polynomial of length 2^2 and merge with a polynomial of 2^1.
376    // The resulting polynomial should be padded to length 2^3 = 8 and the last two evaluations should be FE::zero().
377    #[test]
378    fn merge() {
379        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
380        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
381        let c = DenseMultilinearPolynomial::merge(&[a, b]);
382        assert_eq!(c.len(), 8);
383        assert_eq!(c[c.len() - 1], FE::zero());
384        assert_eq!(c[c.len() - 2], FE::zero());
385    }
386
387    #[test]
388    fn extend() {
389        let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
390        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
391        a.extend(&b);
392        assert_eq!(a.len(), 8);
393        assert_eq!(a.num_vars(), 3);
394    }
395
396    #[test]
397    #[should_panic]
398    fn extend_unequal() {
399        let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
400        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
401        a.extend(&b);
402    }
403}