Skip to main content

proof_cat/poly/
multilinear.rs

1//! Multilinear polynomials over the Boolean hypercube.
2//!
3//! A [`MultilinearPoly<F>`] is uniquely determined by its `2^n`
4//! evaluations on `{0,1}^n`.  It supports efficient evaluation at
5//! arbitrary points via iterated partial evaluation (folding).
6//!
7//! This is the core data structure for the sumcheck protocol:
8//! each round binds one variable, halving the evaluation table.
9
10use crate::error::Error;
11use plonkish_cat::Field;
12
13/// The number of variables in a multilinear polynomial.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct NumVars(usize);
16
17impl NumVars {
18    /// Create a new variable count.
19    #[must_use]
20    pub fn new(n: usize) -> Self {
21        Self(n)
22    }
23
24    /// The underlying count.
25    #[must_use]
26    pub fn count(self) -> usize {
27        self.0
28    }
29}
30
31impl core::fmt::Display for NumVars {
32    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33        write!(f, "{}", self.0)
34    }
35}
36
37/// A multilinear polynomial represented by its evaluation table.
38///
39/// The table has `2^num_vars` entries in big-endian bit order:
40/// the first variable is the most significant bit.  For a
41/// 2-variable polynomial `f(x0, x1)`:
42///
43/// | Index | x0 | x1 | Value |
44/// |-------|----|----|-------|
45/// | 0     | 0  | 0  | `evals[0]` |
46/// | 1     | 0  | 1  | `evals[1]` |
47/// | 2     | 1  | 0  | `evals[2]` |
48/// | 3     | 1  | 1  | `evals[3]` |
49///
50/// # Examples
51///
52/// ```
53/// use plonkish_cat::F101;
54/// use proof_cat::MultilinearPoly;
55///
56/// // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
57/// let poly = MultilinearPoly::from_evals(vec![
58///     F101::new(1), F101::new(2), F101::new(3), F101::new(4),
59/// ])?;
60///
61/// assert_eq!(poly.num_vars().count(), 2);
62///
63/// // Sum over the Boolean hypercube: 1 + 2 + 3 + 4 = 10
64/// assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
65///
66/// // Evaluate at a Boolean point: f(1, 0) = 3
67/// let val = poly.evaluate(&[F101::new(1), F101::new(0)])?;
68/// assert_eq!(val, F101::new(3));
69/// # Ok::<(), proof_cat::Error>(())
70/// ```
71#[derive(Debug, Clone)]
72pub struct MultilinearPoly<F: Field> {
73    evals: Vec<F>,
74    num_vars: NumVars,
75}
76
77impl<F: Field> MultilinearPoly<F> {
78    /// Construct from an evaluation table.
79    ///
80    /// The table length must be a power of two (including 1).
81    ///
82    /// # Errors
83    ///
84    /// Returns [`Error::NotPowerOfTwo`] if `evals.len()` is not
85    /// a power of two or is zero.
86    ///
87    /// # Examples
88    ///
89    /// ```
90    /// use plonkish_cat::F101;
91    /// use proof_cat::MultilinearPoly;
92    ///
93    /// // A 1-variable polynomial: f(0) = 3, f(1) = 7.
94    /// let poly = MultilinearPoly::from_evals(vec![
95    ///     F101::new(3), F101::new(7),
96    /// ])?;
97    /// assert_eq!(poly.num_vars().count(), 1);
98    ///
99    /// // Non-power-of-two lengths are rejected.
100    /// let err = MultilinearPoly::<F101>::from_evals(vec![
101    ///     F101::new(1), F101::new(2), F101::new(3),
102    /// ]);
103    /// assert!(err.is_err());
104    /// # Ok::<(), proof_cat::Error>(())
105    /// ```
106    pub fn from_evals(evals: Vec<F>) -> Result<Self, Error> {
107        let len = evals.len();
108        if len.is_power_of_two() {
109            // log2 of a power of two: count trailing zeros.
110            let num_vars = NumVars::new(
111                usize::try_from(len.trailing_zeros())
112                    .map_err(|_| Error::NotPowerOfTwo { value: len })?,
113            );
114            Ok(Self { evals, num_vars })
115        } else {
116            Err(Error::NotPowerOfTwo { value: len })
117        }
118    }
119
120    /// The number of variables.
121    #[must_use]
122    pub fn num_vars(&self) -> NumVars {
123        self.num_vars
124    }
125
126    /// The evaluation table.
127    #[must_use]
128    pub fn evals(&self) -> &[F] {
129        &self.evals
130    }
131
132    /// Sum of all evaluations on the Boolean hypercube.
133    ///
134    /// This equals `sum_{x in {0,1}^n} f(x)`.
135    #[must_use]
136    pub fn sum_over_boolean_hypercube(&self) -> F {
137        self.evals
138            .iter()
139            .cloned()
140            .fold(F::zero(), |acc, v| acc + v)
141    }
142
143    /// Evaluate the multilinear extension at an arbitrary point.
144    ///
145    /// Uses iterated partial evaluation: for each variable `i`,
146    /// the table is split in half and each pair `(lo, hi)` is
147    /// interpolated as `lo * (1 - r_i) + hi * r_i`.  After `n`
148    /// rounds a single value remains.
149    ///
150    /// # Errors
151    ///
152    /// Returns [`Error::DimensionMismatch`] if `point.len() != num_vars`.
153    ///
154    /// # Examples
155    ///
156    /// ```
157    /// use plonkish_cat::F101;
158    /// use proof_cat::MultilinearPoly;
159    ///
160    /// // f(x) = 3*(1-x) + 7*x = 3 + 4x
161    /// let poly = MultilinearPoly::from_evals(vec![
162    ///     F101::new(3), F101::new(7),
163    /// ])?;
164    ///
165    /// // f(0) = 3, f(1) = 7, f(2) = 11
166    /// assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
167    /// assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
168    /// assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
169    /// # Ok::<(), proof_cat::Error>(())
170    /// ```
171    pub fn evaluate(&self, point: &[F]) -> Result<F, Error> {
172        if point.len() == self.num_vars.0 {
173            let final_table = point.iter().fold(self.evals.clone(), |table, r_i| {
174                let half = table.len() / 2;
175                (0..half)
176                    .map(|j| {
177                        let lo = table[j].clone();
178                        let hi = table[j + half].clone();
179                        // lo * (1 - r_i) + hi * r_i
180                        lo * (F::one() - r_i.clone()) + hi * r_i.clone()
181                    })
182                    .collect()
183            });
184            // After num_vars folds, exactly one element remains.
185            final_table
186                .into_iter()
187                .next()
188                .ok_or(Error::DimensionMismatch {
189                    expected: self.num_vars.0,
190                    actual: point.len(),
191                })
192        } else {
193            Err(Error::DimensionMismatch {
194                expected: self.num_vars.0,
195                actual: point.len(),
196            })
197        }
198    }
199
200    /// Bind the first variable to `r`, producing a polynomial
201    /// with one fewer variable.
202    ///
203    /// The resulting table has `2^(n-1)` entries where each
204    /// entry `j` is `evals[2j] * (1 - r) + evals[2j+1] * r`.
205    ///
206    /// # Errors
207    ///
208    /// Returns [`Error::DimensionMismatch`] if the polynomial
209    /// has zero variables.
210    pub fn bind_first_var(&self, r: &F) -> Result<Self, Error> {
211        if self.num_vars.0 > 0 {
212            let half = self.evals.len() / 2;
213            let new_evals: Vec<F> = (0..half)
214                .map(|j| {
215                    let lo = self.evals[j].clone();
216                    let hi = self.evals[j + half].clone();
217                    lo * (F::one() - r.clone()) + hi * r.clone()
218                })
219                .collect();
220            Self::from_evals(new_evals)
221        } else {
222            Err(Error::DimensionMismatch {
223                expected: 1,
224                actual: 0,
225            })
226        }
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use plonkish_cat::F101;
234
235    #[test]
236    fn from_evals_requires_power_of_two() {
237        let result = MultilinearPoly::<F101>::from_evals(vec![
238            F101::new(1),
239            F101::new(2),
240            F101::new(3),
241        ]);
242        assert!(result.is_err());
243    }
244
245    #[test]
246    fn from_evals_empty_fails() {
247        let result = MultilinearPoly::<F101>::from_evals(vec![]);
248        assert!(result.is_err());
249    }
250
251    #[test]
252    fn single_element_poly() -> Result<(), Error> {
253        // 0 variables, one evaluation.
254        let poly = MultilinearPoly::from_evals(vec![F101::new(42)])?;
255        assert_eq!(poly.num_vars().count(), 0);
256        assert_eq!(poly.evaluate(&[])?, F101::new(42));
257        Ok(())
258    }
259
260    #[test]
261    fn one_var_evaluation_at_boolean_points() -> Result<(), Error> {
262        // f(0) = 3, f(1) = 7
263        let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
264        assert_eq!(poly.num_vars().count(), 1);
265        assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
266        assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
267        Ok(())
268    }
269
270    #[test]
271    fn one_var_evaluation_at_midpoint() -> Result<(), Error> {
272        // f(0) = 3, f(1) = 7
273        // f(r) = 3*(1-r) + 7*r = 3 + 4r
274        // f(2) = 3 + 8 = 11
275        let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
276        assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
277        Ok(())
278    }
279
280    #[test]
281    fn two_var_evaluation() -> Result<(), Error> {
282        // f(x0, x1) with evaluations:
283        //   f(0,0) = 1, f(0,1) = 2, f(1,0) = 3, f(1,1) = 4
284        let poly = MultilinearPoly::from_evals(vec![
285            F101::new(1),
286            F101::new(2),
287            F101::new(3),
288            F101::new(4),
289        ])?;
290        assert_eq!(poly.num_vars().count(), 2);
291        // Boolean point checks:
292        assert_eq!(poly.evaluate(&[F101::new(0), F101::new(0)])?, F101::new(1));
293        assert_eq!(poly.evaluate(&[F101::new(0), F101::new(1)])?, F101::new(2));
294        assert_eq!(poly.evaluate(&[F101::new(1), F101::new(0)])?, F101::new(3));
295        assert_eq!(poly.evaluate(&[F101::new(1), F101::new(1)])?, F101::new(4));
296        Ok(())
297    }
298
299    #[test]
300    fn sum_over_hypercube() -> Result<(), Error> {
301        let poly = MultilinearPoly::from_evals(vec![
302            F101::new(1),
303            F101::new(2),
304            F101::new(3),
305            F101::new(4),
306        ])?;
307        // 1 + 2 + 3 + 4 = 10
308        assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
309        Ok(())
310    }
311
312    #[test]
313    fn bind_first_var() -> Result<(), Error> {
314        // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
315        let poly = MultilinearPoly::from_evals(vec![
316            F101::new(1),
317            F101::new(2),
318            F101::new(3),
319            F101::new(4),
320        ])?;
321        // Bind x0 = 0: get f(0, x1) = [1, 2]
322        let bound_zero = poly.bind_first_var(&F101::new(0))?;
323        assert_eq!(bound_zero.num_vars().count(), 1);
324        assert_eq!(bound_zero.evals(), &[F101::new(1), F101::new(2)]);
325
326        // Bind x0 = 1: get f(1, x1) = [3, 4]
327        let bound_one = poly.bind_first_var(&F101::new(1))?;
328        assert_eq!(bound_one.evals(), &[F101::new(3), F101::new(4)]);
329        Ok(())
330    }
331
332    #[test]
333    fn dimension_mismatch_error() {
334        let poly =
335            MultilinearPoly::from_evals(vec![F101::new(1), F101::new(2)]).unwrap_or_else(|_| {
336                MultilinearPoly::from_evals(vec![F101::new(0)]).unwrap_or_else(|_| unreachable!())
337            });
338        // Wrong number of evaluation coordinates.
339        let result = poly.evaluate(&[F101::new(0), F101::new(0)]);
340        assert!(result.is_err());
341    }
342}