use ark_ff::FftField;
use super::Config;
use crate::algebra::{geometric_accumulate, geometric_sequence};
#[derive(Clone, Copy, Debug)]
pub(super) struct ProtocolDims {
pub(super) mu: usize,
pub(super) ell: usize,
pub(super) rem: usize,
pub(super) nu: usize,
pub(super) size: usize,
pub(super) num_vectors: usize,
pub(super) num_blinding_vecs: usize,
}
impl ProtocolDims {
pub(super) fn new<F: FftField>(config: &Config<F>, num_vectors: usize) -> Self {
let mu = config.blinded_polynomial.initial_num_variables();
let ell = config
.blinding_polynomial
.initial_num_variables()
.checked_sub(1)
.filter(|&e| e > 0)
.expect("blinding polynomial must have at least 2 variables (ell >= 1)");
let rem = mu % ell;
let num_blinding_vecs = config.blinding_polynomial.initial_committer.num_vectors;
let nu = num_blinding_vecs
.checked_sub(num_vectors)
.expect("blinding polynomial must commit more vectors than witness count");
let size = 1 << mu;
Self {
mu,
ell,
rem,
nu,
size,
num_vectors,
num_blinding_vecs,
}
}
pub(super) const fn num_g_polys(&self) -> usize {
self.nu + 1
}
pub(super) const fn phi_i_bits(&self, hypercube_idx: usize, phi_index: usize) -> usize {
phi_i_bits(hypercube_idx, phi_index, self.mu, self.ell, self.rem)
}
}
const fn phi_i_bits(
hypercube_idx: usize,
phi_index: usize,
mu: usize,
ell: usize,
rem: usize,
) -> usize {
let start = if phi_index == 0 {
0
} else {
(phi_index - 1) * ell + rem
};
assert!(start + ell <= mu, "phi_i_bits: window exceeds mu");
let shift = mu - start - ell;
(hypercube_idx >> shift) & ((1 << ell) - 1)
}
pub(super) fn discrete_log_pow2<F: FftField>(
target: F,
gen: F,
gen_inv: F,
log_order: u32,
) -> usize {
debug_assert_eq!(gen * gen_inv, F::ONE, "gen_inv must be the inverse of gen");
let mut result = 0usize;
let mut current = target;
let mut gen_inv_power = gen_inv;
for bit in 0..log_order {
let mut test = current;
for _ in 0..(log_order - bit - 1) {
test.square_in_place();
}
if test != F::ONE {
result |= 1 << bit;
current *= gen_inv_power;
}
gen_inv_power.square_in_place();
}
assert_eq!(
gen.pow([result as u64]),
target,
"discrete log verification failed: target not in ⟨gen⟩ of order 2^{log_order}"
);
result
}
pub(super) fn build_fold_args<F: FftField>(r_bar: &[F], z: F, mu: usize) -> Vec<F> {
let num_folded_vars = r_bar.len();
let num_z_vars = mu - num_folded_vars;
let mut point = Vec::with_capacity(mu);
point.extend(r_bar);
let mut z_pow = z;
let mut z_pows = Vec::with_capacity(num_z_vars);
for _ in 0..num_z_vars {
z_pows.push(z_pow);
z_pow.square_in_place();
}
point.extend(z_pows.iter().rev());
point
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(num_points = lambda_z_points.len(), mu = dims.mu, ell = dims.ell, num_g_polys = dims.num_g_polys())))]
pub(super) fn build_beq_tables<F: FftField>(
lambda_z_points: &[F],
eq_weights: &[F],
tau: F,
dims: ProtocolDims,
) -> Vec<Vec<F>> {
let mu = dims.mu;
let ell = dims.ell;
let rem = dims.rem;
let num_g_polys = dims.num_g_polys();
let half_size = 1usize << ell;
assert!(
eq_weights.len().is_power_of_two(),
"eq_weights length must be a power of 2, got {}",
eq_weights.len()
);
let num_folding_vars = eq_weights.len().trailing_zeros() as usize;
assert!(
num_folding_vars <= ell,
"folding factor num_folding_vars={num_folding_vars} must not exceed ell={ell} (would underflow m_cap in Φ₀ window)"
);
let num_m_bits = mu - num_folding_vars;
let tau_powers_full = geometric_sequence(tau, lambda_z_points.len() + 1);
let tau_powers = &tau_powers_full[1..];
let z_pows_all: Vec<Vec<F>> = lambda_z_points
.iter()
.map(|z| {
let mut z_pows = Vec::with_capacity(num_m_bits);
let mut z_pow = *z;
for _ in 0..num_m_bits {
z_pows.push(z_pow);
z_pow.square_in_place();
}
z_pows
})
.collect();
let num_points = lambda_z_points.len();
let mut tables = vec![vec![F::ZERO; half_size]; num_g_polys];
for (i, table) in tables.iter_mut().enumerate() {
let start_i = if i == 0 { 0 } else { (i - 1) * ell + rem };
let a_below = mu - start_i - ell; let a_above = start_i.saturating_sub(num_folding_vars); let m_cap = num_m_bits - a_below - a_above; let c_cap = ell - m_cap;
let eq_partial = if c_cap > 0 {
let mut eq_partial = vec![F::ZERO; 1 << c_cap];
let c_mask = (1 << c_cap) - 1;
for (c_idx, &weight) in eq_weights.iter().enumerate() {
eq_partial[c_idx & c_mask] += weight;
}
eq_partial
} else {
vec![F::ONE] };
let m_cap_size = 1usize << m_cap;
let mut scalars = Vec::with_capacity(num_points);
let mut bases = Vec::with_capacity(num_points);
for (j, &tp) in tau_powers.iter().enumerate() {
let z_pows = &z_pows_all[j];
let mut geo_below = F::ONE;
for &zp in z_pows.iter().take(a_below) {
geo_below *= F::ONE + zp;
}
let mut geo_above = F::ONE;
for &zp in z_pows.iter().skip(a_below + m_cap).take(a_above) {
geo_above *= F::ONE + zp;
}
scalars.push(tp * geo_below * geo_above);
bases.push(if a_below < num_m_bits {
z_pows[a_below]
} else {
F::ONE
});
}
let mut m_inner = vec![F::ZERO; m_cap_size];
geometric_accumulate(&mut m_inner, scalars, &bases);
if c_cap > 0 {
for (k_c, &ep) in eq_partial.iter().enumerate() {
for (k_m, &mi) in m_inner.iter().enumerate() {
table[k_c * m_cap_size + k_m] = ep * mi;
}
}
} else {
*table = m_inner;
}
}
tables
}
#[derive(Debug)]
pub(super) struct RsFoldCoeffs<F> {
pub(super) masking_coeffs_all: Vec<Vec<F>>,
pub(super) g_i_coeffs: Vec<Vec<F>>,
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip_all, fields(mu = dims.mu, ell = dims.ell, num_g_polys = g_polys.len())))]
pub(super) fn compute_rs_fold_blinding_coeffs<F: FftField>(
eq_weights: &[F],
g_polys: &[Vec<F>],
masking_polys: &[Vec<F>],
alpha_coeffs: &[F],
rho: F,
dims: ProtocolDims,
) -> RsFoldCoeffs<F> {
let mu = dims.mu;
assert!(
eq_weights.len().is_power_of_two(),
"eq_weights length must be a power of 2, got {}",
eq_weights.len()
);
let num_folding_vars = eq_weights.len().trailing_zeros() as usize;
let num_sub_polys = 1usize << num_folding_vars;
let sub_poly_len = 1usize << (mu - num_folding_vars);
let num_g_polys = g_polys.len();
let num_masking = masking_polys.len();
let neg_rho = -rho;
#[allow(clippy::needless_range_loop)]
let accumulate_j = |g0_fold_accumulator: &mut Vec<F>,
masking_fold_accumulators: &mut Vec<Vec<F>>,
g_polys_fold_accumulator: &mut Vec<Vec<F>>,
j: usize| {
let eq_j = eq_weights[j];
for sub_idx in 0..sub_poly_len {
let full_idx = j * sub_poly_len + sub_idx;
let phi_0_idx = dims.phi_i_bits(full_idx, 0);
g0_fold_accumulator[sub_idx] += eq_j * g_polys[0][phi_0_idx];
for (i, msk) in masking_polys.iter().enumerate() {
masking_fold_accumulators[i][sub_idx] += eq_j * msk[phi_0_idx];
}
for (gi_idx, g_poly) in g_polys[1..].iter().enumerate() {
let phi_i_idx = dims.phi_i_bits(full_idx, gi_idx + 1);
g_polys_fold_accumulator[gi_idx][sub_idx] += eq_j * g_poly[phi_i_idx];
}
}
};
let assemble =
|g0_fold: Vec<F>, masking_folds: Vec<Vec<F>>, g_i_coeffs: Vec<Vec<F>>| -> RsFoldCoeffs<F> {
let mut masking_coeffs_all = Vec::with_capacity(num_masking);
let m0: Vec<F> = g0_fold
.iter()
.zip(masking_folds[0].iter())
.map(|(&g, &msk)| g + neg_rho * msk)
.collect();
masking_coeffs_all.push(m0);
for i in 1..num_masking {
let scale = neg_rho * alpha_coeffs[i];
let mi: Vec<F> = masking_folds[i].iter().map(|&v| scale * v).collect();
masking_coeffs_all.push(mi);
}
RsFoldCoeffs {
masking_coeffs_all,
g_i_coeffs,
}
};
let mut g0_fold = vec![F::ZERO; sub_poly_len];
let mut masking_folds = vec![vec![F::ZERO; sub_poly_len]; num_masking];
let mut g_i_coeffs = vec![vec![F::ZERO; sub_poly_len]; num_g_polys - 1];
for j in 0..num_sub_polys {
accumulate_j(&mut g0_fold, &mut masking_folds, &mut g_i_coeffs, j);
}
assemble(g0_fold, masking_folds, g_i_coeffs)
}
pub(super) fn build_weight_covectors<F: FftField>(
beq_tables: &[Vec<F>],
rho: F,
alpha_coeffs: &[F],
dims: ProtocolDims,
) -> Vec<Vec<F>> {
let num_vectors = dims.num_vectors;
let num_blinding_vecs = dims.num_blinding_vecs;
let full_size = 1usize << (dims.ell + 1);
let mut weight_covectors: Vec<Vec<F>> = Vec::with_capacity(num_blinding_vecs);
{
let mut w0 = vec![F::ZERO; full_size];
let neg_rho = -rho;
for (chunk, &beq) in w0.chunks_exact_mut(2).zip(&beq_tables[0]) {
chunk[0] = beq;
chunk[1] = neg_rho * beq;
}
weight_covectors.push(w0);
}
for &alpha in &alpha_coeffs[1..num_vectors] {
let mut wi = vec![F::ZERO; full_size];
let scale = -rho * alpha;
for (chunk, &beq) in wi.chunks_exact_mut(2).zip(&beq_tables[0]) {
chunk[1] = scale * beq;
}
weight_covectors.push(wi);
}
for beq_table in beq_tables.iter().skip(1) {
let mut wj = vec![F::ZERO; full_size];
for (chunk, &beq) in wj.chunks_exact_mut(2).zip(beq_table) {
chunk[0] = beq;
}
weight_covectors.push(wj);
}
weight_covectors
}
pub(super) fn gamma_to_f_hat_indices<F: FftField>(
gamma_points: &[F],
config: &super::Config<F>,
) -> Vec<usize> {
assert!(
!config.blinded_polynomial.round_configs.is_empty(),
"zkWHIR 2.0 requires at least one WHIR round"
);
let initial_codeword_len = config.blinded_polynomial.initial_committer.codeword_length;
let round0_codeword_len = config.blinded_polynomial.round_configs[0]
.irs_committer
.codeword_length;
let stride = initial_codeword_len / round0_codeword_len;
let gen_h = config.blinded_polynomial.round_configs[0]
.irs_committer
.generator();
let gen_h_inv = gen_h.inverse().expect("generator must be invertible");
let log_round0_len = round0_codeword_len.trailing_zeros();
gamma_points
.iter()
.map(|&gamma| discrete_log_pow2(gamma, gen_h, gen_h_inv, log_round0_len) * stride)
.collect()
}
pub(super) fn compute_eq_weights<F: FftField>(r_bar: &[F]) -> Vec<F> {
let len = 1usize << r_bar.len();
let mut buf = vec![F::ONE; len];
for (i, &r) in r_bar.iter().enumerate() {
let half = 1 << i;
for j in (0..half).rev() {
buf[2 * j + 1] = buf[j] * r;
buf[2 * j] = buf[j] - buf[2 * j + 1];
}
}
buf
}
#[derive(Debug)]
pub(super) struct LambdaAccumulator<F> {
z_points: Vec<F>,
m_evals: Vec<Vec<F>>,
g_evals: Vec<Vec<F>>,
}
impl<F> LambdaAccumulator<F> {
pub(super) const fn new() -> Self {
Self {
z_points: Vec::new(),
m_evals: Vec::new(),
g_evals: Vec::new(),
}
}
pub(super) fn z_points(&self) -> &[F] {
&self.z_points
}
pub(super) fn push(&mut self, z: F, m: Vec<F>, g: Vec<F>) {
assert!(
self.m_evals.is_empty() || m.len() == self.m_evals[0].len(),
"m_evals length mismatch: expected {}, got {}",
self.m_evals.first().map_or(0, Vec::len),
m.len()
);
assert!(
self.g_evals.is_empty() || g.len() == self.g_evals[0].len(),
"g_evals length mismatch: expected {}, got {}",
self.g_evals.first().map_or(0, Vec::len),
g.len()
);
self.z_points.push(z);
self.m_evals.push(m);
self.g_evals.push(g);
}
#[must_use]
pub(super) const fn len(&self) -> usize {
self.z_points.len()
}
pub(super) fn claim(&self, lambda_idx: usize, vec_idx: usize, num_vectors: usize) -> F
where
F: Copy,
{
if vec_idx < num_vectors {
self.m_evals[lambda_idx][vec_idx]
} else {
self.g_evals[lambda_idx][vec_idx - num_vectors]
}
}
}
#[cfg(test)]
mod tests {
use ark_ff::{FftField, Field};
use proptest::prelude::*;
use super::{discrete_log_pow2, phi_i_bits};
use crate::algebra::fields::Field64;
#[test]
fn phi_i_bits_phi0_extracts_top_ell_bits() {
let mu = 8;
let ell = 3;
let rem = mu % ell;
assert_eq!(phi_i_bits(0b1010_1010, 0, mu, ell, rem), 0b101);
assert_eq!(phi_i_bits(0b1110_0000, 0, mu, ell, rem), 0b111);
assert_eq!(phi_i_bits(0b0001_1111, 0, mu, ell, rem), 0b000);
}
#[test]
fn phi_i_bits_phi1_extracts_after_rem() {
let mu = 8;
let ell = 3;
let rem = 2;
assert_eq!(phi_i_bits(0b1011_0010, 1, mu, ell, rem), 0b110);
}
#[test]
fn phi_i_bits_phi2_extracts_next_window() {
let mu = 8;
let ell = 3;
let rem = 2;
assert_eq!(phi_i_bits(0b1011_0101, 2, mu, ell, rem), 0b101);
}
#[test]
fn phi_i_bits_rem_zero() {
let mu = 6;
let ell = 2;
let rem = 0;
assert_eq!(phi_i_bits(0b11_00_10, 0, mu, ell, rem), 0b11);
assert_eq!(phi_i_bits(0b11_00_10, 2, mu, ell, rem), 0b00);
assert_eq!(phi_i_bits(0b11_00_10, 3, mu, ell, rem), 0b10);
}
#[test]
fn phi_i_bits_single_bit_ell() {
let mu = 4;
let ell = 1;
let rem = 0;
assert_eq!(phi_i_bits(0b1011, 0, mu, ell, rem), 1);
assert_eq!(phi_i_bits(0b1011, 1, mu, ell, rem), 1); assert_eq!(phi_i_bits(0b1011, 2, mu, ell, rem), 0);
assert_eq!(phi_i_bits(0b1011, 3, mu, ell, rem), 1);
assert_eq!(phi_i_bits(0b1011, 4, mu, ell, rem), 1);
}
#[test]
fn phi_i_bits_boundary_all_ones() {
let mu = 6;
let ell = 3;
let rem = 0;
assert_eq!(phi_i_bits(0b111_111, 0, mu, ell, rem), 0b111);
assert_eq!(phi_i_bits(0b111_111, 2, mu, ell, rem), 0b111);
}
#[test]
fn phi_i_bits_boundary_all_zeros() {
let mu = 6;
let ell = 3;
let rem = 0;
assert_eq!(phi_i_bits(0, 0, mu, ell, rem), 0);
assert_eq!(phi_i_bits(0, 2, mu, ell, rem), 0);
}
fn dlog(target: Field64, gen: Field64, log_order: u32) -> usize {
let gen_inv = gen.inverse().expect("generator must be invertible");
discrete_log_pow2(target, gen, gen_inv, log_order)
}
fn subgroup_gen(log_order: u32) -> Field64 {
let max_log = Field64::TWO_ADICITY;
assert!(log_order <= max_log);
let mut gen = Field64::TWO_ADIC_ROOT_OF_UNITY;
for _ in 0..(max_log - log_order) {
gen.square_in_place();
}
gen
}
#[test]
fn dlog_identity_is_zero() {
for log_order in 1..=8 {
let gen = subgroup_gen(log_order);
assert_eq!(
dlog(Field64::ONE, gen, log_order),
0,
"dlog(1, gen, {log_order}) should be 0"
);
}
}
#[test]
fn dlog_generator_is_one() {
for log_order in 1..=8 {
let gen = subgroup_gen(log_order);
assert_eq!(
dlog(gen, gen, log_order),
1,
"dlog(gen, gen, {log_order}) should be 1"
);
}
}
#[test]
fn dlog_known_powers() {
let log_order = 4; let gen = subgroup_gen(log_order);
for i in 0..16usize {
let target = gen.pow([i as u64]);
assert_eq!(
dlog(target, gen, log_order),
i,
"dlog(gen^{i}, gen, {log_order}) should be {i}"
);
}
}
#[test]
fn dlog_order_1() {
let gen = subgroup_gen(1);
assert_eq!(dlog(Field64::ONE, gen, 1), 0);
assert_eq!(dlog(gen, gen, 1), 1);
}
#[test]
fn dlog_larger_group() {
let log_order = 10; let gen = subgroup_gen(log_order);
for &i in &[0, 1, 2, 511, 512, 1023] {
let target = gen.pow([i as u64]);
assert_eq!(dlog(target, gen, log_order), i, "failed for i={i}");
}
}
proptest! {
#[test]
fn dlog_roundtrip(log_order in 1u32..=12, idx in 0u32..4096) {
let order = 1u32 << log_order;
let i = (idx % order) as usize;
let gen = subgroup_gen(log_order);
let target = gen.pow([i as u64]);
let result = dlog(target, gen, log_order);
prop_assert_eq!(result, i, "dlog roundtrip failed for log_order={}, i={}", log_order, i);
}
#[test]
fn phi_i_bits_in_range(
mu in 4usize..=16,
ell in 1usize..=4,
hypercube_idx_raw in 0usize..65536,
) {
prop_assume!(ell <= mu);
let rem = mu % ell;
let num_phis = 1 + (mu - rem) / ell; let max_idx = 1usize << mu;
let hypercube_idx = hypercube_idx_raw % max_idx;
for phi_index in 0..num_phis {
let start = if phi_index == 0 { 0 } else { (phi_index - 1) * ell + rem };
if start + ell > mu {
break;
}
let result = phi_i_bits(hypercube_idx, phi_index, mu, ell, rem);
prop_assert!(
result < (1 << ell),
"phi_i_bits({}, {}, {}, {}, {}) = {} >= 2^{}",
hypercube_idx, phi_index, mu, ell, rem, result, ell
);
}
}
#[test]
fn phi_i_bits_partition_no_remainder(
mu_factor in 1usize..=4,
ell in 2usize..=4,
hypercube_idx_raw in 0usize..65536,
) {
let mu = mu_factor * ell;
prop_assume!(mu <= 16);
let rem = 0;
let num_phis = mu / ell; let max_idx = 1usize << mu;
let b = hypercube_idx_raw % max_idx;
let mut reconstructed = 0usize;
let phi0 = phi_i_bits(b, 0, mu, ell, rem);
reconstructed |= phi0 << (mu - ell);
for i in 1..num_phis {
let phi_idx = i + 1; let start = (phi_idx - 1) * ell; if start + ell > mu {
break;
}
let bits = phi_i_bits(b, phi_idx, mu, ell, rem);
let shift = mu - start - ell;
reconstructed |= bits << shift;
}
prop_assert_eq!(
reconstructed, b,
"partition reconstruction failed for b={:#b}, mu={}, ell={}", b, mu, ell
);
}
}
}