use field_cat::Field;
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NumVars(usize);
impl NumVars {
#[must_use]
pub fn new(n: usize) -> Self {
Self(n)
}
#[must_use]
pub fn count(self) -> usize {
self.0
}
}
impl core::fmt::Display for NumVars {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone)]
pub struct MultilinearPoly<F: Field> {
evals: Vec<F>,
num_vars: NumVars,
}
impl<F: Field> MultilinearPoly<F> {
pub fn from_evals(evals: Vec<F>) -> Result<Self, Error> {
let len = evals.len();
if len.is_power_of_two() {
let num_vars = NumVars::new(
usize::try_from(len.trailing_zeros())
.map_err(|_| Error::NotPowerOfTwo { value: len })?,
);
Ok(Self { evals, num_vars })
} else {
Err(Error::NotPowerOfTwo { value: len })
}
}
#[must_use]
pub fn num_vars(&self) -> NumVars {
self.num_vars
}
#[must_use]
pub fn evals(&self) -> &[F] {
&self.evals
}
#[must_use]
pub fn sum_over_boolean_hypercube(&self) -> F {
self.evals.iter().cloned().fold(F::zero(), |acc, v| acc + v)
}
pub fn evaluate(&self, point: &[F]) -> Result<F, Error> {
if point.len() == self.num_vars.0 {
let final_table = point.iter().fold(self.evals.clone(), |table, r_i| {
let half = table.len() / 2;
(0..half)
.map(|j| {
let lo = table[j].clone();
let hi = table[j + half].clone();
lo * (F::one() - r_i.clone()) + hi * r_i.clone()
})
.collect()
});
final_table
.into_iter()
.next()
.ok_or(Error::DimensionMismatch {
expected: self.num_vars.0,
actual: point.len(),
})
} else {
Err(Error::DimensionMismatch {
expected: self.num_vars.0,
actual: point.len(),
})
}
}
pub fn bind_first_var(&self, r: &F) -> Result<Self, Error> {
if self.num_vars.0 > 0 {
let half = self.evals.len() / 2;
let new_evals: Vec<F> = (0..half)
.map(|j| {
let lo = self.evals[j].clone();
let hi = self.evals[j + half].clone();
lo * (F::one() - r.clone()) + hi * r.clone()
})
.collect();
Self::from_evals(new_evals)
} else {
Err(Error::DimensionMismatch {
expected: 1,
actual: 0,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use field_cat::F101;
#[test]
fn from_evals_requires_power_of_two() {
let result =
MultilinearPoly::<F101>::from_evals(vec![F101::new(1), F101::new(2), F101::new(3)]);
assert!(result.is_err());
}
#[test]
fn from_evals_empty_fails() {
let result = MultilinearPoly::<F101>::from_evals(vec![]);
assert!(result.is_err());
}
#[test]
fn single_element_poly() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::new(42)])?;
assert_eq!(poly.num_vars().count(), 0);
assert_eq!(poly.evaluate(&[])?, F101::new(42));
Ok(())
}
#[test]
fn one_var_evaluation_at_boolean_points() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
assert_eq!(poly.num_vars().count(), 1);
assert_eq!(poly.evaluate(&[F101::new(0)])?, F101::new(3));
assert_eq!(poly.evaluate(&[F101::new(1)])?, F101::new(7));
Ok(())
}
#[test]
fn one_var_evaluation_at_midpoint() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![F101::new(3), F101::new(7)])?;
assert_eq!(poly.evaluate(&[F101::new(2)])?, F101::new(11));
Ok(())
}
#[test]
fn two_var_evaluation() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
assert_eq!(poly.num_vars().count(), 2);
assert_eq!(poly.evaluate(&[F101::new(0), F101::new(0)])?, F101::new(1));
assert_eq!(poly.evaluate(&[F101::new(0), F101::new(1)])?, F101::new(2));
assert_eq!(poly.evaluate(&[F101::new(1), F101::new(0)])?, F101::new(3));
assert_eq!(poly.evaluate(&[F101::new(1), F101::new(1)])?, F101::new(4));
Ok(())
}
#[test]
fn sum_over_hypercube() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
assert_eq!(poly.sum_over_boolean_hypercube(), F101::new(10));
Ok(())
}
#[test]
fn bind_first_var() -> Result<(), Error> {
let poly = MultilinearPoly::from_evals(vec![
F101::new(1),
F101::new(2),
F101::new(3),
F101::new(4),
])?;
let bound_zero = poly.bind_first_var(&F101::new(0))?;
assert_eq!(bound_zero.num_vars().count(), 1);
assert_eq!(bound_zero.evals(), &[F101::new(1), F101::new(2)]);
let bound_one = poly.bind_first_var(&F101::new(1))?;
assert_eq!(bound_one.evals(), &[F101::new(3), F101::new(4)]);
Ok(())
}
#[test]
fn dimension_mismatch_error() {
let poly =
MultilinearPoly::from_evals(vec![F101::new(1), F101::new(2)]).unwrap_or_else(|_| {
MultilinearPoly::from_evals(vec![F101::new(0)]).unwrap_or_else(|_| unreachable!())
});
let result = poly.evaluate(&[F101::new(0), F101::new(0)]);
assert!(result.is_err());
}
}