use std::sync::Arc;
use spongefish::{ProverState, VerificationError, VerificationResult, VerifierState};
use super::base_case::{BaseCase, verify_base_case};
use super::code::{Field, InterleavedCode, ReedSolomon};
use super::codeswitch::CodeswitchHandle;
use super::commitment::{
CodeCommitment, CodeCommitmentHandle, CodeCommitmentProverHandle, ExplicitCodeCommitmentHandle,
FoldedCodeCommitmentHandle, FoldedCodeCommitmentProverState,
};
use super::evader::{ETA, OodEvader, OodEvaderHandle};
use super::linear_form::{
FoldedFormHandle, LinearCombinationForm, LinearConstraint, LinearForm, LinearFormHandle,
};
use super::mask_stack::MaskStack;
use super::transcript_io::{
read_opening, sample_positions_prover, sample_positions_verifier, write_opening,
};
use super::vc::MerkleVc;
use crate::field::Goldilocks4;
use crate::hash::JV_RZK;
const INTERLEAVING: usize = 4;
const RATE_INV: usize = 4;
type ProverFolded = FoldedCodeCommitmentProverState<ReedSolomon<Goldilocks4>, MerkleVc>;
type ProverCommit = CodeCommitment<InterleavedCode<ReedSolomon<Goldilocks4>>, MerkleVc>;
type VerifierFolded = FoldedCodeCommitmentHandle<ReedSolomon<Goldilocks4>, MerkleVc>;
type VerifierExplicit =
ExplicitCodeCommitmentHandle<InterleavedCode<ReedSolomon<Goldilocks4>>, MerkleVc>;
pub(crate) struct ProverCodeswitch {
ood_ze: OodEvader,
queries: usize,
output_commitment: ProverCommit,
}
impl ProverCodeswitch {
fn prove(
&self,
transcript: &mut ProverState,
input: ProverFolded,
mut constraint: LinearForm<Goldilocks4>,
mask_seed: &[u8; 32],
salt: &[u8],
mask_stack: &mut super::mask_stack::MaskStack,
) -> (LinearForm<Goldilocks4>, ProverFolded) {
let output = self
.output_commitment
.commit(transcript, input.msg().to_vec());
let l_zk_inner = crate::params::Params::M_ZK - crate::params::Params::T_ZK;
let t_zk = crate::params::Params::T_ZK;
let zk_enc = super::encoding::ZkEncoding::new(l_zk_inner, t_zk);
let padding_msg = super::base_case::derive_field_vec(
mask_seed,
&codeswitch_pad_msg_salt(salt),
l_zk_inner,
);
let padding_r =
super::base_case::derive_field_vec(mask_seed, &codeswitch_pad_r_salt(salt), t_zk);
let padding_codeword = zk_enc.encode_with(&padding_msg, &padding_r);
let padding_slab = super::code::CodewordSlab::new(padding_codeword, 1);
let padding_vc = super::vc::MerkleVc::new(padding_slab.positions());
let (padding_root, padding_vc_state) = padding_vc.commit_slab(padding_slab);
transcript.prover_message(&padding_root);
let ood_seeds: Vec<Goldilocks4> = (0..ETA).map(|_| transcript.verifier_message()).collect();
let mut ood_answers = self.ood_ze.apply(&output.msg, &ood_seeds);
for (i, y) in ood_answers.iter_mut().enumerate() {
*y += padding_msg[i];
}
transcript.prover_message(&ood_answers);
let positions = sample_positions_prover(transcript, self.queries, input.codeword_len());
let openings = input.open(&positions);
write_opening(transcript, &openings);
let batch_rand = transcript.verifier_messages_vec(ood_answers.len() + self.queries);
for (i, ze_constraint) in self
.ood_ze
.expanded_constraint(&ood_seeds)
.into_iter()
.enumerate()
{
constraint += LinearForm::new(ze_constraint) * batch_rand[i];
}
use super::codeswitch::TransposeCode;
let mut selector = vec![Goldilocks4::ZERO; input.codeword_len()];
for (i, &pos) in positions.iter().enumerate() {
selector[pos] += batch_rand[ood_answers.len() + i];
}
constraint += LinearForm::new(input.code().apply_transpose(&selector));
let mut padding_sl_o = vec![Goldilocks4::ZERO; l_zk_inner];
padding_sl_o[..ETA].copy_from_slice(&batch_rand[..ETA]);
let padding_target: Goldilocks4 = padding_sl_o
.iter()
.zip(&padding_msg)
.map(|(a, b)| *a * *b)
.sum();
transcript.prover_message(&padding_target);
let padding_handle = super::mask_stack::MaskOracleHandle::new_prover(
padding_msg,
padding_r,
padding_vc,
padding_vc_state,
);
mask_stack.push_padding_mask(padding_handle, Goldilocks4::ONE, padding_sl_o);
if let Some(last_mc) = mask_stack.constraints.last_mut() {
last_mc.target = padding_target;
}
let (folded_state, folded_constraint, mask_handles, gammas, epsilon, mask_targets) =
super::sumcheck::prove_sumcheck_zk(transcript, output, constraint, mask_seed, salt);
mask_stack.scale_alphas(epsilon);
mask_stack.push_sumcheck_masks(mask_handles, gammas, mask_targets);
let folded_constraint = folded_constraint * epsilon;
(folded_constraint, folded_state)
}
}
fn codeswitch_pad_msg_salt(salt: &[u8]) -> Vec<u8> {
let mut s = salt.to_vec();
s.extend_from_slice(b"::cs::pad_msg");
s
}
fn codeswitch_pad_r_salt(salt: &[u8]) -> Vec<u8> {
let mut s = salt.to_vec();
s.extend_from_slice(b"::cs::pad_r");
s
}
pub(crate) struct ConcreteWhirProtocol {
initial_commitment: ProverCommit,
rounds: Vec<ProverCodeswitch>,
final_base_case: BaseCase,
msg_len: usize,
internal_len: usize,
}
impl ConcreteWhirProtocol {
pub(crate) fn build(
msg_len: usize,
hvzk_budget: usize,
queries: usize,
threshold: usize,
) -> Self {
assert!(msg_len.is_power_of_two(), "msg_len must be a power of two");
let internal_len = (msg_len + hvzk_budget).next_power_of_two();
let initial_inner_msg_len = internal_len / INTERLEAVING;
let initial_commitment = make_commitment(initial_inner_msg_len);
let mut rounds = Vec::new();
let mut current_inner_msg_len = initial_inner_msg_len;
while current_inner_msg_len > threshold {
let next_inner_msg_len = current_inner_msg_len / INTERLEAVING;
rounds.push(ProverCodeswitch {
ood_ze: OodEvader::new(current_inner_msg_len),
queries,
output_commitment: make_commitment(next_inner_msg_len),
});
current_inner_msg_len = next_inner_msg_len;
}
Self {
initial_commitment,
rounds,
final_base_case: BaseCase::new(queries, crate::params::Params::T_ZK),
msg_len,
internal_len,
}
}
pub(crate) fn commit(
&self,
c: &[Goldilocks4],
sigma: &[u8; 32],
) -> ([u8; 32], WhirSignerState) {
assert_eq!(c.len(), self.msg_len);
let r_zk = derive_r_zk(sigma, self.internal_len - self.msg_len);
let mut internal = Vec::with_capacity(self.internal_len);
internal.extend_from_slice(c);
internal.extend(r_zk);
let (root, state) = self.initial_commitment.commit_only(internal.clone());
(
root,
WhirSignerState {
internal,
initial_state: Some(state),
},
)
}
pub(crate) fn prove(
&self,
transcript: &mut ProverState,
state: &WhirSignerState,
alpha_message: Vec<Goldilocks4>,
mask_seed: &[u8; 32],
) {
assert_eq!(alpha_message.len(), self.msg_len);
assert_eq!(state.internal.len(), self.internal_len);
let mut alpha_full = alpha_message;
alpha_full.resize(self.internal_len, Goldilocks4::ZERO);
let initial_constraint = LinearForm::new(alpha_full);
let cached_initial = state
.initial_state
.as_ref()
.expect("WhirSignerState::initial_state missing; keygen must populate it");
cached_initial.write_root_to(transcript);
let initial_state = cached_initial.clone();
let mut mask_stack = MaskStack::new();
let (mut folded_state, mut constraint, mask_handles, gammas, epsilon, mask_targets) =
super::sumcheck::prove_sumcheck_zk(
transcript,
initial_state,
initial_constraint,
mask_seed,
b"sumcheck::initial",
);
mask_stack.scale_alphas(epsilon);
mask_stack.push_sumcheck_masks(mask_handles, gammas, mask_targets);
constraint = constraint * epsilon;
for (round_idx, round) in self.rounds.iter().enumerate() {
let salt = format!("sumcheck::post-codeswitch-{round_idx}");
let (next_constraint, next_folded) = round.prove(
transcript,
folded_state,
constraint,
mask_seed,
salt.as_bytes(),
&mut mask_stack,
);
folded_state = next_folded;
constraint = next_constraint;
}
let coefficients: Vec<Goldilocks4> = constraint.coefficients().to_vec();
self.final_base_case.prove(
transcript,
folded_state,
&coefficients,
&mask_stack,
mask_seed,
);
}
}
pub(crate) struct VerifierCodeswitch {
ood_ze: OodEvaderHandle,
queries: usize,
output_commitment: VerifierExplicit,
}
impl VerifierCodeswitch {
fn verify(
&self,
transcript: &mut VerifierState,
input: &VerifierFolded,
constraint: LinearConstraint<FoldedFormHandle<Goldilocks4>>,
mask_stack: &mut super::mask_stack::MaskStack,
) -> VerificationResult<(
LinearConstraint<FoldedFormHandle<Goldilocks4>>,
VerifierFolded,
)> {
let output = ExplicitCodeCommitmentHandle {
code: self.output_commitment.code.clone(),
vc: self.output_commitment.vc.clone(),
commitment: transcript.prover_message()?,
};
let padding_root: [u8; 32] = transcript.prover_message()?;
let ood_seeds: Vec<Goldilocks4> = (0..ETA).map(|_| transcript.verifier_message()).collect();
let ze_constraints = self.ood_ze.zero_evader_handles(&ood_seeds);
let ood_answers = transcript.prover_messages_vec::<Goldilocks4>(ze_constraints.len())?;
let positions = sample_positions_verifier(transcript, self.queries, input.codeword_len());
let path_len_per_opening =
input.codeword_len().next_power_of_two().trailing_zeros() as usize;
let mut sorted_unique = positions.clone();
sorted_unique.sort_unstable();
sorted_unique.dedup();
let multiproof_bytes = crate::merkle::multiproof_size(&sorted_unique, path_len_per_opening);
let openings = read_opening(transcript, self.queries, INTERLEAVING, multiproof_bytes)?;
let opened = input.verify_openings(&positions, &openings)?;
let ood_len = ood_answers.len();
let batch_rand: Vec<Goldilocks4> = (0..ood_len + self.queries)
.map(|_| transcript.verifier_message())
.collect();
let mut value = constraint.value;
let mut forms: Vec<Box<dyn LinearFormHandle<Alphabet = Goldilocks4>>> =
vec![Box::new(constraint.linear_form_handle)];
let mut coeffs = vec![<Goldilocks4 as crate::field::SumcheckField>::ONE];
for (i, (answer, ze_constraint)) in ood_answers.into_iter().zip(ze_constraints).enumerate()
{
value += answer * batch_rand[i];
forms.push(Box::new(ze_constraint));
coeffs.push(batch_rand[i]);
}
for (i, (&pos, opening)) in positions.iter().zip(opened).enumerate() {
value += opening * batch_rand[ood_len + i];
forms.push(Box::new(input.code().apply_transpose_handle(pos)));
coeffs.push(batch_rand[ood_len + i]);
}
let padding_target: Goldilocks4 = transcript.prover_message()?;
let l_zk_inner = crate::params::Params::M_ZK - crate::params::Params::T_ZK;
let mut padding_sl_o = vec![Goldilocks4::ZERO; l_zk_inner];
padding_sl_o[..ETA].copy_from_slice(&batch_rand[..ETA]);
let padding_mask_handle =
super::mask_stack::MaskOracleHandle::verifier_root_only(padding_root);
mask_stack.push_padding_mask(
padding_mask_handle,
<Goldilocks4 as crate::field::SumcheckField>::ONE,
padding_sl_o,
);
if let Some(last_mc) = mask_stack.constraints.last_mut() {
last_mc.target = padding_target;
}
let mask_carry_in_pre = mask_stack.joint_mask_value();
let batched_constraint = LinearConstraint {
linear_form_handle: FoldedFormHandle {
linear_form_handle: Box::new(LinearCombinationForm {
linear_form_handles: forms,
combination_rand: coeffs,
}),
rand: Vec::new(),
scale: <Goldilocks4 as crate::field::SumcheckField>::ONE,
},
value: value - mask_carry_in_pre,
};
let (folded_output, mut folded_constraint, mask_handles, gammas, epsilon, mask_targets) =
super::sumcheck::verify_sumcheck_zk(transcript, output, batched_constraint)?;
folded_constraint.value += epsilon * mask_carry_in_pre;
mask_stack.scale_alphas(epsilon);
mask_stack.push_sumcheck_masks(mask_handles, gammas, mask_targets);
Ok((folded_constraint, folded_output))
}
}
pub(crate) struct ConcreteWhirVerifier {
initial_commitment: VerifierExplicit,
rounds: Vec<VerifierCodeswitch>,
final_base_case: BaseCase,
msg_len: usize,
internal_len: usize,
}
impl ConcreteWhirVerifier {
pub(crate) fn build(
msg_len: usize,
hvzk_budget: usize,
queries: usize,
threshold: usize,
) -> Self {
assert!(msg_len.is_power_of_two(), "msg_len must be a power of two");
let internal_len = (msg_len + hvzk_budget).next_power_of_two();
let initial_inner_msg_len = internal_len / INTERLEAVING;
let initial_commitment = make_explicit_handle(initial_inner_msg_len);
let mut rounds = Vec::new();
let mut current_inner_msg_len = initial_inner_msg_len;
while current_inner_msg_len > threshold {
let next_inner_msg_len = current_inner_msg_len / INTERLEAVING;
rounds.push(VerifierCodeswitch {
ood_ze: OodEvaderHandle::new(current_inner_msg_len),
queries,
output_commitment: make_explicit_handle(next_inner_msg_len),
});
current_inner_msg_len = next_inner_msg_len;
}
Self {
initial_commitment,
rounds,
final_base_case: BaseCase::new(queries, crate::params::Params::T_ZK),
msg_len,
internal_len,
}
}
pub(crate) fn verify<LFH>(
&self,
transcript: &mut VerifierState,
expected_root: [u8; 32],
alpha_handle: LFH,
value: Goldilocks4,
) -> VerificationResult<()>
where
LFH: LinearFormHandle<Alphabet = Goldilocks4> + 'static,
{
assert_eq!(alpha_handle.form_size(), self.msg_len);
let nu = self.msg_len.trailing_zeros();
let nu_prime = self.internal_len.trailing_zeros();
let embedded = MessageEmbeddedHandle {
inner: Box::new(alpha_handle),
nu,
nu_prime,
};
let initial_commitment_root: [u8; 32] = transcript.prover_message()?;
if initial_commitment_root != expected_root {
return Err(VerificationError);
}
let initial_commitment = ExplicitCodeCommitmentHandle {
code: self.initial_commitment.code.clone(),
vc: self.initial_commitment.vc.clone(),
commitment: initial_commitment_root,
};
let wrapped_initial = LinearConstraint {
linear_form_handle: FoldedFormHandle {
linear_form_handle: Box::new(embedded),
rand: Vec::new(),
scale: <Goldilocks4 as crate::field::SumcheckField>::ONE,
},
value,
};
let mut mask_stack = super::mask_stack::MaskStack::new();
let (mut folded_commitment, mut constraint, mask_handles, gammas, epsilon, mask_targets) =
super::sumcheck::verify_sumcheck_zk(transcript, initial_commitment, wrapped_initial)?;
mask_stack.scale_alphas(epsilon); mask_stack.push_sumcheck_masks(mask_handles, gammas, mask_targets);
for round in &self.rounds {
let (next_constraint, next_folded) =
round.verify(transcript, &folded_commitment, constraint, &mut mask_stack)?;
constraint = next_constraint;
folded_commitment = next_folded;
}
verify_base_case(
transcript,
self.final_base_case.queries,
self.final_base_case.mask_queries,
folded_commitment,
constraint,
&mask_stack.oracles,
&mask_stack.constraints,
)
}
}
fn make_commitment(inner_msg_len: usize) -> ProverCommit {
let rs = ReedSolomon::<Goldilocks4>::new(inner_msg_len);
let code = Arc::new(InterleavedCode::new(rs, INTERLEAVING));
let vc = Arc::new(MerkleVc::new(inner_msg_len * RATE_INV));
CodeCommitment::new(code, vc)
}
fn make_explicit_handle(inner_msg_len: usize) -> VerifierExplicit {
let rs = ReedSolomon::<Goldilocks4>::new(inner_msg_len);
let code = Arc::new(InterleavedCode::new(rs, INTERLEAVING));
let vc = Arc::new(MerkleVc::new(inner_msg_len * RATE_INV));
ExplicitCodeCommitmentHandle::new(code, vc, <[u8; 32]>::default())
}
pub(crate) type ProverCommitState = super::commitment::CodeCommitmentProverState<
super::code::InterleavedCode<super::code::ReedSolomon<Goldilocks4>>,
super::vc::MerkleVc,
>;
pub(crate) struct WhirSignerState {
pub(crate) internal: Vec<Goldilocks4>,
pub(crate) initial_state: Option<ProverCommitState>,
}
pub(crate) struct MessageEmbeddedHandle<F: Field> {
inner: Box<dyn LinearFormHandle<Alphabet = F>>,
nu: u32,
nu_prime: u32,
}
impl<F: Field> LinearFormHandle for MessageEmbeddedHandle<F> {
type Alphabet = F;
fn form_size(&self) -> usize {
1usize << self.nu_prime
}
fn folded_form(&self, rand: &[Self::Alphabet]) -> Vec<Self::Alphabet> {
let r = rand.len();
let pad_total = (self.nu_prime - self.nu) as usize;
let pad_rounds_bound = r.min(pad_total);
let mut pad_scalar = F::ONE;
for &r_i in &rand[..pad_rounds_bound] {
pad_scalar *= F::ONE - r_i;
}
let inner_rand = &rand[pad_rounds_bound..];
let inner_folded = self.inner.folded_form(inner_rand);
if pad_rounds_bound < pad_total {
let output_len = 1usize << (self.nu_prime as usize - r);
let mut out = vec![F::ZERO; output_len];
for (i, &x) in inner_folded.iter().enumerate() {
out[i] = pad_scalar * x;
}
out
} else {
inner_folded.iter().map(|&x| pad_scalar * x).collect()
}
}
}
fn derive_r_zk(sigma: &[u8; 32], count: usize) -> Vec<Goldilocks4> {
crate::hash::shake_field_elements(JV_RZK, &[sigma], count)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alpha::BatchedAlpha;
use crate::field::Goldilocks;
use crate::lift::MonomialLift;
fn g(n: u64) -> Goldilocks4 {
Goldilocks4::new([
Goldilocks::new(n),
Goldilocks::new(0),
Goldilocks::new(0),
Goldilocks::new(0),
])
}
fn manual_fold(mut v: Vec<Goldilocks4>, rand: &[Goldilocks4]) -> Vec<Goldilocks4> {
for &w in rand {
let half = v.len() / 2;
let mut new = Vec::with_capacity(half);
for k in 0..half {
new.push(v[k] + (v[k + half] - v[k]) * w);
}
v = new;
}
v
}
#[test]
fn embedded_handle_matches_zero_padded_explicit_fold() {
for (nu, nu_prime) in [(2u32, 4), (3, 5), (4, 6), (4, 9), (5, 7)] {
for k in [1usize, 2, 4] {
let xs: Vec<Goldilocks4> = (0..k).map(|i| g(7 + i as u64)).collect();
let betas: Vec<Goldilocks4> = (0..k).map(|i| g(101 + i as u64)).collect();
let alpha = BatchedAlpha::new(&xs, betas.clone(), nu);
let embedded = MessageEmbeddedHandle {
inner: Box::new(alpha),
nu,
nu_prime,
};
let n = 1usize << nu_prime;
let m = 1usize << nu;
let mut explicit = vec![Goldilocks4::ZERO; n];
for (x, &beta) in xs.iter().zip(betas.iter()) {
let lift = MonomialLift::new(*x, nu);
let u = lift.materialize();
assert_eq!(u.len(), m);
for (a, &uk) in explicit.iter_mut().take(m).zip(u.iter()) {
*a += beta * uk;
}
}
for r in 0..=nu_prime {
let rand: Vec<Goldilocks4> = (0..r).map(|i| g(2000 + i as u64)).collect();
let symbolic = embedded.folded_form(&rand);
let ref_fold = manual_fold(explicit.clone(), &rand);
assert_eq!(symbolic, ref_fold, "ν={nu} ν'={nu_prime} K={k} R={r}");
}
}
}
}
#[test]
fn derive_r_zk_is_deterministic() {
let sigma = [9u8; 32];
let a = derive_r_zk(&sigma, 100);
let b = derive_r_zk(&sigma, 100);
assert_eq!(a, b);
assert_eq!(a.len(), 100);
}
#[test]
fn derive_r_zk_distinct_seeds_diverge() {
let a = derive_r_zk(&[0u8; 32], 10);
let b = derive_r_zk(&[1u8; 32], 10);
assert_ne!(a, b);
}
}