use alloc::vec;
use alloc::vec::Vec;
use itertools::Itertools;
use p3_field::{ExtensionField, Field, dot_product};
use p3_maybe_rayon::prelude::*;
use p3_multilinear_util::point::Point;
use p3_multilinear_util::poly::Poly;
use p3_multilinear_util::split_eq::SplitEq;
use p3_util::log2_strict_usize;
use crate::layout::ProverMultiClaim;
use crate::strategy::VariableOrder;
fn evals_01inf_grid_into<F: Field>(boolean_evals: &[F], output: &mut [F], scratch: &mut [F]) {
let num_variables = log2_strict_usize(boolean_evals.len());
let output_len = 3usize.pow(num_variables as u32);
assert_eq!(output.len(), output_len);
assert_eq!(scratch.len(), output_len);
if num_variables == 0 {
output[0] = boolean_evals[0];
return;
}
let (mut cur, mut next) = if num_variables % 2 == 1 {
scratch[..boolean_evals.len()].copy_from_slice(boolean_evals);
(&mut scratch[..], &mut output[..])
} else {
output[..boolean_evals.len()].copy_from_slice(boolean_evals);
(&mut output[..], &mut scratch[..])
};
const PARALLEL_STRIDE_THRESHOLD: usize = 256;
for stage in 0..num_variables {
let in_stride = 3usize.pow(stage as u32);
let blocks = 1usize << (num_variables - stage - 1);
let cur_slice = &cur[..blocks * 2 * in_stride];
let next_slice = &mut next[..blocks * 3 * in_stride];
if in_stride < PARALLEL_STRIDE_THRESHOLD {
cur_slice
.par_chunks(2 * in_stride)
.zip(next_slice.par_chunks_mut(3 * in_stride))
.for_each(|(c_chunk, n_chunk)| {
for j in 0..in_stride {
let f0 = c_chunk[j];
let f1 = c_chunk[in_stride + j];
n_chunk[3 * j] = f0;
n_chunk[3 * j + 1] = f1;
n_chunk[3 * j + 2] = f1 - f0;
}
});
} else {
cur_slice
.chunks(2 * in_stride)
.zip(next_slice.chunks_mut(3 * in_stride))
.for_each(|(c_chunk, n_chunk)| {
let (c_left, c_right) = c_chunk.split_at(in_stride);
c_left
.par_iter()
.zip(c_right.par_iter())
.zip(n_chunk.par_chunks_mut(3))
.for_each(|((&f0, &f1), out)| {
out[0] = f0;
out[1] = f1;
out[2] = f1 - f0;
});
});
}
core::mem::swap(&mut cur, &mut next);
}
}
fn evals_01inf_grid_prefix<F: Field>(evals: &[F]) -> Vec<F> {
fn reverse_ternary_digits(mut idx: usize, l: usize) -> usize {
let mut rev = 0usize;
for _ in 0..l {
rev = 3 * rev + (idx % 3);
idx /= 3;
}
rev
}
let grid_len = 3usize.pow(log2_strict_usize(evals.len()) as u32);
let l = log2_strict_usize(evals.len());
let mut prefix = F::zero_vec(grid_len);
let mut scratch = F::zero_vec(grid_len);
evals_01inf_grid_into(evals, &mut prefix, &mut scratch);
let mut out = F::zero_vec(grid_len);
for (src_idx, value) in prefix.into_iter().enumerate() {
out[reverse_ternary_digits(src_idx, l)] = value;
}
out
}
pub(crate) fn calculate_accumulators_batch<F: Field, EF: ExtensionField<F>>(
claim: &ProverMultiClaim<F, EF>,
alphas: &[EF],
) -> SvoAccumulators<EF> {
assert_eq!(claim.len(), alphas.len());
let k = claim.point().num_variables_svo();
(0..k)
.map(|round_idx| {
let l = round_idx + 1;
let mut acc = Poly::<EF>::zero(l);
claim
.openings()
.iter()
.map(|opening| &opening.data()[round_idx])
.zip(alphas.iter())
.for_each(|(partial, &alpha)| {
acc.as_mut_slice()
.iter_mut()
.zip_eq(partial.iter())
.for_each(|(out, &f)| *out += alpha * f);
});
if matches!(claim.point().var_order(), VariableOrder::Prefix) {
let (svo_active, _) = claim.point().z_svo().split_at(l);
return calculate_accumulator::<F, EF>(l, acc.as_slice(), svo_active.as_slice());
}
let (_, svo_active) = claim.point().z_svo().split_at(k - l);
let eq_grid = evals_01inf_grid_prefix(
Poly::new_from_point(svo_active.as_slice(), EF::ONE).as_slice(),
);
let acc_grid = evals_01inf_grid_prefix(acc.as_slice());
let stride = 3usize.pow(round_idx as u32);
let acc0 = eq_grid[..stride]
.iter()
.zip(acc_grid[..stride].iter())
.map(|(&eq, &eval)| eq * eval)
.collect::<Vec<_>>();
let acc_inf = eq_grid[2 * stride..]
.iter()
.zip(acc_grid[2 * stride..].iter())
.map(|(&eq, &eval)| eq * eval)
.collect::<Vec<_>>();
[acc0, acc_inf]
})
.collect()
}
fn calculate_accumulator<F: Field, EF: ExtensionField<F>>(
l: usize,
partial_evals: &[EF],
point: &[EF],
) -> [Vec<EF>; 2] {
let total_vars = log2_strict_usize(partial_evals.len());
let offset = total_vars - l;
let (z0, z1) = point.split_at(point.len() - offset);
let eq0 = Poly::new_from_point(z0, EF::ONE);
let eq1 = Poly::new_from_point(z1, EF::ONE);
let reduced_evals = partial_evals
.chunks(eq1.num_evals())
.map(|chunk| dot_product::<EF, _, _>(eq1.iter().copied(), chunk.iter().copied()))
.collect::<Vec<_>>();
match l {
1 => calculate_accumulator_1(eq0.as_slice(), &reduced_evals),
2 => calculate_accumulator_2(eq0.as_slice(), &reduced_evals),
3 => calculate_accumulator_3(eq0.as_slice(), &reduced_evals),
_ => calculate_accumulator_general(l, eq0.as_slice(), &reduced_evals),
}
}
fn calculate_accumulator_1<EF: Field>(eq0: &[EF], reduced: &[EF]) -> [Vec<EF>; 2] {
assert_eq!(eq0.len(), 2);
assert_eq!(reduced.len(), 2);
let (e0, e1) = (eq0[0], eq0[1]);
let (r0, r1) = (reduced[0], reduced[1]);
let e_inf = e1 - e0;
let r_inf = r1 - r0;
[vec![e0 * r0], vec![e_inf * r_inf]]
}
fn calculate_accumulator_2<EF: Field>(eq0: &[EF], reduced: &[EF]) -> [Vec<EF>; 2] {
assert_eq!(eq0.len(), 4);
assert_eq!(reduced.len(), 4);
let (e00, e10, e01, e11) = (eq0[0], eq0[1], eq0[2], eq0[3]);
let (r00, r10, r01, r11) = (reduced[0], reduced[1], reduced[2], reduced[3]);
let e20 = e10 - e00;
let e21 = e11 - e01;
let r20 = r10 - r00;
let r21 = r11 - r01;
let e02 = e01 - e00;
let r02 = r01 - r00;
[
vec![e00 * r00, e01 * r01, e02 * r02],
vec![e20 * r20, e21 * r21, (e21 - e20) * (r21 - r20)],
]
}
fn calculate_accumulator_3<EF: Field>(eq0: &[EF], reduced: &[EF]) -> [Vec<EF>; 2] {
assert_eq!(eq0.len(), 8);
assert_eq!(reduced.len(), 8);
let (e_000, e_100, e_010, e_110, e_001, e_101, e_011, e_111) = (
eq0[0], eq0[1], eq0[2], eq0[3], eq0[4], eq0[5], eq0[6], eq0[7],
);
let (r_000, r_100, r_010, r_110, r_001, r_101, r_011, r_111) = (
reduced[0], reduced[1], reduced[2], reduced[3], reduced[4], reduced[5], reduced[6],
reduced[7],
);
let e_200 = e_100 - e_000;
let e_210 = e_110 - e_010;
let e_201 = e_101 - e_001;
let e_211 = e_111 - e_011;
let r_200 = r_100 - r_000;
let r_210 = r_110 - r_010;
let r_201 = r_101 - r_001;
let r_211 = r_111 - r_011;
let e_020 = e_010 - e_000;
let e_220 = e_210 - e_200;
let e_021 = e_011 - e_001;
let e_221 = e_211 - e_201;
let r_020 = r_010 - r_000;
let r_220 = r_210 - r_200;
let r_021 = r_011 - r_001;
let r_221 = r_211 - r_201;
let e_002 = e_001 - e_000;
let e_012 = e_011 - e_010;
let e_022 = e_021 - e_020;
let e_202 = e_201 - e_200;
let e_212 = e_211 - e_210;
let e_222 = e_221 - e_220;
let r_002 = r_001 - r_000;
let r_012 = r_011 - r_010;
let r_022 = r_021 - r_020;
let r_202 = r_201 - r_200;
let r_212 = r_211 - r_210;
let r_222 = r_221 - r_220;
[
vec![
e_000 * r_000,
e_001 * r_001,
e_002 * r_002,
e_010 * r_010,
e_011 * r_011,
e_012 * r_012,
e_020 * r_020,
e_021 * r_021,
e_022 * r_022,
],
vec![
e_200 * r_200,
e_201 * r_201,
e_202 * r_202,
e_210 * r_210,
e_211 * r_211,
e_212 * r_212,
e_220 * r_220,
e_221 * r_221,
e_222 * r_222,
],
]
}
fn calculate_accumulator_general<F: Field, EF: ExtensionField<F>>(
l: usize,
eq0: &[EF],
reduced_evals: &[EF],
) -> [Vec<EF>; 2] {
let grid_len = 3usize.pow(l as u32);
let mut eq0_grid = EF::zero_vec(grid_len);
let mut reduced_grid = EF::zero_vec(grid_len);
let mut scratch = EF::zero_vec(grid_len);
evals_01inf_grid_into(eq0, &mut eq0_grid, &mut scratch);
evals_01inf_grid_into(reduced_evals, &mut reduced_grid, &mut scratch);
let stride = 3usize.pow((l - 1) as u32);
let acc0 = eq0_grid[..stride]
.iter()
.copied()
.zip(reduced_grid[..stride].iter().copied())
.map(|(eq, eval)| eq * eval)
.collect();
let acc_inf = eq0_grid[2 * stride..]
.iter()
.copied()
.zip(reduced_grid[2 * stride..].iter().copied())
.map(|(eq, eval)| eq * eval)
.collect();
[acc0, acc_inf]
}
pub(super) type SvoAccumulators<EF> = Vec<[Vec<EF>; 2]>;
#[derive(Debug, Clone)]
pub struct SvoPoint<F: Field, EF: ExtensionField<F>> {
pub(crate) z_svo: Point<EF>,
pub(crate) z_split: SplitEq<F, EF>,
var_order: VariableOrder,
}
impl<F: Field, EF: ExtensionField<F>> SvoPoint<F, EF> {
pub fn new_unpacked(l0: usize, point: &Point<EF>, var_order: VariableOrder) -> Self {
assert!(l0 <= point.num_variables());
let (svo, split) = match var_order {
VariableOrder::Prefix => point.split_at(l0),
VariableOrder::Suffix => {
let (split, svo) = point.split_at(point.num_variables() - l0);
(svo, split)
}
};
let split = SplitEq::new_unpacked(&split, EF::ONE);
Self {
z_svo: svo,
z_split: split,
var_order,
}
}
pub fn new_packed(l0: usize, point: &Point<EF>) -> Self {
assert!(l0 <= point.num_variables());
let (svo, split) = point.split_at(l0);
let split = SplitEq::new_packed(&split, EF::ONE);
Self {
z_svo: svo,
z_split: split,
var_order: VariableOrder::Prefix,
}
}
pub fn accumulate_into(&self, out: &mut [EF], rs: &Point<EF>, mut scale: EF) {
assert_eq!(rs.num_variables(), self.num_variables_svo());
scale *= Point::eval_eq(self.z_svo.as_slice(), rs.as_slice());
self.z_split.accumulate_into(out, Some(scale));
}
pub fn accumulate_into_packed(
&self,
out: &mut [EF::ExtensionPacking],
rs: &Point<EF>,
mut scale: EF,
) {
assert!(matches!(self.var_order, VariableOrder::Prefix));
assert_eq!(rs.num_variables(), self.num_variables_svo());
scale *= Point::eval_eq(self.z_svo.as_slice(), rs.as_slice());
self.z_split.accumulate_into_packed(out, Some(scale));
}
pub fn eval(&self, poly: &Poly<F>) -> (EF, Vec<Poly<EF>>) {
assert_eq!(self.num_variables(), poly.num_variables());
let (compressed, partial_evals) = match self.var_order {
VariableOrder::Prefix => {
let compressed = self.z_split.compress_suffix(poly);
let partial_evals = (1..=self.num_variables_svo())
.map(|i| {
let (_svo_active, svo_rest) = self.z_svo.split_at(i);
compressed.compress_suffix(&svo_rest, EF::ONE)
})
.collect::<Vec<_>>();
(compressed, partial_evals)
}
VariableOrder::Suffix => {
let compressed = self.z_split.compress_prefix(poly);
let partial_evals = (1..=self.num_variables_svo())
.map(|i| {
let (svo_rest, _svo_active) =
self.z_svo.split_at(self.z_svo.num_variables() - i);
compressed.compress_prefix(&svo_rest, EF::ONE)
})
.collect::<Vec<_>>();
(compressed, partial_evals)
}
};
let eval = compressed.eval_base(&self.z_svo);
(eval, partial_evals)
}
pub const fn num_variables_svo(&self) -> usize {
self.z_svo.num_variables()
}
pub const fn num_variables(&self) -> usize {
self.z_svo.num_variables() + self.z_split.num_variables()
}
pub const fn z_svo(&self) -> &Point<EF> {
&self.z_svo
}
pub const fn z_split(&self) -> &SplitEq<F, EF> {
&self.z_split
}
pub const fn var_order(&self) -> VariableOrder {
self.var_order
}
}
#[cfg(test)]
mod test {
use alloc::vec;
use alloc::vec::Vec;
use p3_field::extension::BinomialExtensionField;
use p3_field::{PackedFieldExtension, PackedValue, PrimeCharacteristicRing, dot_product};
use p3_koala_bear::KoalaBear;
use proptest::prelude::*;
use rand::rngs::SmallRng;
use rand::{RngExt, SeedableRng};
use super::*;
use crate::lagrange::lagrange_weights_01inf_multi;
use crate::layout::Opening;
use crate::strategy::VariableOrder;
type F = KoalaBear;
type EF = BinomialExtensionField<F, 4>;
fn evals_01inf_grid(boolean_evals: &[EF]) -> Vec<EF> {
let num_variables = log2_strict_usize(boolean_evals.len());
let output_len = 3usize.pow(num_variables as u32);
let mut output = EF::zero_vec(output_len);
let mut scratch = EF::zero_vec(output_len);
evals_01inf_grid_into(boolean_evals, &mut output, &mut scratch);
output
}
fn assert_evals_01inf_grid_correct(boolean_evals: &[EF]) {
let num_variables = log2_strict_usize(boolean_evals.len());
let grid = evals_01inf_grid(boolean_evals);
let inner_groups = grid.len() / 3;
for g in 0..inner_groups {
let v0 = grid[3 * g];
let v1 = grid[3 * g + 1];
let v_inf = grid[3 * g + 2];
assert_eq!(
v_inf,
v1 - v0,
"f(inf) != f(1)-f(0) at group {g}, num_variables={num_variables}"
);
}
}
#[test]
fn test_evals_01inf_grid_into_zero_vars() {
let c = EF::from_u32(42);
let input = [c];
let mut output = [EF::ZERO];
let mut scratch = [EF::ZERO];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
assert_eq!(output, [c]);
}
#[test]
fn test_evals_01inf_grid_into_one_var() {
let f0 = EF::from_u32(3);
let f1 = EF::from_u32(7);
let input = [f0, f1];
let mut output = [EF::ZERO; 3];
let mut scratch = [EF::ZERO; 3];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
assert_eq!(output[0], f0);
assert_eq!(output[1], f1);
assert_eq!(output[2], EF::from_u32(4));
}
#[test]
fn test_evals_01inf_grid_into_two_vars_hand_computed() {
let input = [1, 3, 5, 11].map(EF::from_u32);
let mut output = [EF::ZERO; 9];
let mut scratch = [EF::ZERO; 9];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
let expected = [1, 5, 4, 3, 11, 8, 2, 6, 4].map(EF::from_u32);
assert_eq!(output, expected);
}
#[test]
fn test_evals_01inf_grid_into_output_size() {
for num_variables in 1..=5 {
let input_len = 1 << num_variables;
let output_len = 3usize.pow(num_variables as u32);
let input = EF::zero_vec(input_len);
let mut output = EF::zero_vec(output_len);
let mut scratch = EF::zero_vec(output_len);
evals_01inf_grid_into(&input, &mut output, &mut scratch);
}
}
#[test]
fn test_evals_01inf_grid_into_result_lands_in_output() {
let mut rng = SmallRng::seed_from_u64(123);
for num_variables in 1..=4 {
let input: Vec<EF> = (0..1 << num_variables).map(|_| rng.random()).collect();
let output_len = 3usize.pow(num_variables as u32);
let mut output = EF::zero_vec(output_len);
let mut scratch = EF::zero_vec(output_len);
evals_01inf_grid_into(&input, &mut output, &mut scratch);
let reference = evals_01inf_grid(&input);
assert_eq!(
output, reference,
"ping-pong mismatch for num_variables={num_variables}"
);
}
}
#[test]
fn test_evals_01inf_grid_into_preserves_boolean_points() {
let mut rng = SmallRng::seed_from_u64(77);
for num_variables in 1..=4 {
let input: Vec<EF> = (0..1 << num_variables).map(|_| rng.random()).collect();
let grid = evals_01inf_grid(&input);
for (bool_idx, &input_val) in input.iter().enumerate() {
let mut bits = Vec::with_capacity(num_variables);
let mut tmp = bool_idx;
for _ in 0..num_variables {
bits.push(tmp & 1);
tmp >>= 1;
}
let mut ternary_idx = 0;
let mut power_of_3 = 1;
for &b in bits.iter().rev() {
ternary_idx += b * power_of_3;
power_of_3 *= 3;
}
assert_eq!(
grid[ternary_idx], input_val,
"Boolean point mismatch at bool_idx={bool_idx}, num_variables={num_variables}"
);
}
}
}
#[test]
fn test_evals_01inf_grid_into_constant_polynomial() {
let c = EF::from_u32(99);
for num_variables in 0..=4 {
let input = vec![c; 1 << num_variables];
let output_len = 3usize.pow(num_variables as u32);
let mut output = EF::zero_vec(output_len);
let mut scratch = EF::zero_vec(output_len);
evals_01inf_grid_into(&input, &mut output, &mut scratch);
for (idx, &val) in output.iter().enumerate() {
let has_inf = {
let mut tmp = idx;
let mut found = false;
for _ in 0..num_variables {
if tmp % 3 == 2 {
found = true;
}
tmp /= 3;
}
found
};
let expected = if has_inf { EF::ZERO } else { c };
assert_eq!(
val, expected,
"constant polynomial mismatch at idx={idx}, num_variables={num_variables}"
);
}
}
}
#[test]
fn test_evals_01inf_grid_into_linearity() {
let mut rng = SmallRng::seed_from_u64(55);
let num_variables = 3;
let n = 1 << num_variables;
let f: Vec<EF> = (0..n).map(|_| rng.random()).collect();
let g: Vec<EF> = (0..n).map(|_| rng.random()).collect();
let a: EF = rng.random();
let b: EF = rng.random();
let combined: Vec<EF> = f
.iter()
.zip(g.iter())
.map(|(&fi, &gi)| a * fi + b * gi)
.collect();
let grid_combined = evals_01inf_grid(&combined);
let grid_f = evals_01inf_grid(&f);
let grid_g = evals_01inf_grid(&g);
let linear_combined: Vec<EF> = grid_f
.iter()
.zip(grid_g.iter())
.map(|(&fi, &gi)| a * fi + b * gi)
.collect();
assert_eq!(grid_combined, linear_combined);
}
#[test]
fn test_evals_01inf_grid_into_large_stride_branch_matches_naive() {
let num_variables = 7;
let mut rng = SmallRng::seed_from_u64(2025);
let evals: Vec<EF> = (0..1 << num_variables).map(|_| rng.random()).collect();
assert_evals_01inf_grid_correct(evals.as_slice());
}
#[test]
#[should_panic(expected = "Not a power of two")]
fn test_evals_01inf_grid_into_panics_on_non_power_of_two_input() {
let input = [EF::ZERO; 3];
let mut output = [EF::ZERO; 3];
let mut scratch = [EF::ZERO; 3];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn test_evals_01inf_grid_into_panics_on_wrong_output_len() {
let input = [EF::ZERO; 4];
let mut output = [EF::ZERO; 8];
let mut scratch = [EF::ZERO; 9];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
}
#[test]
#[should_panic(expected = "assertion `left == right` failed")]
fn test_evals_01inf_grid_into_panics_on_wrong_scratch_len() {
let input = [EF::ZERO; 4];
let mut output = [EF::ZERO; 9];
let mut scratch = [EF::ZERO; 8];
evals_01inf_grid_into(&input, &mut output, &mut scratch);
}
#[test]
fn test_batch_svo_accumulators() {
let k = 12;
let n_polys = 3;
let mut rng = SmallRng::seed_from_u64(0);
let polys = (0..n_polys)
.map(|_| Poly::<F>::rand(&mut rng, k))
.collect::<Vec<_>>();
let alphas = (0..polys.len())
.map(|_| rng.random::<EF>())
.collect::<Vec<_>>();
let point = Point::<EF>::rand(&mut rng, k);
for l0 in 0..=k / 2 {
let svo_point = SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Suffix);
let openings = polys
.iter()
.map(|poly| {
let (eval, partial_evals) = svo_point.eval(poly);
let opening = Opening {
poly_idx: None,
eval,
data: partial_evals,
};
assert_eq!(opening.eval(), poly.eval_base(&point));
opening
})
.collect::<Vec<_>>();
let claim = ProverMultiClaim::new(svo_point, openings);
let accumulators = calculate_accumulators_batch(&claim, &alphas);
if l0 == 0 {
assert!(accumulators.is_empty());
continue;
}
let mut poly = Poly::<EF>::zero(l0);
claim
.openings()
.iter()
.zip(alphas.iter())
.for_each(|(opening, &alpha)| {
let full_svo_poly = opening
.data()
.last()
.expect("l0 > 0 guarantees one SVO partial polynomial");
poly.as_mut_slice()
.iter_mut()
.zip(full_svo_poly.iter())
.for_each(|(out, &value)| *out += alpha * value);
});
let mut eq = Poly::new_from_point(claim.point().z_svo().as_slice(), EF::ONE);
let mut rs = Vec::with_capacity(l0);
for [acc0, acc_inf] in accumulators.iter() {
let weights = lagrange_weights_01inf_multi(rs.as_slice());
let c0 = dot_product::<EF, _, _>(acc0.iter().copied(), weights.iter().copied());
let cinf =
dot_product::<EF, _, _>(acc_inf.iter().copied(), weights.iter().copied());
let (c0_ref, cinf_ref) =
VariableOrder::Suffix.sumcheck_coefficients(poly.as_slice(), eq.as_slice());
assert_eq!(c0, c0_ref);
assert_eq!(cinf, cinf_ref);
let r: EF = rng.random();
poly.fix_suffix_var_mut(r);
eq.fix_suffix_var_mut(r);
rs.push(r);
}
}
for l0 in 0..=k / 2 {
let svo_point = SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Prefix);
let openings = polys
.iter()
.map(|poly| {
let (eval, partial_evals) = svo_point.eval(poly);
let opening = Opening {
poly_idx: None,
eval,
data: partial_evals,
};
assert_eq!(opening.eval(), poly.eval_base(&point));
opening
})
.collect::<Vec<_>>();
let claim = ProverMultiClaim::new(svo_point, openings);
let accumulators = calculate_accumulators_batch(&claim, &alphas);
if l0 == 0 {
assert!(accumulators.is_empty());
continue;
}
let mut poly = Poly::<EF>::zero(l0);
claim
.openings()
.iter()
.zip(alphas.iter())
.for_each(|(opening, &alpha)| {
let full_svo_poly = opening
.data()
.last()
.expect("l0 > 0 guarantees one SVO partial polynomial");
poly.as_mut_slice()
.iter_mut()
.zip(full_svo_poly.iter())
.for_each(|(out, &value)| *out += alpha * value);
});
let mut eq = Poly::new_from_point(claim.point().z_svo().as_slice(), EF::ONE);
let mut rs = Vec::with_capacity(l0);
for [acc0, acc_inf] in accumulators.iter() {
let weights = lagrange_weights_01inf_multi(rs.as_slice());
let c0 = dot_product::<EF, _, _>(acc0.iter().copied(), weights.iter().copied());
let cinf =
dot_product::<EF, _, _>(acc_inf.iter().copied(), weights.iter().copied());
let (c0_ref, cinf_ref) =
VariableOrder::Prefix.sumcheck_coefficients(poly.as_slice(), eq.as_slice());
assert_eq!(c0, c0_ref);
assert_eq!(cinf, cinf_ref);
let r: EF = rng.random();
poly.fix_prefix_var_mut(r);
eq.fix_prefix_var_mut(r);
rs.push(r);
}
}
}
proptest! {
#[test]
fn prop_evals_01inf_grid_matches_naive(num_variables in 1usize..=5) {
let mut rng = SmallRng::seed_from_u64(num_variables as u64);
let evals: Vec<EF> = (0..1 << num_variables).map(|_| rng.random()).collect();
assert_evals_01inf_grid_correct(evals.as_slice());
}
#[test]
fn prop_accumulators_specialization_matches_general(k in 10usize..=14) {
let mut rng = SmallRng::seed_from_u64(k as u64);
let poly = Poly::new((0..1 << k).map(|_| rng.random()).collect());
let point = Point::<EF>::rand(&mut rng, k);
let (z_svo, z_rest) = point.split_at(k / 2);
let split_eq = SplitEq::<F, EF>::new_packed(&z_rest, EF::ONE);
let partial_evals = split_eq.compress_suffix(&poly);
for l in 1..k / 2 {
let dispatched =
calculate_accumulator::<F, EF>(l, partial_evals.as_slice(), z_svo.as_slice());
let eq0 = Poly::new_from_point(&z_svo.as_slice()[..l], EF::ONE);
let eq1 = Poly::new_from_point(&z_svo.as_slice()[l..], EF::ONE);
let reduced: Vec<EF> = partial_evals
.as_slice()
.chunks(eq1.num_evals())
.map(|chunk| dot_product::<EF, _, _>(eq1.iter().copied(), chunk.iter().copied()))
.collect();
let general = calculate_accumulator_general::<F, EF>(l, eq0.as_slice(), &reduced);
prop_assert_eq!(dispatched, general);
}
}
#[test]
fn prop_accumulators_1_matches_general(seed in 0u64..1000) {
let mut rng = SmallRng::seed_from_u64(seed);
let eq0: Vec<EF> = (0..2).map(|_| rng.random()).collect();
let reduced: Vec<EF> = (0..2).map(|_| rng.random()).collect();
let fast = calculate_accumulator_1(&eq0, &reduced);
let general = calculate_accumulator_general::<F, EF>(1, &eq0, &reduced);
prop_assert_eq!(fast, general);
}
#[test]
fn prop_accumulators_2_matches_general(seed in 0u64..1000) {
let mut rng = SmallRng::seed_from_u64(seed);
let eq0: Vec<EF> = (0..4).map(|_| rng.random()).collect();
let reduced: Vec<EF> = (0..4).map(|_| rng.random()).collect();
let fast = calculate_accumulator_2(&eq0, &reduced);
let general = calculate_accumulator_general::<F, EF>(2, &eq0, &reduced);
prop_assert_eq!(fast, general);
}
#[test]
fn prop_accumulators_3_matches_general(seed in 0u64..1000) {
let mut rng = SmallRng::seed_from_u64(seed);
let eq0: Vec<EF> = (0..8).map(|_| rng.random()).collect();
let reduced: Vec<EF> = (0..8).map(|_| rng.random()).collect();
let fast = calculate_accumulator_3(&eq0, &reduced);
let general = calculate_accumulator_general::<F, EF>(3, &eq0, &reduced);
prop_assert_eq!(fast, general);
}
#[test]
fn prop_evals_01inf_grid_preserves_boolean_points(num_variables in 1usize..=6) {
let mut rng = SmallRng::seed_from_u64(num_variables as u64 + 1000);
let input: Vec<EF> = (0..1 << num_variables).map(|_| rng.random()).collect();
let grid = evals_01inf_grid(&input);
for (bool_idx, &input_val) in input.iter().enumerate() {
let mut bits = Vec::with_capacity(num_variables);
let mut tmp = bool_idx;
for _ in 0..num_variables {
bits.push(tmp & 1);
tmp >>= 1;
}
let mut ternary_idx = 0;
let mut power_of_3 = 1;
for &b in bits.iter().rev() {
ternary_idx += b * power_of_3;
power_of_3 *= 3;
}
prop_assert_eq!(grid[ternary_idx], input_val);
}
}
}
#[test]
fn test_svo_point_eval() {
let assert_eval = |svo_point: &SvoPoint<F, EF>, poly: &Poly<F>, point: &Point<EF>| {
let e0 = poly.eval_base(point);
let (e1, partial_evals) = svo_point.eval(poly);
assert_eq!(e0, e1);
assert_eq!(partial_evals.len(), svo_point.num_variables_svo());
match svo_point.var_order() {
VariableOrder::Prefix => {
partial_evals.iter().enumerate().for_each(|(i, pe0)| {
let (_point_lo, point_hi) = point.split_at(i + 1);
assert_eq!(pe0, &poly.compress_suffix(&point_hi, EF::ONE));
assert_eq!(e0, pe0.eval_base(&svo_point.z_svo().split_at(i + 1).0));
});
}
VariableOrder::Suffix => {
partial_evals.iter().enumerate().for_each(|(i, pe0)| {
let (point_lo, point_hi) = point.split_at(point.num_variables() - i - 1);
assert_eq!(pe0, &poly.compress_prefix(&point_lo, EF::ONE));
assert_eq!(e0, pe0.eval_base(&point_hi));
});
}
}
};
let k = 12;
let mut rng = SmallRng::seed_from_u64(11);
let poly = Poly::<F>::rand(&mut rng, k);
let point = Point::<EF>::rand(&mut rng, k);
for l0 in 0..=k {
let unpacked_prefix =
SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Prefix);
assert_eval(&unpacked_prefix, &poly, &point);
}
for l0 in 0..=k {
let unpacked_suffix =
SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Suffix);
assert_eval(&unpacked_suffix, &poly, &point);
}
for l0 in 0..=k {
let packed_prefix = SvoPoint::<F, EF>::new_packed(l0, &point);
assert_eval(&packed_prefix, &poly, &point);
}
}
#[test]
fn test_svo_point_accumulate() {
type F = KoalaBear;
type EF = BinomialExtensionField<F, 4>;
type PackedEF = <EF as ExtensionField<F>>::ExtensionPacking;
let mut rng = SmallRng::seed_from_u64(0);
let assert_accumulate_unpacked =
|svo_point: &SvoPoint<F, EF>, point: &Point<EF>, scale: EF, rs: &Point<EF>| {
let eq = Poly::new_from_point(point.as_slice(), EF::ONE);
let expected = match svo_point.var_order() {
VariableOrder::Prefix => eq.compress_prefix(rs, scale),
VariableOrder::Suffix => eq.compress_suffix(rs, scale),
};
let mut out = Poly::<EF>::zero(expected.num_variables());
svo_point.accumulate_into(out.as_mut_slice(), rs, scale);
assert_eq!(out, expected);
};
let assert_accumulate_packed =
|svo_point: &SvoPoint<F, EF>, point: &Point<EF>, scale: EF, rs: &Point<EF>| {
let eq = Poly::new_from_point(point.as_slice(), EF::ONE);
let expected = eq.compress_prefix(rs, scale);
let k_pack = log2_strict_usize(<<F as Field>::Packing as PackedValue>::WIDTH);
assert!(expected.num_variables() >= k_pack);
let mut out = Poly::<PackedEF>::zero(expected.num_variables() - k_pack);
svo_point.accumulate_into_packed(out.as_mut_slice(), rs, scale);
let unpacked =
<PackedEF as PackedFieldExtension<F, EF>>::to_ext_iter(out.iter().copied())
.take(expected.num_evals())
.collect::<Vec<_>>();
assert_eq!(unpacked, expected.as_slice());
};
let k = 12;
let k_pack = log2_strict_usize(<<F as Field>::Packing as PackedValue>::WIDTH);
let point = Point::<EF>::rand(&mut rng, k);
let scale: EF = rng.random();
for l0 in 0..=k {
let unpacked_prefix =
SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Prefix);
assert_eq!(unpacked_prefix.var_order(), VariableOrder::Prefix);
assert_eq!(unpacked_prefix.num_variables(), k);
assert_eq!(unpacked_prefix.num_variables_svo(), l0);
assert_accumulate_unpacked(&unpacked_prefix, &point, scale, &Point::rand(&mut rng, l0));
}
for l0 in 0..=k {
let unpacked_suffix =
SvoPoint::<F, EF>::new_unpacked(l0, &point, VariableOrder::Suffix);
assert_eq!(unpacked_suffix.var_order(), VariableOrder::Suffix);
assert_eq!(unpacked_suffix.num_variables(), k);
assert_eq!(unpacked_suffix.num_variables_svo(), l0);
assert_accumulate_unpacked(&unpacked_suffix, &point, scale, &Point::rand(&mut rng, l0));
}
for l0 in 0..=k {
if k - l0 >= k_pack {
let packed_prefix = SvoPoint::<F, EF>::new_packed(l0, &point);
assert_eq!(packed_prefix.var_order(), VariableOrder::Prefix);
assert_eq!(packed_prefix.num_variables(), k);
assert_eq!(packed_prefix.num_variables_svo(), l0);
assert_accumulate_packed(&packed_prefix, &point, scale, &Point::rand(&mut rng, l0));
}
}
}
}