use crate::eq::eq;
use ark_ff::Field;
use std::vec::IntoIter;
use transcript::Message;
#[derive(Clone, Debug)]
pub struct MultiPoint<F: Field>(Vec<F>);
impl<F: Field> Message<F> for MultiPoint<F> {
fn len(vars: usize, _param_resolver: &transcript::params::ParamResolver) -> usize {
vars
}
fn to_field_elements(&self) -> Vec<F> {
self.0.clone()
}
}
impl<F: Field> From<Vec<F>> for MultiPoint<F> {
fn from(value: Vec<F>) -> Self {
Self(value)
}
}
impl<F: Field> From<&[F]> for MultiPoint<F> {
fn from(value: &[F]) -> Self {
Self(value.to_vec())
}
}
impl<F: Field> AsRef<[F]> for MultiPoint<F> {
fn as_ref(&self) -> &[F] {
&self.0
}
}
impl<F: Field> MultiPoint<F> {
pub fn new(vars: Vec<F>) -> Self {
MultiPoint(vars)
}
pub(crate) fn pop(mut self) -> (Self, F) {
let var = self.0.pop().unwrap();
(self, var)
}
pub(crate) fn pop_mut(&mut self) -> F {
self.0.pop().unwrap()
}
pub fn vars(&self) -> usize {
self.0.len()
}
pub fn inner(self) -> Vec<F> {
self.0
}
pub fn inner_ref(&self) -> &[F] {
&self.0
}
pub fn eval_as_eq(&self, point: &Self) -> F {
assert_eq!(self.0.len(), point.0.len());
self.0
.iter()
.zip(point.0.iter())
.fold(F::one(), |acc, (a, b)| {
let var = *a * b + (F::one() - a) * (F::one() - b);
acc * var
})
}
}
pub trait Evals<V>: Sized + Clone {
type Idx: Copy;
fn index(&self, index: Self::Idx) -> &V;
fn combine<C: Fn(V, V) -> V>(&self, other: &Self, f: C) -> Self;
fn flatten(self, vec: &mut Vec<V>);
fn unflatten(elems: &mut IntoIter<V>) -> Self;
fn flatten_vec(self) -> Vec<V> {
let mut vec = vec![];
self.flatten(&mut vec);
vec
}
fn unflatten_vec(vec: Vec<V>) -> Self {
let mut iter = vec.into_iter();
Self::unflatten(&mut iter)
}
}
pub trait EvalsExt<F: Field>: Evals<F> + Sized {
fn fix_var(mut mle: Vec<Self>, var: F) -> Vec<Self> {
let half_len = mle.len() / 2;
let one_minus_var = F::one() - var;
let (left, right) = mle.split_at_mut(half_len);
let f = |a, b| one_minus_var * a + var * b;
for (left, right) in left.iter_mut().zip(right) {
let left: &mut Self = left;
let comb = left.combine(right, f);
*left = comb;
}
mle.truncate(half_len);
mle
}
fn eval_slow(mle: Vec<Self>, point: MultiPoint<F>) -> Self {
assert_eq!(
mle.len().ilog2() as usize,
point.vars(),
"number of variables missmatch"
);
let (point, var) = point.pop();
let mle = Self::fix_var(mle, var);
if point.vars() == 0 {
mle.into_iter().next().unwrap()
} else {
Self::eval_slow(mle, point)
}
}
fn eval(mles: &[Self], point: MultiPoint<F>) -> Self {
use std::iter::Iterator;
assert_eq!(
mles.len().ilog2() as usize,
point.vars(),
"number of variables missmatch"
);
let eq: Vec<F> = eq(&point);
let dummy = mles[0].clone().flatten_vec();
let dummy: Self = Self::unflatten_vec(vec![F::zero(); dummy.len()]);
eq.into_iter().zip(mles).fold(dummy.clone(), |acc, x| {
let acc: Self = acc;
let (eq_eval, eval): (F, &Self) = x;
acc.combine(eval, |a, b| a + b * eq_eval)
})
}
fn eval_iter<M>(mut mles: M, point: MultiPoint<F>) -> Self
where
M: Iterator<Item = Self>,
{
let mut eq = eq(&point).into_iter();
let first: Self = mles.next().unwrap();
let first_eq = eq.next().unwrap();
let mut res = first.combine(&first, |e, _| e * first_eq);
loop {
match (mles.next(), eq.next()) {
(None, None) => {
break;
}
(None, Some(_)) | (Some(_), None) => {
panic!("unexpected number of evaluations")
}
(Some(e), Some(eq_eval)) => {
res = res.combine(&e, |a, b| a + b * eq_eval);
}
}
}
todo!()
}
}
impl<F, T> EvalsExt<F> for T
where
T: Evals<F> + Sized,
F: Field,
{
}
#[derive(Clone)]
pub struct SingleEval<F>(pub F);
impl<F> SingleEval<F> {
pub fn from_vec(mle: Vec<F>) -> Vec<Self> {
mle.into_iter().map(Self).collect()
}
}
impl<V: Copy> Evals<V> for SingleEval<V> {
type Idx = ();
fn combine<C: Fn(V, V) -> V>(&self, other: &Self, f: C) -> Self {
SingleEval(f(self.0, other.0))
}
fn index(&self, _index: Self::Idx) -> &V {
&self.0
}
fn flatten(self, vec: &mut Vec<V>) {
vec.push(self.0);
}
fn unflatten(elems: &mut IntoIter<V>) -> Self {
Self(elems.next().unwrap())
}
}
impl<F: Clone> SingleEval<F> {
pub fn from_field_elements(evals: &[F]) -> Vec<Self> {
evals.iter().cloned().map(SingleEval).collect()
}
}
pub mod simple_eval {
use super::Evals;
use crate::utils::ZeroCheckAvailable;
use std::fmt::Debug;
#[derive(Clone, Copy, Debug)]
pub struct SimpleEval<F, const N: usize>([F; N]);
impl<F, const N: usize> SimpleEval<F, N> {
pub const fn new(inner: [F; N]) -> Self {
Self(inner)
}
pub fn map<V, M>(self, f: M) -> SimpleEval<V, N>
where
M: Fn(F) -> V,
{
SimpleEval(self.0.map(f))
}
pub fn inner(&self) -> &[F; N] {
&self.0
}
}
impl ZeroCheckAvailable for usize {
fn zerocheck_eq() -> Self {
0
}
}
impl<V: Copy + Debug, const N: usize> Evals<V> for SimpleEval<V, N> {
type Idx = usize;
fn index(&self, index: Self::Idx) -> &V {
&self.0[index]
}
fn combine<C: Fn(V, V) -> V>(&self, other: &Self, f: C) -> Self {
let mut res = self.0;
for (i, res) in res.iter_mut().enumerate() {
*res = f(*res, other.0[i]);
}
Self(res)
}
fn flatten(self, vec: &mut Vec<V>) {
for i in 0..N {
vec.push(self.0[i]);
}
}
fn unflatten(elems: &mut std::vec::IntoIter<V>) -> Self {
let elems: Vec<V> = elems.take(N).collect();
Self(elems.try_into().unwrap())
}
}
}