proof_cat/poly/multilinear.rs
1//! Multilinear polynomials over the Boolean hypercube.
2//!
3//! A [`MultilinearPoly<F>`] is uniquely determined by its `2^n`
4//! evaluations on `{0,1}^n`. It supports efficient evaluation at
5//! arbitrary points via iterated partial evaluation (folding).
6//!
7//! This is the core data structure for the sumcheck protocol:
8//! each round binds one variable, halving the evaluation table.
9
10use crate::error::Error;
11use plonkish_cat::Field;
12
13/// The number of variables in a multilinear polynomial.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct NumVars(usize);
16
17impl NumVars {
18 /// Create a new variable count.
19 #[must_use]
20 pub fn new(n: usize) -> Self {
21 Self(n)
22 }
23
24 /// The underlying count.
25 #[must_use]
26 pub fn count(self) -> usize {
27 self.0
28 }
29}
30
31impl core::fmt::Display for NumVars {
32 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
33 write!(f, "{}", self.0)
34 }
35}
36
37/// A multilinear polynomial represented by its evaluation table.
38///
39/// The table has `2^num_vars` entries in big-endian bit order:
40/// the first variable is the most significant bit. For a
41/// 2-variable polynomial `f(x0, x1)`:
42///
43/// | Index | x0 | x1 | Value |
44/// |-------|----|----|-------|
45/// | 0 | 0 | 0 | `evals[0]` |
46/// | 1 | 0 | 1 | `evals[1]` |
47/// | 2 | 1 | 0 | `evals[2]` |
48/// | 3 | 1 | 1 | `evals[3]` |
49///
50/// # Examples
51///
52/// ```
53/// use plonkish_cat::F101;
54/// use proof_cat::MultilinearPoly;
55///
56/// // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
57/// let poly = MultilinearPoly::from_evals(vec![
58/// F101::new(1), F101::new(2), F101::new(3), F101::new(4),
59/// ])?;
60///
61/// assert_eq!(poly.num_vars().count(), 2);
62///
63/// // Sum over the Boolean hypercube: 1 + 2 + 3 + 4 = 10
64/// assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
65///
66/// // Evaluate at a Boolean point: f(1, 0) = 3
67/// let val = poly.evaluate(&[F101::new(1), F101::new(0)])?;
68/// assert_eq!(val, F101::new(3));
69/// # Ok::<(), proof_cat::Error>(())
70/// ```
71#[derive(Debug, Clone)]
72pub struct MultilinearPoly<F: Field> {
73 evals: Vec<F>,
74 num_vars: NumVars,
75}
76
77impl<F: Field> MultilinearPoly<F> {
78 /// Construct from an evaluation table.
79 ///
80 /// The table length must be a power of two (including 1).
81 ///
82 /// # Errors
83 ///
84 /// Returns [`Error::NotPowerOfTwo`] if `evals.len()` is not
85 /// a power of two or is zero.
86 ///
87 /// # Examples
88 ///
89 /// ```
90 /// use plonkish_cat::F101;
91 /// use proof_cat::MultilinearPoly;
92 ///
93 /// // A 1-variable polynomial: f(0) = 3, f(1) = 7.
94 /// let poly = MultilinearPoly::from_evals(vec![
95 /// F101::new(3), F101::new(7),
96 /// ])?;
97 /// assert_eq!(poly.num_vars().count(), 1);
98 ///
99 /// // Non-power-of-two lengths are rejected.
100 /// let err = MultilinearPoly::<F101>::from_evals(vec![
101 /// F101::new(1), F101::new(2), F101::new(3),
102 /// ]);
103 /// assert!(err.is_err());
104 /// # Ok::<(), proof_cat::Error>(())
105 /// ```
106 pub fn from_evals(evals: Vec<F>) -> Result<Self, Error> {
107 let len = evals.len();
108 if len.is_power_of_two() {
109 // log2 of a power of two: count trailing zeros.
110 let num_vars = NumVars::new(
111 usize::try_from(len.trailing_zeros())
112 .map_err(|_| Error::NotPowerOfTwo { value: len })?,
113 );
114 Ok(Self { evals, num_vars })
115 } else {
116 Err(Error::NotPowerOfTwo { value: len })
117 }
118 }
119
120 /// The number of variables.
121 #[must_use]
122 pub fn num_vars(&self) -> NumVars {
123 self.num_vars
124 }
125
126 /// The evaluation table.
127 #[must_use]
128 pub fn evals(&self) -> &[F] {
129 &self.evals
130 }
131
132 /// Sum of all evaluations on the Boolean hypercube.
133 ///
134 /// This equals `sum_{x in {0,1}^n} f(x)`.
135 #[must_use]
136 pub fn sum_over_boolean_hypercube(&self) -> F {
137 self.evals
138 .iter()
139 .cloned()
140 .fold(F::zero(), |acc, v| acc + v)
141 }
142
143 /// Evaluate the multilinear extension at an arbitrary point.
144 ///
145 /// Uses iterated partial evaluation: for each variable `i`,
146 /// the table is split in half and each pair `(lo, hi)` is
147 /// interpolated as `lo * (1 - r_i) + hi * r_i`. After `n`
148 /// rounds a single value remains.
149 ///
150 /// # Errors
151 ///
152 /// Returns [`Error::DimensionMismatch`] if `point.len() != num_vars`.
153 ///
154 /// # Examples
155 ///
156 /// ```
157 /// use plonkish_cat::F101;
158 /// use proof_cat::MultilinearPoly;
159 ///
160 /// // f(x) = 3*(1-x) + 7*x = 3 + 4x
161 /// let poly = MultilinearPoly::from_evals(vec![
162 /// F101::new(3), F101::new(7),
163 /// ])?;
164 ///
165 /// // f(0) = 3, f(1) = 7, f(2) = 11
166 /// assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
167 /// assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
168 /// assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
169 /// # Ok::<(), proof_cat::Error>(())
170 /// ```
171 pub fn evaluate(&self, point: &[F]) -> Result<F, Error> {
172 if point.len() == self.num_vars.0 {
173 let final_table = point.iter().fold(self.evals.clone(), |table, r_i| {
174 let half = table.len() / 2;
175 (0..half)
176 .map(|j| {
177 let lo = table[j].clone();
178 let hi = table[j + half].clone();
179 // lo * (1 - r_i) + hi * r_i
180 lo * (F::one() - r_i.clone()) + hi * r_i.clone()
181 })
182 .collect()
183 });
184 // After num_vars folds, exactly one element remains.
185 final_table
186 .into_iter()
187 .next()
188 .ok_or(Error::DimensionMismatch {
189 expected: self.num_vars.0,
190 actual: point.len(),
191 })
192 } else {
193 Err(Error::DimensionMismatch {
194 expected: self.num_vars.0,
195 actual: point.len(),
196 })
197 }
198 }
199
200 /// Bind the first variable to `r`, producing a polynomial
201 /// with one fewer variable.
202 ///
203 /// The resulting table has `2^(n-1)` entries where each
204 /// entry `j` is `evals[2j] * (1 - r) + evals[2j+1] * r`.
205 ///
206 /// # Errors
207 ///
208 /// Returns [`Error::DimensionMismatch`] if the polynomial
209 /// has zero variables.
210 pub fn bind_first_var(&self, r: &F) -> Result<Self, Error> {
211 if self.num_vars.0 > 0 {
212 let half = self.evals.len() / 2;
213 let new_evals: Vec<F> = (0..half)
214 .map(|j| {
215 let lo = self.evals[j].clone();
216 let hi = self.evals[j + half].clone();
217 lo * (F::one() - r.clone()) + hi * r.clone()
218 })
219 .collect();
220 Self::from_evals(new_evals)
221 } else {
222 Err(Error::DimensionMismatch {
223 expected: 1,
224 actual: 0,
225 })
226 }
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use plonkish_cat::F101;
234
235 #[test]
236 fn from_evals_requires_power_of_two() {
237 let result = MultilinearPoly::<F101>::from_evals(vec![
238 F101::new(1),
239 F101::new(2),
240 F101::new(3),
241 ]);
242 assert!(result.is_err());
243 }
244
245 #[test]
246 fn from_evals_empty_fails() {
247 let result = MultilinearPoly::<F101>::from_evals(vec![]);
248 assert!(result.is_err());
249 }
250
251 #[test]
252 fn single_element_poly() -> Result<(), Error> {
253 // 0 variables, one evaluation.
254 let poly = MultilinearPoly::from_evals(vec![F101::new(42)])?;
255 assert_eq!(poly.num_vars().count(), 0);
256 assert_eq!(poly.evaluate(&[])?, F101::new(42));
257 Ok(())
258 }
259
260 #[test]
261 fn one_var_evaluation_at_boolean_points() -> Result<(), Error> {
262 // f(0) = 3, f(1) = 7
263 let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
264 assert_eq!(poly.num_vars().count(), 1);
265 assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
266 assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
267 Ok(())
268 }
269
270 #[test]
271 fn one_var_evaluation_at_midpoint() -> Result<(), Error> {
272 // f(0) = 3, f(1) = 7
273 // f(r) = 3*(1-r) + 7*r = 3 + 4r
274 // f(2) = 3 + 8 = 11
275 let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
276 assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
277 Ok(())
278 }
279
280 #[test]
281 fn two_var_evaluation() -> Result<(), Error> {
282 // f(x0, x1) with evaluations:
283 // f(0,0) = 1, f(0,1) = 2, f(1,0) = 3, f(1,1) = 4
284 let poly = MultilinearPoly::from_evals(vec![
285 F101::new(1),
286 F101::new(2),
287 F101::new(3),
288 F101::new(4),
289 ])?;
290 assert_eq!(poly.num_vars().count(), 2);
291 // Boolean point checks:
292 assert_eq!(poly.evaluate(&[F101::new(0), F101::new(0)])?, F101::new(1));
293 assert_eq!(poly.evaluate(&[F101::new(0), F101::new(1)])?, F101::new(2));
294 assert_eq!(poly.evaluate(&[F101::new(1), F101::new(0)])?, F101::new(3));
295 assert_eq!(poly.evaluate(&[F101::new(1), F101::new(1)])?, F101::new(4));
296 Ok(())
297 }
298
299 #[test]
300 fn sum_over_hypercube() -> Result<(), Error> {
301 let poly = MultilinearPoly::from_evals(vec![
302 F101::new(1),
303 F101::new(2),
304 F101::new(3),
305 F101::new(4),
306 ])?;
307 // 1 + 2 + 3 + 4 = 10
308 assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
309 Ok(())
310 }
311
312 #[test]
313 fn bind_first_var() -> Result<(), Error> {
314 // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
315 let poly = MultilinearPoly::from_evals(vec![
316 F101::new(1),
317 F101::new(2),
318 F101::new(3),
319 F101::new(4),
320 ])?;
321 // Bind x0 = 0: get f(0, x1) = [1, 2]
322 let bound_zero = poly.bind_first_var(&F101::new(0))?;
323 assert_eq!(bound_zero.num_vars().count(), 1);
324 assert_eq!(bound_zero.evals(), &[F101::new(1), F101::new(2)]);
325
326 // Bind x0 = 1: get f(1, x1) = [3, 4]
327 let bound_one = poly.bind_first_var(&F101::new(1))?;
328 assert_eq!(bound_one.evals(), &[F101::new(3), F101::new(4)]);
329 Ok(())
330 }
331
332 #[test]
333 fn dimension_mismatch_error() {
334 let poly =
335 MultilinearPoly::from_evals(vec![F101::new(1), F101::new(2)]).unwrap_or_else(|_| {
336 MultilinearPoly::from_evals(vec![F101::new(0)]).unwrap_or_else(|_| unreachable!())
337 });
338 // Wrong number of evaluation coordinates.
339 let result = poly.evaluate(&[F101::new(0), F101::new(0)]);
340 assert!(result.is_err());
341 }
342}