use crate::polynomials::MultiPoint;
use ark_ff::Field;
use std::ops::Mul;
fn eval_eq<F: Field>(dest: &mut [F], mut vars: Vec<F>, zero: F) {
assert!(dest.len().is_power_of_two());
if dest.len() == 2 {
assert_eq!(vars.len(), 1);
let var = vars.pop().unwrap();
dest[0] = zero;
dest[1] = zero * var;
} else {
assert_eq!(dest.len().ilog2() as usize, vars.len());
let half_len = dest.len() / 2;
let var = vars.pop().unwrap();
let (left, right) = dest.split_at_mut(half_len);
eval_eq(left, vars, zero);
for (l, r) in left.iter().zip(right.iter_mut()) {
let r: &mut F = r;
*r = var * l;
}
}
}
pub fn eq<F: Field>(point: &MultiPoint<F>) -> Vec<F> {
let n_log = point.vars();
eq_subset(point, n_log)
}
pub fn eq_subset<F: Field>(point: &MultiPoint<F>, n_log: usize) -> Vec<F> {
let vars = point.inner_ref();
assert!(vars.len() >= n_log, "subset bigger than full set");
assert!(n_log > 0, "subset must not be empty");
let len = 1 << n_log;
let one_minus_v: Vec<F> = vars.iter().map(|x| F::one() - x).collect();
let mut one_minus_v_inv = one_minus_v.clone();
ark_ff::fields::batch_inversion(&mut one_minus_v_inv);
let mut vars: Vec<F> = vars
.iter()
.zip(one_minus_v_inv)
.map(|(a, b)| *a * b)
.collect();
vars.truncate(n_log);
let zero: F = one_minus_v.iter().cloned().reduce(Mul::mul).unwrap();
let mut eq = Vec::with_capacity(len);
eq.resize(len, F::zero());
eval_eq(&mut eq, vars, zero);
eq
}
#[test]
fn test_eq() {
use crate::polynomials::{EvalsExt, SingleEval};
use ark_vesta::Fr;
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let mut r_point = || rng.gen::<Fr>();
let vars = 4;
let point = vec![r_point(); vars];
let point = MultiPoint::new(point);
let eq_evals = eq(&point);
let check_poly = vec![r_point(); eq_evals.len()];
let eq_eval = eq_evals
.iter()
.cloned()
.zip(check_poly.iter())
.fold(Fr::from(0), |sum, (a, b)| sum + a * b);
let check_poly: Vec<_> = check_poly.into_iter().map(SingleEval).collect();
let check_eval = EvalsExt::eval_slow(check_poly, point).0;
assert_eq!(eq_eval, check_eval);
}
#[test]
fn test_subset() {
use ark_vesta::Fr;
let vars: [Fr; 4] = [2_u32, 3, 4, 5].map(Fr::from);
let point: MultiPoint<Fr> = MultiPoint::new(vars.to_vec());
let full_eq = eq(&point);
let subset_eq = eq_subset(&point, 2);
for i in 0..4 {
assert_eq!(full_eq[i], subset_eq[i]);
}
}