use ark_ff::Field;
use ark_std::rand::{distributions::Standard, prelude::Distribution, Rng, RngCore};
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub struct MultilinearPoint<F>(pub Vec<F>);
impl<F> MultilinearPoint<F>
where
F: Field,
{
#[inline]
pub const fn num_variables(&self) -> usize {
self.0.len()
}
pub(crate) fn eq_poly(&self, mut point: usize) -> F {
let n_variables = self.num_variables();
assert!(point < (1 << n_variables));
let mut acc = F::ONE;
for val in self.0.iter().rev() {
let b = point % 2;
acc *= if b == 1 { *val } else { F::ONE - *val };
point >>= 1;
}
acc
}
pub(crate) fn eq_weights(&self) -> Vec<F> {
(0..1 << self.0.len())
.map(|point| self.eq_poly(point))
.collect()
}
}
impl<F> MultilinearPoint<F>
where
Standard: Distribution<F>,
{
pub fn rand(rng: &mut impl RngCore, num_variables: usize) -> Self {
Self((0..num_variables).map(|_| rng.gen()).collect())
}
}
impl<F> From<F> for MultilinearPoint<F> {
fn from(value: F) -> Self {
Self(vec![value])
}
}
#[cfg(test)]
#[allow(
clippy::identity_op,
clippy::cast_sign_loss,
clippy::erasing_op,
clippy::should_panic_without_expect
)]
mod tests {
use ark_ff::AdditiveGroup;
use ark_std::rand::thread_rng;
use super::*;
use crate::algebra::fields::Field64;
#[test]
fn test_n_variables() {
let point =
MultilinearPoint::<Field64>(vec![Field64::from(1), Field64::from(0), Field64::from(1)]);
assert_eq!(point.num_variables(), 3);
}
#[test]
fn test_eq_poly_all_zeros() {
let ml_point = MultilinearPoint(vec![Field64::ZERO; 4]);
let binary_point = 0b0000;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_eq_poly_all_ones() {
let ml_point = MultilinearPoint(vec![Field64::ONE; 4]);
let binary_point = 0b1111;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_eq_poly_mixed_bits_match() {
let ml_point = MultilinearPoint(vec![
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
]);
let binary_point = 0b1010;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_eq_poly_mixed_bits_mismatch() {
let ml_point = MultilinearPoint(vec![
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
]);
let binary_point = 0b1100;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ZERO);
}
#[test]
fn test_eq_poly_single_variable_match() {
let ml_point = MultilinearPoint(vec![Field64::ONE]);
let binary_point = 0b1;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_eq_poly_single_variable_mismatch() {
let ml_point = MultilinearPoint(vec![Field64::ONE]);
let binary_point = 0b0;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ZERO);
}
#[test]
fn test_eq_poly_large_binary_number_match() {
let ml_point = MultilinearPoint(vec![
Field64::ONE,
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ONE,
Field64::ZERO,
]);
let binary_point = 0b1101_0110;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_eq_poly_large_binary_number_mismatch() {
let ml_point = MultilinearPoint(vec![
Field64::ONE,
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ZERO,
Field64::ONE,
Field64::ONE,
Field64::ZERO,
]);
let binary_point = 0b1101_0111;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ZERO);
}
#[test]
fn test_eq_poly_empty_vector() {
let ml_point = MultilinearPoint::<Field64>(vec![]);
let binary_point = 0;
assert_eq!(ml_point.eq_poly(binary_point), Field64::ONE);
}
#[test]
fn test_equality() {
let point = MultilinearPoint(vec![Field64::from(0), Field64::from(0)]);
assert_eq!(point.eq_poly(0b00), Field64::from(1));
assert_eq!(point.eq_poly(0b01), Field64::from(0));
assert_eq!(point.eq_poly(0b10), Field64::from(0));
assert_eq!(point.eq_poly(0b11), Field64::from(0));
let point = MultilinearPoint(vec![Field64::from(1), Field64::from(0)]);
assert_eq!(point.eq_poly(0b00), Field64::from(0));
assert_eq!(point.eq_poly(0b01), Field64::from(0));
assert_eq!(point.eq_poly(0b10), Field64::from(1));
assert_eq!(point.eq_poly(0b11), Field64::from(0));
}
#[test]
fn test_multilinear_point_rand_not_all_same() {
const K: usize = 20; const N: usize = 10;
let mut rng = thread_rng();
let mut all_same_count = 0;
for _ in 0..K {
let point = MultilinearPoint::<Field64>::rand(&mut rng, N);
let first = point.0[0];
if point.0.iter().all(|&x| x == first) {
all_same_count += 1;
}
}
assert!(
all_same_count < K,
"rand generated uniform points in all {K} trials"
);
}
}