proof_cat/poly/multilinear.rs
1//! Multilinear polynomials over the Boolean hypercube.
2//!
3//! A [`MultilinearPoly<F>`] is uniquely determined by its `2^n`
4//! evaluations on `{0,1}^n`. It supports efficient evaluation at
5//! arbitrary points via iterated partial evaluation (folding).
6//!
7//! This is the core data structure for the sumcheck protocol:
8//! each round binds one variable, halving the evaluation table.
9
10use field_cat::Field;
11
12use crate::error::Error;
13
14/// The number of variables in a multilinear polynomial.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub struct NumVars(usize);
17
18impl NumVars {
19 /// Create a new variable count.
20 #[must_use]
21 pub fn new(n: usize) -> Self {
22 Self(n)
23 }
24
25 /// The underlying count.
26 #[must_use]
27 pub fn count(self) -> usize {
28 self.0
29 }
30}
31
32impl core::fmt::Display for NumVars {
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 write!(f, "{}", self.0)
35 }
36}
37
38/// A multilinear polynomial represented by its evaluation table.
39///
40/// The table has `2^num_vars` entries in big-endian bit order:
41/// the first variable is the most significant bit. For a
42/// 2-variable polynomial `f(x0, x1)`:
43///
44/// | Index | x0 | x1 | Value |
45/// |-------|----|----|-------|
46/// | 0 | 0 | 0 | `evals[0]` |
47/// | 1 | 0 | 1 | `evals[1]` |
48/// | 2 | 1 | 0 | `evals[2]` |
49/// | 3 | 1 | 1 | `evals[3]` |
50///
51/// # Examples
52///
53/// ```
54/// use field_cat::F101;
55/// use proof_cat::MultilinearPoly;
56///
57/// // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
58/// let poly = MultilinearPoly::from_evals(vec![
59/// F101::new(1), F101::new(2), F101::new(3), F101::new(4),
60/// ])?;
61///
62/// assert_eq!(poly.num_vars().count(), 2);
63///
64/// // Sum over the Boolean hypercube: 1 + 2 + 3 + 4 = 10
65/// assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
66///
67/// // Evaluate at a Boolean point: f(1, 0) = 3
68/// let val = poly.evaluate(&[F101::new(1), F101::new(0)])?;
69/// assert_eq!(val, F101::new(3));
70/// # Ok::<(), proof_cat::Error>(())
71/// ```
72#[derive(Debug, Clone)]
73pub struct MultilinearPoly<F: Field> {
74 evals: Vec<F>,
75 num_vars: NumVars,
76}
77
78impl<F: Field> MultilinearPoly<F> {
79 /// Construct from an evaluation table.
80 ///
81 /// The table length must be a power of two (including 1).
82 ///
83 /// # Errors
84 ///
85 /// Returns [`Error::NotPowerOfTwo`] if `evals.len()` is not
86 /// a power of two or is zero.
87 ///
88 /// # Examples
89 ///
90 /// ```
91 /// use field_cat::F101;
92 /// use proof_cat::MultilinearPoly;
93 ///
94 /// // A 1-variable polynomial: f(0) = 3, f(1) = 7.
95 /// let poly = MultilinearPoly::from_evals(vec![
96 /// F101::new(3), F101::new(7),
97 /// ])?;
98 /// assert_eq!(poly.num_vars().count(), 1);
99 ///
100 /// // Non-power-of-two lengths are rejected.
101 /// let err = MultilinearPoly::<F101>::from_evals(vec![
102 /// F101::new(1), F101::new(2), F101::new(3),
103 /// ]);
104 /// assert!(err.is_err());
105 /// # Ok::<(), proof_cat::Error>(())
106 /// ```
107 pub fn from_evals(evals: Vec<F>) -> Result<Self, Error> {
108 let len = evals.len();
109 if len.is_power_of_two() {
110 // log2 of a power of two: count trailing zeros.
111 let num_vars = NumVars::new(
112 usize::try_from(len.trailing_zeros())
113 .map_err(|_| Error::NotPowerOfTwo { value: len })?,
114 );
115 Ok(Self { evals, num_vars })
116 } else {
117 Err(Error::NotPowerOfTwo { value: len })
118 }
119 }
120
121 /// The number of variables.
122 #[must_use]
123 pub fn num_vars(&self) -> NumVars {
124 self.num_vars
125 }
126
127 /// The evaluation table.
128 #[must_use]
129 pub fn evals(&self) -> &[F] {
130 &self.evals
131 }
132
133 /// Sum of all evaluations on the Boolean hypercube.
134 ///
135 /// This equals `sum_{x in {0,1}^n} f(x)`.
136 #[must_use]
137 pub fn sum_over_boolean_hypercube(&self) -> F {
138 self.evals.iter().cloned().fold(F::zero(), |acc, v| acc + v)
139 }
140
141 /// Evaluate the multilinear extension at an arbitrary point.
142 ///
143 /// Uses iterated partial evaluation: for each variable `i`,
144 /// the table is split in half and each pair `(lo, hi)` is
145 /// interpolated as `lo * (1 - r_i) + hi * r_i`. After `n`
146 /// rounds a single value remains.
147 ///
148 /// # Errors
149 ///
150 /// Returns [`Error::DimensionMismatch`] if `point.len() != num_vars`.
151 ///
152 /// # Examples
153 ///
154 /// ```
155 /// use field_cat::F101;
156 /// use proof_cat::MultilinearPoly;
157 ///
158 /// // f(x) = 3*(1-x) + 7*x = 3 + 4x
159 /// let poly = MultilinearPoly::from_evals(vec![
160 /// F101::new(3), F101::new(7),
161 /// ])?;
162 ///
163 /// // f(0) = 3, f(1) = 7, f(2) = 11
164 /// assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
165 /// assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
166 /// assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
167 /// # Ok::<(), proof_cat::Error>(())
168 /// ```
169 pub fn evaluate(&self, point: &[F]) -> Result<F, Error> {
170 if point.len() == self.num_vars.0 {
171 let final_table = point.iter().fold(self.evals.clone(), |table, r_i| {
172 let half = table.len() / 2;
173 (0..half)
174 .map(|j| {
175 let lo = table[j].clone();
176 let hi = table[j + half].clone();
177 // lo * (1 - r_i) + hi * r_i
178 lo * (F::one() - r_i.clone()) + hi * r_i.clone()
179 })
180 .collect()
181 });
182 // After num_vars folds, exactly one element remains.
183 final_table
184 .into_iter()
185 .next()
186 .ok_or(Error::DimensionMismatch {
187 expected: self.num_vars.0,
188 actual: point.len(),
189 })
190 } else {
191 Err(Error::DimensionMismatch {
192 expected: self.num_vars.0,
193 actual: point.len(),
194 })
195 }
196 }
197
198 /// Bind the first variable to `r`, producing a polynomial
199 /// with one fewer variable.
200 ///
201 /// The resulting table has `2^(n-1)` entries where each
202 /// entry `j` is `evals[2j] * (1 - r) + evals[2j+1] * r`.
203 ///
204 /// # Errors
205 ///
206 /// Returns [`Error::DimensionMismatch`] if the polynomial
207 /// has zero variables.
208 pub fn bind_first_var(&self, r: &F) -> Result<Self, Error> {
209 if self.num_vars.0 > 0 {
210 let half = self.evals.len() / 2;
211 let new_evals: Vec<F> = (0..half)
212 .map(|j| {
213 let lo = self.evals[j].clone();
214 let hi = self.evals[j + half].clone();
215 lo * (F::one() - r.clone()) + hi * r.clone()
216 })
217 .collect();
218 Self::from_evals(new_evals)
219 } else {
220 Err(Error::DimensionMismatch {
221 expected: 1,
222 actual: 0,
223 })
224 }
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use field_cat::F101;
232
233 #[test]
234 fn from_evals_requires_power_of_two() {
235 let result =
236 MultilinearPoly::<F101>::from_evals(vec![F101::new(1), F101::new(2), F101::new(3)]);
237 assert!(result.is_err());
238 }
239
240 #[test]
241 fn from_evals_empty_fails() {
242 let result = MultilinearPoly::<F101>::from_evals(vec![]);
243 assert!(result.is_err());
244 }
245
246 #[test]
247 fn single_element_poly() -> Result<(), Error> {
248 // 0 variables, one evaluation.
249 let poly = MultilinearPoly::from_evals(vec![F101::new(42)])?;
250 assert_eq!(poly.num_vars().count(), 0);
251 assert_eq!(poly.evaluate(&[])?, F101::new(42));
252 Ok(())
253 }
254
255 #[test]
256 fn one_var_evaluation_at_boolean_points() -> Result<(), Error> {
257 // f(0) = 3, f(1) = 7
258 let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
259 assert_eq!(poly.num_vars().count(), 1);
260 assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
261 assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
262 Ok(())
263 }
264
265 #[test]
266 fn one_var_evaluation_at_midpoint() -> Result<(), Error> {
267 // f(0) = 3, f(1) = 7
268 // f(r) = 3*(1-r) + 7*r = 3 + 4r
269 // f(2) = 3 + 8 = 11
270 let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
271 assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
272 Ok(())
273 }
274
275 #[test]
276 fn two_var_evaluation() -> Result<(), Error> {
277 // f(x0, x1) with evaluations:
278 // f(0,0) = 1, f(0,1) = 2, f(1,0) = 3, f(1,1) = 4
279 let poly = MultilinearPoly::from_evals(vec![
280 F101::new(1),
281 F101::new(2),
282 F101::new(3),
283 F101::new(4),
284 ])?;
285 assert_eq!(poly.num_vars().count(), 2);
286 // Boolean point checks:
287 assert_eq!(poly.evaluate(&[F101::new(0), F101::new(0)])?, F101::new(1));
288 assert_eq!(poly.evaluate(&[F101::new(0), F101::new(1)])?, F101::new(2));
289 assert_eq!(poly.evaluate(&[F101::new(1), F101::new(0)])?, F101::new(3));
290 assert_eq!(poly.evaluate(&[F101::new(1), F101::new(1)])?, F101::new(4));
291 Ok(())
292 }
293
294 #[test]
295 fn sum_over_hypercube() -> Result<(), Error> {
296 let poly = MultilinearPoly::from_evals(vec![
297 F101::new(1),
298 F101::new(2),
299 F101::new(3),
300 F101::new(4),
301 ])?;
302 // 1 + 2 + 3 + 4 = 10
303 assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
304 Ok(())
305 }
306
307 #[test]
308 fn bind_first_var() -> Result<(), Error> {
309 // f(x0, x1): f(0,0)=1, f(0,1)=2, f(1,0)=3, f(1,1)=4
310 let poly = MultilinearPoly::from_evals(vec![
311 F101::new(1),
312 F101::new(2),
313 F101::new(3),
314 F101::new(4),
315 ])?;
316 // Bind x0 = 0: get f(0, x1) = [1, 2]
317 let bound_zero = poly.bind_first_var(&F101::new(0))?;
318 assert_eq!(bound_zero.num_vars().count(), 1);
319 assert_eq!(bound_zero.evals(), &[F101::new(1), F101::new(2)]);
320
321 // Bind x0 = 1: get f(1, x1) = [3, 4]
322 let bound_one = poly.bind_first_var(&F101::new(1))?;
323 assert_eq!(bound_one.evals(), &[F101::new(3), F101::new(4)]);
324 Ok(())
325 }
326
327 #[test]
328 fn dimension_mismatch_error() {
329 let poly =
330 MultilinearPoly::from_evals(vec![F101::new(1), F101::new(2)]).unwrap_or_else(|_| {
331 MultilinearPoly::from_evals(vec![F101::new(0)]).unwrap_or_else(|_| unreachable!())
332 });
333 // Wrong number of evaluation coordinates.
334 let result = poly.evaluate(&[F101::new(0), F101::new(0)]);
335 assert!(result.is_err());
336 }
337}