lambdaworks_math/polynomial/
dense_multilinear_poly.rs

1use crate::{
2    field::{element::FieldElement, traits::IsField},
3    polynomial::{error::MultilinearError, Polynomial},
4};
5use alloc::{vec, vec::Vec};
6use core::ops::{Add, Index, Mul};
7#[cfg(feature = "parallel")]
8use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
9
10/// Represents a multilinear polynomial as a vector of evaluations (FieldElements) in Lagrange basis.
11#[derive(Debug, PartialEq, Clone)]
12pub struct DenseMultilinearPolynomial<F: IsField>
13where
14    <F as IsField>::BaseType: Send + Sync,
15{
16    evals: Vec<FieldElement<F>>,
17    n_vars: usize,
18    len: usize,
19}
20
21impl<F: IsField> DenseMultilinearPolynomial<F>
22where
23    <F as IsField>::BaseType: Send + Sync,
24{
25    /// Constructs a new multilinear polynomial from a collection of evaluations.
26    /// Pads non-power-of-2 evaluations with zeros.
27    pub fn new(mut evals: Vec<FieldElement<F>>) -> Self {
28        while !evals.len().is_power_of_two() {
29            evals.push(FieldElement::zero());
30        }
31        let len = evals.len();
32        DenseMultilinearPolynomial {
33            n_vars: log_2(len),
34            evals,
35            len,
36        }
37    }
38
39    /// Returns the number of variables.
40    pub fn num_vars(&self) -> usize {
41        self.n_vars
42    }
43
44    /// Returns a reference to the evaluations vector.
45    pub fn evals(&self) -> &Vec<FieldElement<F>> {
46        &self.evals
47    }
48
49    /// Returns the total number of evaluations (2^num_vars).
50    #[allow(clippy::len_without_is_empty)]
51    pub fn len(&self) -> usize {
52        self.len
53    }
54
55    /// Evaluates `self` at the point `r` (a vector of FieldElements) in O(n) time.
56    /// `r` must have a value for each variable.
57    pub fn evaluate(&self, r: Vec<FieldElement<F>>) -> Result<FieldElement<F>, MultilinearError> {
58        if r.len() != self.num_vars() {
59            return Err(MultilinearError::IncorrectNumberofEvaluationPoints(
60                r.len(),
61                self.num_vars(),
62            ));
63        }
64        let mut chis: Vec<FieldElement<F>> =
65            vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
66        let mut size = 1;
67        for j in r {
68            size *= 2;
69            for i in (0..size).rev().step_by(2) {
70                let half_i = i / 2;
71                let temp = &chis[half_i] * &j;
72                chis[i] = temp;
73                chis[i - 1] = &chis[half_i] - &chis[i];
74            }
75        }
76        #[cfg(feature = "parallel")]
77        let iter = (0..chis.len()).into_par_iter();
78        #[cfg(not(feature = "parallel"))]
79        let iter = 0..chis.len();
80        Ok(iter.map(|i| &self.evals[i] * &chis[i]).sum())
81    }
82
83    /// Evaluates a slice of evaluations with the given point `r`.
84    pub fn evaluate_with(
85        evals: &[FieldElement<F>],
86        r: &[FieldElement<F>],
87    ) -> Result<FieldElement<F>, MultilinearError> {
88        let mut chis: Vec<FieldElement<F>> =
89            vec![FieldElement::one(); (2usize).pow(r.len() as u32)];
90        if chis.len() != evals.len() {
91            return Err(MultilinearError::ChisAndEvalsLengthMismatch(
92                chis.len(),
93                evals.len(),
94            ));
95        }
96        let mut size = 1;
97        for j in r {
98            size *= 2;
99            for i in (0..size).rev().step_by(2) {
100                let half_i = i / 2;
101                let temp = &chis[half_i] * j;
102                chis[i] = temp;
103                chis[i - 1] = &chis[half_i] - &chis[i];
104            }
105        }
106        Ok((0..evals.len()).map(|i| &evals[i] * &chis[i]).sum())
107    }
108
109    /// Fixes the first variable to the given value `r` and returns a new DenseMultilinearPolynomial
110    /// with one fewer variable.
111    ///
112    /// Combines each pair of evaluations as: new_eval = a + r * (b - a)
113    ///  This reduces the polynomial by one variable, allowing it to later be collapsed
114    /// into a univariate polynomial by summing over the remaining variables.
115    ///
116    /// Example (2 variables): evaluations are ordered as:
117    ///     [f(0,0), f(0,1), f(1,0), f(1,1)]
118    /// Fixing the first variable `x = r` produces evaluations of a 1-variable polynomial:
119    ///     [f(r,0), f(r,1)]
120    /// computed explicitly as:
121    ///     f(r,0) = f(0,0) + r * ( f(1,0) - f(0,0)),
122    ///     f(r,1) = f(0,1) + r * (f(1,1) - f(0,1))
123    pub fn fix_first_variable(&self, r: &FieldElement<F>) -> DenseMultilinearPolynomial<F> {
124        let n = self.num_vars();
125        assert!(n > 0, "Cannot fix variable in a 0-variable polynomial");
126        let half = 1 << (n - 1);
127        let new_evals: Vec<FieldElement<F>> = (0..half)
128            .map(|j| {
129                let a = &self.evals[j];
130                let b = &self.evals[j + half];
131                a + r * (b - a)
132            })
133            .collect();
134        DenseMultilinearPolynomial::from((n - 1, new_evals))
135    }
136
137    /// Returns the evaluations of the polynomial on the Boolean hypercube \(\{0,1\}^n\).
138    /// Since we are in Lagrange basis, this is just the elements stored in self.evals.
139    pub fn to_evaluations(&self) -> Vec<FieldElement<F>> {
140        self.evals.clone()
141    }
142
143    /// Collapses the last variable by fixing it to 0 and 1,
144    /// sums the evaluations, and returns a univariate polynomial (as a Polynomial)
145    /// of the form: sum0 + (sum1 - sum0) * x.
146    pub fn to_univariate(&self) -> Polynomial<FieldElement<F>> {
147        let poly0 = self.fix_first_variable(&FieldElement::zero());
148        let poly1 = self.fix_first_variable(&FieldElement::one());
149        let sum0: FieldElement<F> = poly0.to_evaluations().into_iter().sum();
150        let sum1: FieldElement<F> = poly1.to_evaluations().into_iter().sum();
151        let diff = sum1 - &sum0;
152        Polynomial::new(&[sum0, diff])
153    }
154
155    /// Multiplies the polynomial by a scalar.
156    pub fn scalar_mul(&self, scalar: &FieldElement<F>) -> Self {
157        let mut new_poly = self.clone();
158        new_poly.evals.iter_mut().for_each(|eval| *eval *= scalar);
159        new_poly
160    }
161
162    /// Extends this DenseMultilinearPolynomial by concatenating another polynomial of the same length.
163    pub fn extend(&mut self, other: &DenseMultilinearPolynomial<F>) {
164        debug_assert_eq!(self.evals.len(), self.len);
165        debug_assert_eq!(other.evals.len(), self.len);
166        self.evals.extend(other.evals.iter().cloned());
167        self.n_vars += 1;
168        self.len *= 2;
169        debug_assert_eq!(self.evals.len(), self.len);
170    }
171
172    /// Merges a series of DenseMultilinearPolynomials into one polynomial.
173    /// Zero-pads the final merged polynomial to the next power-of-two length if necessary.
174    pub fn merge(polys: &[DenseMultilinearPolynomial<F>]) -> DenseMultilinearPolynomial<F> {
175        // TODO (performance): pre-allocate vector to avoid repeated resizing.
176        let mut z: Vec<FieldElement<F>> = Vec::new();
177        for poly in polys {
178            z.extend(poly.evals.iter().cloned());
179        }
180        z.resize(z.len().next_power_of_two(), FieldElement::zero());
181        DenseMultilinearPolynomial::new(z)
182    }
183
184    /// Constructs a DenseMultilinearPolynomial from a slice of u64 values.
185    pub fn from_u64(evals: &[u64]) -> Self {
186        DenseMultilinearPolynomial::new(evals.iter().map(|&i| FieldElement::from(i)).collect())
187    }
188}
189
190impl<F: IsField> Index<usize> for DenseMultilinearPolynomial<F>
191where
192    <F as IsField>::BaseType: Send + Sync,
193{
194    type Output = FieldElement<F>;
195
196    #[inline(always)]
197    fn index(&self, index: usize) -> &FieldElement<F> {
198        &self.evals[index]
199    }
200}
201
202/// Adds two DenseMultilinearPolynomials.
203/// Assumes that both polynomials have the same number of variables.
204impl<F: IsField> Add for DenseMultilinearPolynomial<F>
205where
206    <F as IsField>::BaseType: Send + Sync,
207{
208    type Output = Result<Self, &'static str>;
209
210    fn add(self, other: Self) -> Self::Output {
211        if self.num_vars() != other.num_vars() {
212            return Err("Polynomials must have the same number of variables");
213        }
214        #[cfg(feature = "parallel")]
215        let evals = self.evals.into_par_iter().zip(other.evals.into_par_iter());
216        #[cfg(not(feature = "parallel"))]
217        let evals = self.evals.iter().zip(other.evals.iter());
218        let sum: Vec<FieldElement<F>> = evals.map(|(a, b)| a + b).collect();
219        Ok(DenseMultilinearPolynomial::new(sum))
220    }
221}
222
223impl<F: IsField> Mul<FieldElement<F>> for DenseMultilinearPolynomial<F>
224where
225    <F as IsField>::BaseType: Send + Sync,
226{
227    type Output = DenseMultilinearPolynomial<F>;
228
229    fn mul(self, rhs: FieldElement<F>) -> Self::Output {
230        Self::scalar_mul(&self, &rhs)
231    }
232}
233
234impl<F: IsField> Mul<&FieldElement<F>> for DenseMultilinearPolynomial<F>
235where
236    <F as IsField>::BaseType: Send + Sync,
237{
238    type Output = DenseMultilinearPolynomial<F>;
239
240    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
241        Self::scalar_mul(&self, rhs)
242    }
243}
244
245/// Helper function to calculate logâ‚‚(n).
246fn log_2(n: usize) -> usize {
247    if n == 0 {
248        return 0;
249    }
250    if n.is_power_of_two() {
251        (1usize.leading_zeros() - n.leading_zeros()) as usize
252    } else {
253        (0usize.leading_zeros() - n.leading_zeros()) as usize
254    }
255}
256
257impl<F: IsField> From<(usize, Vec<FieldElement<F>>)> for DenseMultilinearPolynomial<F>
258where
259    <F as IsField>::BaseType: Send + Sync,
260{
261    fn from((num_vars, evaluations): (usize, Vec<FieldElement<F>>)) -> Self {
262        assert_eq!(
263            evaluations.len(),
264            1 << num_vars,
265            "The size of evaluations should be 2^num_vars."
266        );
267        DenseMultilinearPolynomial {
268            n_vars: num_vars,
269            evals: evaluations,
270            len: 1 << num_vars,
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::field::fields::u64_prime_field::U64PrimeField;
279    const ORDER: u64 = 101;
280    type F = U64PrimeField<ORDER>;
281    type FE = FieldElement<F>;
282
283    pub fn evals(r: Vec<FE>) -> Vec<FE> {
284        let mut evals: Vec<FE> = vec![FE::one(); (2usize).pow(r.len() as u32)];
285        let mut size = 1;
286        for j in r {
287            size *= 2;
288            for i in (0..size).rev().step_by(2) {
289                let scalar = evals[i / 2];
290                evals[i] = scalar * j;
291                evals[i - 1] = scalar - evals[i];
292            }
293        }
294        evals
295    }
296
297    pub fn compute_factored_evals(r: Vec<FE>) -> (Vec<FE>, Vec<FE>) {
298        let size = r.len();
299        let (left_num_vars, _right_num_vars) = (size / 2, size - size / 2);
300        let l = evals(r[..left_num_vars].to_vec());
301        let r = evals(r[left_num_vars..size].to_vec());
302        (l, r)
303    }
304
305    fn evaluate_with_lr(z: &[FE], r: &[FE]) -> FE {
306        let (l, r) = compute_factored_evals(r.to_vec());
307        let size = r.len();
308        // Ensure size is even.
309        assert!(size % 2 == 0);
310        // n = 2^size
311        let n = (2usize).pow(size as u32);
312        // Compute m = sqrt(n) = 2^(l/2)
313        let m = (n as f64).sqrt() as usize;
314        // Compute vector-matrix product between L and Z (viewed as a matrix)
315        let lz = (0..m)
316            .map(|i| {
317                (0..m).fold(FE::zero(), |mut acc, j| {
318                    acc += l[j] * z[j * m + i];
319                    acc
320                })
321            })
322            .collect::<Vec<FE>>();
323        // Compute dot product between LZ and R
324        (0..lz.len()).map(|i| lz[i] * r[i]).sum()
325    }
326
327    #[test]
328    fn evaluation() {
329        // Example: Z = [1, 2, 1, 4]
330        let z = vec![FE::one(), FE::from(2u64), FE::one(), FE::from(4u64)];
331        // r = [4, 3]
332        let r = vec![FE::from(4u64), FE::from(3u64)];
333        let eval_with_lr = evaluate_with_lr(&z, &r);
334        let poly = DenseMultilinearPolynomial::new(z);
335        let eval = poly.evaluate(r).unwrap();
336        assert_eq!(eval, FE::from(28u64));
337        assert_eq!(eval_with_lr, eval);
338    }
339
340    #[test]
341    fn evaluate_with() {
342        let two = FE::from(2);
343        let z = vec![
344            FE::zero(),
345            FE::zero(),
346            FE::zero(),
347            FE::one(),
348            FE::one(),
349            FE::one(),
350            FE::zero(),
351            two,
352        ];
353        let x = vec![FE::one(), FE::one(), FE::one()];
354        let y = DenseMultilinearPolynomial::<F>::evaluate_with(z.as_slice(), x.as_slice()).unwrap();
355        assert_eq!(y, two);
356    }
357
358    #[test]
359    fn add() {
360        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
361        let b = DenseMultilinearPolynomial::new(vec![FE::from(7); 4]);
362        let c = a.add(b).unwrap();
363        assert_eq!(*c.evals(), vec![FE::from(10); 4]);
364    }
365
366    #[test]
367    fn mul() {
368        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
369        let b = a.mul(&FE::from(2));
370        assert_eq!(*b.evals(), vec![FE::from(6); 4]);
371    }
372
373    // Take a multilinear polynomial of length 2^2 and merge with a polynomial of 2^1.
374    // The resulting polynomial should be padded to length 2^3 = 8 and the last two evaluations should be FE::zero().
375    #[test]
376    fn merge() {
377        let a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
378        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
379        let c = DenseMultilinearPolynomial::merge(&[a, b]);
380        assert_eq!(c.len(), 8);
381        assert_eq!(c[c.len() - 1], FE::zero());
382        assert_eq!(c[c.len() - 2], FE::zero());
383    }
384
385    #[test]
386    fn extend() {
387        let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
388        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
389        a.extend(&b);
390        assert_eq!(a.len(), 8);
391        assert_eq!(a.num_vars(), 3);
392    }
393
394    #[test]
395    #[should_panic]
396    fn extend_unequal() {
397        let mut a = DenseMultilinearPolynomial::new(vec![FE::from(3); 4]);
398        let b = DenseMultilinearPolynomial::new(vec![FE::from(3); 2]);
399        a.extend(&b);
400    }
401}