use alloc::vec::Vec;
use p3_field::{ExtensionField, Field};
#[derive(Debug, Clone, Copy)]
pub(super) struct RoundContext<'a, F, EF> {
pub k: usize,
pub ell_zk: usize,
pub pow2: &'a [F],
pub eps: EF,
}
impl<F, EF> RoundContext<'_, F, EF>
where
F: Field,
EF: ExtensionField<F>,
{
pub(super) fn assemble(&self, state: RoundState<'_, F, EF>, plain: PlainPiece<EF>) -> Vec<EF> {
let h_size = self.ell_zk.max(3);
let mut h: Vec<EF> = EF::zero_vec(h_size);
let mult_live = self.pow2[self.k - state.j];
for (i, &c) in state.mask.iter().enumerate() {
h[i] += mult_live * c;
}
let past_mask_sum: EF = state.past_mask_evals.iter().copied().sum();
h[0] += past_mask_sum * mult_live;
if state.j < self.k {
let mult_future = self.pow2[self.k - state.j - 1];
h[0] += mult_future * state.future_endpoints;
}
h[0] += self.eps * plain.c0;
h[2] += self.eps * plain.c_inf;
#[cfg(debug_assertions)]
{
let mult_past = self.pow2[self.k - state.j + 1];
let s_j_endpoints = state.mask[0].double() + state.mask[1..].iter().copied().sum::<F>();
let plain_transmitted = plain.c0.double() + plain.c_inf;
#[allow(clippy::suspicious_operation_groupings)]
let mut expected: EF = self.eps * plain_transmitted
+ past_mask_sum * mult_past
+ mult_live * s_j_endpoints;
if state.j < self.k {
expected += mult_live * state.future_endpoints;
}
debug_assert_eq!(
h[0].double() + h[1..].iter().copied().sum::<EF>(),
expected,
"round polynomial affine consistency check failed at round {}",
state.j,
);
}
h
}
}
#[derive(Debug, Clone, Copy)]
pub(super) struct RoundState<'a, F, EF> {
pub j: usize,
pub mask: &'a [F],
pub past_mask_evals: &'a [EF],
pub future_endpoints: F,
}
#[derive(Debug, Clone, Copy)]
pub(super) struct PlainPiece<EF> {
pub c0: EF,
pub c_inf: EF,
}
pub(super) fn round_poly_to_wire<EF: Copy>(h: &[EF]) -> Vec<EF> {
let mut wire: Vec<EF> = Vec::with_capacity(h.len() - 1);
wire.push(h[0]);
wire.extend_from_slice(&h[2..]);
wire
}
#[cfg(test)]
mod tests {
use alloc::vec;
use p3_field::PrimeCharacteristicRing;
use super::*;
use crate::zk::test_helpers::{EF, F};
fn pow2_table(k: usize) -> Vec<F> {
F::TWO.powers().collect_n(k + 1)
}
fn zero_plain() -> PlainPiece<EF> {
PlainPiece {
c0: EF::ZERO,
c_inf: EF::ZERO,
}
}
#[test]
fn round_poly_to_wire_drops_the_linear_coefficient() {
let h: Vec<EF> = vec![
EF::from_u32(10),
EF::from_u32(20),
EF::from_u32(30),
EF::from_u32(40),
];
let wire = round_poly_to_wire(&h);
assert_eq!(
wire,
vec![EF::from_u32(10), EF::from_u32(30), EF::from_u32(40)]
);
}
#[test]
fn round_poly_to_wire_handles_minimum_length() {
let h: Vec<EF> = vec![EF::from_u32(1), EF::from_u32(2), EF::from_u32(3)];
let wire = round_poly_to_wire(&h);
assert_eq!(wire, vec![EF::from_u32(1), EF::from_u32(3)]);
}
#[test]
fn assemble_output_length_matches_max_of_mask_len_and_quadratic() {
let k = 1;
let pow2 = pow2_table(k);
let mask2 = vec![F::ZERO; 2];
let h = RoundContext {
k,
ell_zk: 2,
pow2: &pow2,
eps: EF::ZERO,
}
.assemble(
RoundState {
j: 1,
mask: &mask2,
past_mask_evals: &[],
future_endpoints: F::ZERO,
},
zero_plain(),
);
assert_eq!(h, vec![EF::ZERO; 3]);
let mask5 = vec![F::ZERO; 5];
let h = RoundContext {
k,
ell_zk: 5,
pow2: &pow2,
eps: EF::ZERO,
}
.assemble(
RoundState {
j: 1,
mask: &mask5,
past_mask_evals: &[],
future_endpoints: F::ZERO,
},
zero_plain(),
);
assert_eq!(h, vec![EF::ZERO; 5]);
}
#[test]
fn assemble_live_mask_lands_at_correct_slots() {
let k = 1;
let pow2 = pow2_table(k);
let mask = vec![
F::from_u32(7),
F::from_u32(11),
F::from_u32(13),
F::from_u32(17),
];
let h = RoundContext {
k,
ell_zk: 4,
pow2: &pow2,
eps: EF::ZERO,
}
.assemble(
RoundState {
j: 1,
mask: &mask,
past_mask_evals: &[],
future_endpoints: F::ZERO,
},
zero_plain(),
);
let expected = vec![
EF::from_u32(7),
EF::from_u32(11),
EF::from_u32(13),
EF::from_u32(17),
];
assert_eq!(h, expected);
}
#[test]
fn assemble_future_term_present_only_when_j_lt_k() {
let k = 2;
let pow2 = pow2_table(k);
let mask = vec![F::ZERO, F::ZERO];
let ctx = RoundContext {
k,
ell_zk: 2,
pow2: &pow2,
eps: EF::ZERO,
};
let h_first = ctx.assemble(
RoundState {
j: 1,
mask: &mask,
past_mask_evals: &[],
future_endpoints: F::from_u32(999),
},
zero_plain(),
);
assert_eq!(h_first, vec![EF::from_u32(999), EF::ZERO, EF::ZERO]);
let h_last = ctx.assemble(
RoundState {
j: 2,
mask: &mask,
past_mask_evals: &[],
future_endpoints: F::from_u32(999),
},
zero_plain(),
);
assert_eq!(h_last, vec![EF::ZERO; 3]);
}
#[test]
fn assemble_past_mask_sum_lands_on_constant_slot() {
let k = 2;
let pow2 = pow2_table(k);
let mask = vec![F::ZERO, F::ZERO];
let past = vec![EF::from_u32(7), EF::from_u32(11)];
let h = RoundContext {
k,
ell_zk: 2,
pow2: &pow2,
eps: EF::ZERO,
}
.assemble(
RoundState {
j: 2,
mask: &mask,
past_mask_evals: &past,
future_endpoints: F::ZERO,
},
zero_plain(),
);
assert_eq!(h, vec![EF::from_u32(18), EF::ZERO, EF::ZERO]);
}
#[test]
fn assemble_satisfies_affine_consistency() {
let k = 1;
let pow2 = pow2_table(k);
let mask = vec![F::from_u32(2), F::from_u32(3), F::from_u32(5)];
let plain = PlainPiece {
c0: EF::from_u32(7),
c_inf: EF::from_u32(13),
};
let eps = EF::from_u32(31);
let h = RoundContext {
k,
ell_zk: 3,
pow2: &pow2,
eps,
}
.assemble(
RoundState {
j: 1,
mask: &mask,
past_mask_evals: &[],
future_endpoints: F::ZERO,
},
plain,
);
let expected = vec![EF::from_u32(219), EF::from_u32(3), EF::from_u32(408)];
assert_eq!(h, expected);
let actual_target = h[0].double() + h[1..].iter().copied().sum::<EF>();
let live = EF::from_u32(2).double() + EF::from_u32(3) + EF::from_u32(5);
let plain_transmitted = plain.c0.double() + plain.c_inf;
let expected_target = eps * plain_transmitted + live;
assert_eq!(actual_target, expected_target);
assert_eq!(actual_target, EF::from_u32(849));
}
}