use crate::error::DoryError;
use crate::setup::ProverSetup;
use super::arithmetic::{DoryRoutines, Field, Group, PairingCurve};
use crate::mode::Mode;
pub trait MultilinearLagrange<F: Field>: Polynomial<F> {
fn lagrange_basis(&self, output: &mut [F], point: &[F]) {
multilinear_lagrange_basis(output, point)
}
fn vector_matrix_product(&self, left_vec: &[F], nu: usize, sigma: usize) -> Vec<F>;
fn compute_evaluation_vectors(&self, point: &[F], nu: usize, sigma: usize) -> (Vec<F>, Vec<F>) {
compute_left_right_vectors(point, nu, sigma)
}
}
pub trait Polynomial<F: Field> {
fn num_vars(&self) -> usize;
fn len(&self) -> usize {
1 << self.num_vars()
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn evaluate(&self, point: &[F]) -> F;
#[allow(clippy::type_complexity)]
fn commit<E, Mo, M1>(
&self,
nu: usize,
sigma: usize,
setup: &ProverSetup<E>,
) -> Result<(E::GT, Vec<E::G1>, F), DoryError>
where
E: PairingCurve,
Mo: Mode,
M1: DoryRoutines<E::G1>,
E::G1: Group<Scalar = F>,
E::GT: Group<Scalar = F>;
}
pub(crate) fn multilinear_lagrange_basis<F: Field>(output: &mut [F], point: &[F]) {
assert!(
output.len() <= (1 << point.len()),
"Output length must be at most 2^point.len()"
);
if point.is_empty() || output.is_empty() {
output.fill(F::one());
return;
}
let one_minus_p0 = F::one() - point[0];
output[0] = one_minus_p0;
if output.len() > 1 {
output[1] = point[0];
}
for (level, p) in point[1..].iter().enumerate() {
let mid = 1 << (level + 1);
let one_minus_p = F::one() - p;
if mid >= output.len() {
for val in output.iter_mut() {
*val = val.mul(&one_minus_p);
}
} else {
let (left, right) = output.split_at_mut(mid);
let k = left.len().min(right.len());
for (l, r) in left[..k].iter_mut().zip(right[..k].iter_mut()) {
let l_val = *l;
*r = l_val.mul(p);
*l = l_val.mul(&one_minus_p);
}
for l in left[k..].iter_mut() {
*l = l.mul(&one_minus_p);
}
}
}
}
pub fn compute_left_right_vectors<F: Field>(
point: &[F],
nu: usize,
sigma: usize,
) -> (Vec<F>, Vec<F>) {
let mut left_vec = vec![F::zero(); 1 << nu];
let mut right_vec = vec![F::zero(); 1 << sigma];
let point_dim = point.len();
match point_dim {
0 => {
left_vec[0] = F::one();
right_vec[0] = F::one();
}
n if n <= sigma => {
multilinear_lagrange_basis(&mut right_vec[..1 << point_dim], point);
left_vec[0] = F::one();
}
n if n <= nu + sigma => {
multilinear_lagrange_basis(&mut right_vec, &point[..sigma]);
multilinear_lagrange_basis(&mut left_vec[..1 << (point_dim - sigma)], &point[sigma..]);
}
_ => {
multilinear_lagrange_basis(&mut right_vec[..1 << sigma], &point[..sigma]);
multilinear_lagrange_basis(&mut left_vec, &point[sigma..]);
}
}
(left_vec, right_vec)
}