Skip to main content

proof_cat/poly/
multilinear.rs

1//! Multilinear polynomials over the Boolean hypercube.
2//!
3//! A [`MultilinearPoly<F>`] is uniquely determined by its `2^n`
4//! evaluations on `{0,1}^n`.  It supports efficient evaluation at
5//! arbitrary points via iterated partial evaluation (folding).
6//!
7//! This is the core data structure for the sumcheck protocol:
8//! each round binds one variable, halving the evaluation table.
9
10use field_cat::Field;
11
12use crate::error::Error;
13
14/// The number of variables in a multilinear polynomial.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct NumVars(usize);
17
18impl NumVars {
19    /// Create a new variable count.
20    #[must_use]
21    pub fn new(n: usize) -> Self {
22        Self(n)
23    }
24
25    /// The underlying count.
26    #[must_use]
27    pub fn count(self) -> usize {
28        self.0
29    }
30}
31
32impl core::fmt::Display for NumVars {
33    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34        write!(f, "{}", self.0)
35    }
36}
37
38/// A multilinear polynomial represented by its evaluation table.
39///
40/// The table has `2^num_vars` entries in big-endian bit order:
41/// the first variable is the most significant bit.  For a
42/// 2-variable polynomial `f(x0, x1)`:
43///
44/// | Index | x0 | x1 | Value |
45/// |-------|----|----|-------|
46/// | 0     | 0  | 0  | `evals[0]` |
47/// | 1     | 0  | 1  | `evals[1]` |
48/// | 2     | 1  | 0  | `evals[2]` |
49/// | 3     | 1  | 1  | `evals[3]` |
50///
51/// # Examples
52///
53/// ```
54/// use field_cat::F101;
55/// use proof_cat::MultilinearPoly;
56///
57/// // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
58/// let poly = MultilinearPoly::from_evals(vec![
59///     F101::new(1), F101::new(2), F101::new(3), F101::new(4),
60/// ])?;
61///
62/// assert_eq!(poly.num_vars().count(), 2);
63///
64/// // Sum over the Boolean hypercube: 1 + 2 + 3 + 4 = 10
65/// assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
66///
67/// // Evaluate at a Boolean point: f(1, 0) = 3
68/// let val = poly.evaluate(&[F101::new(1), F101::new(0)])?;
69/// assert_eq!(val, F101::new(3));
70/// # Ok::<(), proof_cat::Error>(())
71/// ```
72#[derive(Debug, Clone)]
73pub struct MultilinearPoly<F: Field> {
74    evals: Vec<F>,
75    num_vars: NumVars,
76}
77
78impl<F: Field> MultilinearPoly<F> {
79    /// Construct from an evaluation table.
80    ///
81    /// The table length must be a power of two (including 1).
82    ///
83    /// # Errors
84    ///
85    /// Returns [`Error::NotPowerOfTwo`] if `evals.len()` is not
86    /// a power of two or is zero.
87    ///
88    /// # Examples
89    ///
90    /// ```
91    /// use field_cat::F101;
92    /// use proof_cat::MultilinearPoly;
93    ///
94    /// // A 1-variable polynomial: f(0) = 3, f(1) = 7.
95    /// let poly = MultilinearPoly::from_evals(vec![
96    ///     F101::new(3), F101::new(7),
97    /// ])?;
98    /// assert_eq!(poly.num_vars().count(), 1);
99    ///
100    /// // Non-power-of-two lengths are rejected.
101    /// let err = MultilinearPoly::<F101>::from_evals(vec![
102    ///     F101::new(1), F101::new(2), F101::new(3),
103    /// ]);
104    /// assert!(err.is_err());
105    /// # Ok::<(), proof_cat::Error>(())
106    /// ```
107    pub fn from_evals(evals: Vec<F>) -> Result<Self, Error> {
108        let len = evals.len();
109        if len.is_power_of_two() {
110            // log2 of a power of two: count trailing zeros.
111            let num_vars = NumVars::new(
112                usize::try_from(len.trailing_zeros())
113                    .map_err(|_| Error::NotPowerOfTwo { value: len })?,
114            );
115            Ok(Self { evals, num_vars })
116        } else {
117            Err(Error::NotPowerOfTwo { value: len })
118        }
119    }
120
121    /// The number of variables.
122    #[must_use]
123    pub fn num_vars(&self) -> NumVars {
124        self.num_vars
125    }
126
127    /// The evaluation table.
128    #[must_use]
129    pub fn evals(&self) -> &[F] {
130        &self.evals
131    }
132
133    /// Sum of all evaluations on the Boolean hypercube.
134    ///
135    /// This equals `sum_{x in {0,1}^n} f(x)`.
136    #[must_use]
137    pub fn sum_over_boolean_hypercube(&self) -> F {
138        self.evals.iter().cloned().fold(F::zero(), |acc, v| acc + v)
139    }
140
141    /// Evaluate the multilinear extension at an arbitrary point.
142    ///
143    /// Uses iterated partial evaluation: for each variable `i`,
144    /// the table is split in half and each pair `(lo, hi)` is
145    /// interpolated as `lo * (1 - r_i) + hi * r_i`.  After `n`
146    /// rounds a single value remains.
147    ///
148    /// # Errors
149    ///
150    /// Returns [`Error::DimensionMismatch`] if `point.len() != num_vars`.
151    ///
152    /// # Examples
153    ///
154    /// ```
155    /// use field_cat::F101;
156    /// use proof_cat::MultilinearPoly;
157    ///
158    /// // f(x) = 3*(1-x) + 7*x = 3 + 4x
159    /// let poly = MultilinearPoly::from_evals(vec![
160    ///     F101::new(3), F101::new(7),
161    /// ])?;
162    ///
163    /// // f(0) = 3, f(1) = 7, f(2) = 11
164    /// assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
165    /// assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
166    /// assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
167    /// # Ok::<(), proof_cat::Error>(())
168    /// ```
169    pub fn evaluate(&self, point: &[F]) -> Result<F, Error> {
170        if point.len() == self.num_vars.0 {
171            let final_table = point.iter().fold(self.evals.clone(), |table, r_i| {
172                let half = table.len() / 2;
173                (0..half)
174                    .map(|j| {
175                        let lo = table[j].clone();
176                        let hi = table[j + half].clone();
177                        // lo * (1 - r_i) + hi * r_i
178                        lo * (F::one() - r_i.clone()) + hi * r_i.clone()
179                    })
180                    .collect()
181            });
182            // After num_vars folds, exactly one element remains.
183            final_table
184                .into_iter()
185                .next()
186                .ok_or(Error::DimensionMismatch {
187                    expected: self.num_vars.0,
188                    actual: point.len(),
189                })
190        } else {
191            Err(Error::DimensionMismatch {
192                expected: self.num_vars.0,
193                actual: point.len(),
194            })
195        }
196    }
197
198    /// Bind the first variable to `r`, producing a polynomial
199    /// with one fewer variable.
200    ///
201    /// The resulting table has `2^(n-1)` entries where each
202    /// entry `j` is `evals[2j] * (1 - r) + evals[2j+1] * r`.
203    ///
204    /// # Errors
205    ///
206    /// Returns [`Error::DimensionMismatch`] if the polynomial
207    /// has zero variables.
208    pub fn bind_first_var(&self, r: &F) -> Result<Self, Error> {
209        if self.num_vars.0 > 0 {
210            let half = self.evals.len() / 2;
211            let new_evals: Vec<F> = (0..half)
212                .map(|j| {
213                    let lo = self.evals[j].clone();
214                    let hi = self.evals[j + half].clone();
215                    lo * (F::one() - r.clone()) + hi * r.clone()
216                })
217                .collect();
218            Self::from_evals(new_evals)
219        } else {
220            Err(Error::DimensionMismatch {
221                expected: 1,
222                actual: 0,
223            })
224        }
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use field_cat::F101;
232
233    #[test]
234    fn from_evals_requires_power_of_two() {
235        let result =
236            MultilinearPoly::<F101>::from_evals(vec![F101::new(1), F101::new(2), F101::new(3)]);
237        assert!(result.is_err());
238    }
239
240    #[test]
241    fn from_evals_empty_fails() {
242        let result = MultilinearPoly::<F101>::from_evals(vec![]);
243        assert!(result.is_err());
244    }
245
246    #[test]
247    fn single_element_poly() -> Result<(), Error> {
248        // 0 variables, one evaluation.
249        let poly = MultilinearPoly::from_evals(vec![F101::new(42)])?;
250        assert_eq!(poly.num_vars().count(), 0);
251        assert_eq!(poly.evaluate(&[])?, F101::new(42));
252        Ok(())
253    }
254
255    #[test]
256    fn one_var_evaluation_at_boolean_points() -> Result<(), Error> {
257        // f(0) = 3, f(1) = 7
258        let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
259        assert_eq!(poly.num_vars().count(), 1);
260        assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
261        assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
262        Ok(())
263    }
264
265    #[test]
266    fn one_var_evaluation_at_midpoint() -> Result<(), Error> {
267        // f(0) = 3, f(1) = 7
268        // f(r) = 3*(1-r) + 7*r = 3 + 4r
269        // f(2) = 3 + 8 = 11
270        let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
271        assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
272        Ok(())
273    }
274
275    #[test]
276    fn two_var_evaluation() -> Result<(), Error> {
277        // f(x0, x1) with evaluations:
278        //   f(0,0) = 1, f(0,1) = 2, f(1,0) = 3, f(1,1) = 4
279        let poly = MultilinearPoly::from_evals(vec![
280            F101::new(1),
281            F101::new(2),
282            F101::new(3),
283            F101::new(4),
284        ])?;
285        assert_eq!(poly.num_vars().count(), 2);
286        // Boolean point checks:
287        assert_eq!(poly.evaluate(&[F101::new(0), F101::new(0)])?, F101::new(1));
288        assert_eq!(poly.evaluate(&[F101::new(0), F101::new(1)])?, F101::new(2));
289        assert_eq!(poly.evaluate(&[F101::new(1), F101::new(0)])?, F101::new(3));
290        assert_eq!(poly.evaluate(&[F101::new(1), F101::new(1)])?, F101::new(4));
291        Ok(())
292    }
293
294    #[test]
295    fn sum_over_hypercube() -> Result<(), Error> {
296        let poly = MultilinearPoly::from_evals(vec![
297            F101::new(1),
298            F101::new(2),
299            F101::new(3),
300            F101::new(4),
301        ])?;
302        // 1 + 2 + 3 + 4 = 10
303        assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
304        Ok(())
305    }
306
307    #[test]
308    fn bind_first_var() -> Result<(), Error> {
309        // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
310        let poly = MultilinearPoly::from_evals(vec![
311            F101::new(1),
312            F101::new(2),
313            F101::new(3),
314            F101::new(4),
315        ])?;
316        // Bind x0 = 0: get f(0, x1) = [1, 2]
317        let bound_zero = poly.bind_first_var(&F101::new(0))?;
318        assert_eq!(bound_zero.num_vars().count(), 1);
319        assert_eq!(bound_zero.evals(), &[F101::new(1), F101::new(2)]);
320
321        // Bind x0 = 1: get f(1, x1) = [3, 4]
322        let bound_one = poly.bind_first_var(&F101::new(1))?;
323        assert_eq!(bound_one.evals(), &[F101::new(3), F101::new(4)]);
324        Ok(())
325    }
326
327    #[test]
328    fn dimension_mismatch_error() {
329        let poly =
330            MultilinearPoly::from_evals(vec![F101::new(1), F101::new(2)]).unwrap_or_else(|_| {
331                MultilinearPoly::from_evals(vec![F101::new(0)]).unwrap_or_else(|_| unreachable!())
332            });
333        // Wrong number of evaluation coordinates.
334        let result = poly.evaluate(&[F101::new(0), F101::new(0)]);
335        assert!(result.is_err());
336    }
337}