use crate::{
degree::DegreeEnv,
polynomials::Evals,
sumcheck::SumcheckFunction,
symbolic::{
compute::MvPoly,
evaluate::{MvEvaluator, MvIr},
expression::{compute_mv_poly, ExpEnv, Expression, VarOrChall},
id_map::IdMap,
message_eval::MessageEvaluator,
},
};
use ark_ff::Field;
use std::collections::BTreeMap;
#[derive(Debug)]
pub struct SumcheckEvaluator<F: Field, S>
where
S: SumcheckFunction<F>,
{
inner: MessageEvaluator<F, u8>,
var_map: Vec<S::Idx>,
chall_map: Vec<S::ChallIdx>,
message_len: usize,
accumulator_init: Vec<F>,
}
impl<F: Field + Clone, S> Clone for SumcheckEvaluator<F, S>
where
S: SumcheckFunction<F>,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
var_map: self.var_map.clone(),
chall_map: self.chall_map.clone(),
message_len: self.message_len,
accumulator_init: self.accumulator_init.clone(),
}
}
}
type Var<F, S> = VarOrChall<<S as SumcheckFunction<F>>::Idx, <S as SumcheckFunction<F>>::ChallIdx>;
impl<F: Field, S> Default for SumcheckEvaluator<F, S>
where
S: SumcheckFunction<F>,
{
fn default() -> Self {
Self::new(None)
}
}
impl<F: Field, S> SumcheckEvaluator<F, S>
where
S: SumcheckFunction<F>,
{
pub fn new(f: Option<&S>) -> Self {
let env = ExpEnv;
let exp: Expression<F, S::Idx, S::ChallIdx> = {
match f {
Some(f) => f.symbolic_function(env).unwrap(),
None => S::function(env),
}
};
let poly: MvPoly<F, Var<F, S>> = compute_mv_poly(exp);
let evaluator = MvEvaluator::new(poly);
let program: &[MvIr<F, Var<F, S>>] = evaluator.program();
let (ir, var_map, chall_map) = Self::transpile(program);
let message_len = Self::message_len(f);
let inner: MessageEvaluator<F, u8> = MessageEvaluator::new(ir, message_len);
let accumulator_init = vec![F::zero(); message_len];
Self {
inner,
var_map,
chall_map,
message_len,
accumulator_init,
}
}
#[allow(clippy::type_complexity)]
fn transpile(
program: &[MvIr<F, Var<F, S>>],
) -> (
Vec<MvIr<F, VarOrChall<u8, u8>>>,
Vec<S::Idx>,
Vec<S::ChallIdx>,
) {
let kinds = S::KINDS;
let ids = S::map_evals(kinds, |_| ());
let ids = ids.flatten_vec().into_iter().enumerate().map(|(id, _)| id);
let ids = ids.collect();
let ids: S::Mles<usize> = S::Mles::unflatten_vec(ids);
let mut lookup: BTreeMap<usize, S::Idx> = BTreeMap::new();
let mut true_ids = IdMap::new();
let mut chall_ids = IdMap::new();
let ir = program.iter().map(|instruction| {
let instruction: &MvIr<F, Var<F, S>> = instruction;
let mut map_var = |var: S::Idx| {
let temp_id = *ids.index(var);
let _ = lookup.insert(temp_id, var);
let id = true_ids.get_id(temp_id);
let id = u8::try_from(id).unwrap();
VarOrChall::Var(id)
};
let mut map_var = |var: Var<F, S>| match var {
VarOrChall::Var(var) => map_var(var),
VarOrChall::Challenge(chall) => {
let id = chall_ids.get_id(chall);
let id = u8::try_from(id).unwrap();
VarOrChall::Challenge(id)
}
};
match instruction {
MvIr::PushChild(coeff, var) => MvIr::PushChild(*coeff, map_var(*var)),
MvIr::Add => MvIr::Add,
MvIr::Mul(var) => MvIr::Mul(map_var(*var)),
MvIr::AddConstantTerm(coeff) => MvIr::AddConstantTerm(*coeff),
}
});
let ir: Vec<MvIr<F, VarOrChall<u8, u8>>> = ir.chain([MvIr::Add]).collect();
let lookup = true_ids
.finish()
.into_iter()
.map(|temp_id| *lookup.get(&temp_id).unwrap())
.collect();
let chall_lookup = chall_ids.finish();
(ir, lookup, chall_lookup)
}
fn message_len(f: Option<&S>) -> usize {
let env = DegreeEnv::new();
let degree = {
match f {
Some(f) => f.symbolic_function(env).unwrap(),
None => S::function(env),
}
};
degree.0 + 1
}
fn set_challs(&mut self, challs: &S::Challs) {
for i in 0..self.chall_map.len() {
let index: S::ChallIdx = self.chall_map[i];
let chall = challs[index];
self.inner.set_chall(i, chall);
}
}
fn eval_accumulate(&mut self, evals: [&S::Mles<F>; 2]) {
let [left, right] = evals;
assert_eq!(self.inner.result().len(), self.message_len);
for i in 0..self.var_map.len() {
let index: S::Idx = self.var_map[i];
let left = left.index(index);
let right = right.index(index);
self.inner.set_var(i, (*left, *right));
}
self.inner.eval();
}
pub fn accumulator(&mut self, challenges: &S::Challs) -> EvalAccumulator<'_, F, S> {
self.inner.set_stack(&self.accumulator_init);
self.set_challs(challenges);
EvalAccumulator(self)
}
}
pub struct EvalAccumulator<'a, F: Field, S>(&'a mut SumcheckEvaluator<F, S>)
where
S: SumcheckFunction<F>;
impl<F: Field, S> EvalAccumulator<'_, F, S>
where
S: SumcheckFunction<F>,
{
pub fn eval_accumulate(&mut self, evals: [&S::Mles<F>; 2]) {
self.0.eval_accumulate(evals);
}
pub fn eval_and_zero(&mut self, evals: [&S::Mles<F>; 2]) -> Vec<F> {
self.0.eval_accumulate(evals);
let message_len = self.0.message_len;
let stack_len = self.0.inner.result().len();
assert_eq!(message_len, stack_len);
let res = self.0.inner.result().to_vec();
self.0.inner.set_stack(&self.0.accumulator_init);
res
}
pub fn finish(self) -> Vec<F> {
let message_len = self.0.message_len;
let stack_len = self.0.inner.result().len();
assert_eq!(message_len, stack_len);
let res = self.0.inner.result().to_vec();
self.0.inner.set_stack(&[]);
res
}
}