use crate::{
digest::{DigestComputer, SimpleDigestible},
errors::NovaError,
r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness, SparseMatrix},
spartan::{
compute_eval_table_sparse,
polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial, multilinear::SparsePolynomial},
powers,
sumcheck::SumcheckProof,
PolyEvalInstance, PolyEvalWitness,
},
traits::{
evaluation::EvaluationEngineTrait,
snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait},
Engine, TranscriptEngineTrait,
},
CommitmentKey,
};
use ff::Field;
use itertools::Itertools as _;
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct ProverKey<E: Engine, EE: EvaluationEngineTrait<E>> {
pk_ee: EE::ProverKey,
vk_digest: E::Scalar, }
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct VerifierKey<E: Engine, EE: EvaluationEngineTrait<E>> {
vk_ee: EE::VerifierKey,
S: R1CSShape<E>,
#[serde(skip, default = "OnceCell::new")]
digest: OnceCell<E::Scalar>,
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> SimpleDigestible for VerifierKey<E, EE> {}
impl<E: Engine, EE: EvaluationEngineTrait<E>> VerifierKey<E, EE> {
fn new(shape: R1CSShape<E>, vk_ee: EE::VerifierKey) -> Self {
VerifierKey {
vk_ee,
S: shape,
digest: OnceCell::new(),
}
}
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> DigestHelperTrait<E> for VerifierKey<E, EE> {
fn digest(&self) -> E::Scalar {
self
.digest
.get_or_try_init(|| {
let dc = DigestComputer::<E::Scalar, _>::new(self);
dc.digest()
})
.cloned()
.expect("Failure to retrieve digest!")
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct RelaxedR1CSSNARK<E: Engine, EE: EvaluationEngineTrait<E>> {
sc_proof_outer: SumcheckProof<E>,
claims_outer: (E::Scalar, E::Scalar, E::Scalar),
eval_E: E::Scalar,
sc_proof_inner: SumcheckProof<E>,
eval_W: E::Scalar,
sc_proof_batch: SumcheckProof<E>,
evals_batch: Vec<E::Scalar>,
eval_arg: EE::EvaluationArgument,
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for RelaxedR1CSSNARK<E, EE> {
type ProverKey = ProverKey<E, EE>;
type VerifierKey = VerifierKey<E, EE>;
fn setup(
ck: &CommitmentKey<E>,
S: &R1CSShape<E>,
) -> Result<(Self::ProverKey, Self::VerifierKey), NovaError> {
let (pk_ee, vk_ee) = EE::setup(ck);
let S = S.pad();
let vk: VerifierKey<E, EE> = VerifierKey::new(S, vk_ee);
let pk = ProverKey {
pk_ee,
vk_digest: vk.digest(),
};
Ok((pk, vk))
}
fn prove(
ck: &CommitmentKey<E>,
pk: &Self::ProverKey,
S: &R1CSShape<E>,
U: &RelaxedR1CSInstance<E>,
W: &RelaxedR1CSWitness<E>,
) -> Result<Self, NovaError> {
let S = S.pad();
assert!(S.is_regular_shape());
let W = W.pad(&S); let mut transcript = E::TE::new(b"RelaxedR1CSSNARK");
transcript.absorb(b"vk", &pk.vk_digest);
transcript.absorb(b"U", U);
let mut z = [W.W.clone(), vec![U.u], U.X.clone()].concat();
let (num_rounds_x, num_rounds_y) = (
usize::try_from(S.num_cons.ilog2()).unwrap(),
(usize::try_from(S.num_vars.ilog2()).unwrap() + 1),
);
let tau = (0..num_rounds_x)
.map(|_i| transcript.squeeze(b"t"))
.collect::<Result<EqPolynomial<_>, NovaError>>()?;
let mut poly_tau = MultilinearPolynomial::new(tau.evals());
let (mut poly_Az, mut poly_Bz, poly_Cz, mut poly_uCz_E) = {
let (poly_Az, poly_Bz, poly_Cz) = S.multiply_vec(&z)?;
let poly_uCz_E = (0..S.num_cons)
.map(|i| U.u * poly_Cz[i] + W.E[i])
.collect::<Vec<E::Scalar>>();
(
MultilinearPolynomial::new(poly_Az),
MultilinearPolynomial::new(poly_Bz),
MultilinearPolynomial::new(poly_Cz),
MultilinearPolynomial::new(poly_uCz_E),
)
};
let comb_func_outer =
|poly_A_comp: &E::Scalar,
poly_B_comp: &E::Scalar,
poly_C_comp: &E::Scalar,
poly_D_comp: &E::Scalar|
-> E::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) };
let (sc_proof_outer, r_x, claims_outer) = SumcheckProof::prove_cubic_with_additive_term(
&E::Scalar::ZERO, num_rounds_x,
&mut poly_tau,
&mut poly_Az,
&mut poly_Bz,
&mut poly_uCz_E,
comb_func_outer,
&mut transcript,
)?;
let (claim_Az, claim_Bz): (E::Scalar, E::Scalar) = (claims_outer[1], claims_outer[2]);
let claim_Cz = poly_Cz.evaluate(&r_x);
let eval_E = MultilinearPolynomial::new(W.E.clone()).evaluate(&r_x);
transcript.absorb(
b"claims_outer",
&[claim_Az, claim_Bz, claim_Cz, eval_E].as_slice(),
);
let r = transcript.squeeze(b"r")?;
let claim_inner_joint = claim_Az + r * claim_Bz + r * r * claim_Cz;
let poly_ABC = {
let evals_rx = EqPolynomial::evals_from_points(&r_x.clone());
let (evals_A, evals_B, evals_C) = compute_eval_table_sparse(&S, &evals_rx);
assert_eq!(evals_A.len(), evals_B.len());
assert_eq!(evals_A.len(), evals_C.len());
(0..evals_A.len())
.into_par_iter()
.map(|i| evals_A[i] + r * evals_B[i] + r * r * evals_C[i])
.collect::<Vec<E::Scalar>>()
};
let poly_z = {
z.resize(S.num_vars * 2, E::Scalar::ZERO);
z
};
let comb_func = |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar {
*poly_A_comp * *poly_B_comp
};
let (sc_proof_inner, r_y, _claims_inner) = SumcheckProof::prove_quad(
&claim_inner_joint,
num_rounds_y,
&mut MultilinearPolynomial::new(poly_ABC),
&mut MultilinearPolynomial::new(poly_z),
comb_func,
&mut transcript,
)?;
let eval_W = MultilinearPolynomial::evaluate_with(&W.W, &r_y[1..]);
let w_vec = vec![PolyEvalWitness { p: W.W }, PolyEvalWitness { p: W.E }];
let u_vec = vec![
PolyEvalInstance {
c: U.comm_W,
x: r_y[1..].to_vec(),
e: eval_W,
},
PolyEvalInstance {
c: U.comm_E,
x: r_x,
e: eval_E,
},
];
let (batched_u, batched_w, sc_proof_batch, claims_batch_left) =
batch_eval_prove(u_vec, w_vec, &mut transcript)?;
let eval_arg = EE::prove(
ck,
&pk.pk_ee,
&mut transcript,
&batched_u.c,
&batched_w.p,
&batched_u.x,
&batched_u.e,
)?;
Ok(RelaxedR1CSSNARK {
sc_proof_outer,
claims_outer: (claim_Az, claim_Bz, claim_Cz),
eval_E,
sc_proof_inner,
eval_W,
sc_proof_batch,
evals_batch: claims_batch_left,
eval_arg,
})
}
fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance<E>) -> Result<(), NovaError> {
let mut transcript = E::TE::new(b"RelaxedR1CSSNARK");
transcript.absorb(b"vk", &vk.digest());
transcript.absorb(b"U", U);
let (num_rounds_x, num_rounds_y) = (
usize::try_from(vk.S.num_cons.ilog2()).unwrap(),
(usize::try_from(vk.S.num_vars.ilog2()).unwrap() + 1),
);
let tau = (0..num_rounds_x)
.map(|_i| transcript.squeeze(b"t"))
.collect::<Result<EqPolynomial<_>, NovaError>>()?;
let (claim_outer_final, r_x) =
self
.sc_proof_outer
.verify(E::Scalar::ZERO, num_rounds_x, 3, &mut transcript)?;
let (claim_Az, claim_Bz, claim_Cz) = self.claims_outer;
let taus_bound_rx = tau.evaluate(&r_x);
let claim_outer_final_expected =
taus_bound_rx * (claim_Az * claim_Bz - U.u * claim_Cz - self.eval_E);
if claim_outer_final != claim_outer_final_expected {
return Err(NovaError::InvalidSumcheckProof);
}
transcript.absorb(
b"claims_outer",
&[
self.claims_outer.0,
self.claims_outer.1,
self.claims_outer.2,
self.eval_E,
]
.as_slice(),
);
let r = transcript.squeeze(b"r")?;
let claim_inner_joint =
self.claims_outer.0 + r * self.claims_outer.1 + r * r * self.claims_outer.2;
let (claim_inner_final, r_y) =
self
.sc_proof_inner
.verify(claim_inner_joint, num_rounds_y, 2, &mut transcript)?;
let eval_Z = {
let eval_X = {
let mut poly_X = vec![(0, U.u)];
poly_X.extend(
(0..U.X.len())
.map(|i| (i + 1, U.X[i]))
.collect::<Vec<(usize, E::Scalar)>>(),
);
SparsePolynomial::new(usize::try_from(vk.S.num_vars.ilog2()).unwrap(), poly_X)
.evaluate(&r_y[1..])
};
(E::Scalar::ONE - r_y[0]) * self.eval_W + r_y[0] * eval_X
};
let multi_evaluate = |M_vec: &[&SparseMatrix<E::Scalar>],
r_x: &[E::Scalar],
r_y: &[E::Scalar]|
-> Vec<E::Scalar> {
let evaluate_with_table =
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
.sum()
};
let (T_x, T_y) = rayon::join(
|| EqPolynomial::evals_from_points(r_x),
|| EqPolynomial::evals_from_points(r_y),
);
(0..M_vec.len())
.into_par_iter()
.map(|i| evaluate_with_table(M_vec[i], &T_x, &T_y))
.collect()
};
let evals = multi_evaluate(&[&vk.S.A, &vk.S.B, &vk.S.C], &r_x, &r_y);
let claim_inner_final_expected = (evals[0] + r * evals[1] + r * r * evals[2]) * eval_Z;
if claim_inner_final != claim_inner_final_expected {
return Err(NovaError::InvalidSumcheckProof);
}
let u_vec: Vec<PolyEvalInstance<E>> = vec![
PolyEvalInstance {
c: U.comm_W,
x: r_y[1..].to_vec(),
e: self.eval_W,
},
PolyEvalInstance {
c: U.comm_E,
x: r_x,
e: self.eval_E,
},
];
let batched_u = batch_eval_verify(
u_vec,
&mut transcript,
&self.sc_proof_batch,
&self.evals_batch,
)?;
EE::verify(
&vk.vk_ee,
&mut transcript,
&batched_u.c,
&batched_u.x,
&batched_u.e,
&self.eval_arg,
)?;
Ok(())
}
}
pub(in crate::spartan) fn batch_eval_prove<E: Engine>(
u_vec: Vec<PolyEvalInstance<E>>,
w_vec: Vec<PolyEvalWitness<E>>,
transcript: &mut E::TE,
) -> Result<
(
PolyEvalInstance<E>,
PolyEvalWitness<E>,
SumcheckProof<E>,
Vec<E::Scalar>,
),
NovaError,
> {
let num_claims = u_vec.len();
assert_eq!(w_vec.len(), num_claims);
let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::<Vec<_>>();
w_vec
.iter()
.zip_eq(num_rounds.iter())
.for_each(|(w, num_vars)| assert_eq!(w.p.len(), 1 << num_vars));
let rho = transcript.squeeze(b"r")?;
let powers_of_rho = powers::<E>(&rho, num_claims);
let (claims, u_xs, comms): (Vec<_>, Vec<_>, Vec<_>) =
u_vec.into_iter().map(|u| (u.e, u.x, u.c)).multiunzip();
let polys_P: Vec<MultilinearPolynomial<E::Scalar>> = w_vec
.iter()
.map(|w| MultilinearPolynomial::new(w.p.clone()))
.collect();
let polys_eq: Vec<MultilinearPolynomial<E::Scalar>> = u_xs
.into_iter()
.map(|ux| MultilinearPolynomial::new(EqPolynomial::evals_from_points(&ux)))
.collect();
let comb_func = |poly_P: &E::Scalar, poly_eq: &E::Scalar| -> E::Scalar { *poly_P * *poly_eq };
let (sc_proof_batch, r, claims_batch) = SumcheckProof::prove_quad_batch(
&claims,
&num_rounds,
polys_P,
polys_eq,
&powers_of_rho,
comb_func,
transcript,
)?;
let (claims_batch_left, _): (Vec<E::Scalar>, Vec<E::Scalar>) = claims_batch;
transcript.absorb(b"l", &claims_batch_left.as_slice());
let gamma = transcript.squeeze(b"g")?;
let u_joint =
PolyEvalInstance::batch_diff_size(&comms, &claims_batch_left, &num_rounds, r, gamma);
let w_joint = PolyEvalWitness::batch_diff_size(w_vec, gamma);
Ok((u_joint, w_joint, sc_proof_batch, claims_batch_left))
}
pub(in crate::spartan) fn batch_eval_verify<E: Engine>(
u_vec: Vec<PolyEvalInstance<E>>,
transcript: &mut E::TE,
sc_proof_batch: &SumcheckProof<E>,
evals_batch: &[E::Scalar],
) -> Result<PolyEvalInstance<E>, NovaError> {
let num_claims = u_vec.len();
assert_eq!(evals_batch.len(), num_claims);
let rho = transcript.squeeze(b"r")?;
let powers_of_rho = powers::<E>(&rho, num_claims);
let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::<Vec<_>>();
let num_rounds_max = *num_rounds.iter().max().unwrap();
let claims = u_vec.iter().map(|u| u.e).collect::<Vec<_>>();
let (claim_batch_final, r) =
sc_proof_batch.verify_batch(&claims, &num_rounds, &powers_of_rho, 2, transcript)?;
let claim_batch_final_expected = {
let evals_r = u_vec.iter().map(|u| {
let (_, r_hi) = r.split_at(num_rounds_max - u.x.len());
EqPolynomial::new(r_hi.to_vec()).evaluate(&u.x)
});
zip_with!(
(evals_r, evals_batch.iter(), powers_of_rho.iter()),
|e_i, p_i, rho_i| e_i * *p_i * rho_i
)
.sum()
};
if claim_batch_final != claim_batch_final_expected {
return Err(NovaError::InvalidSumcheckProof);
}
transcript.absorb(b"l", &evals_batch);
let gamma = transcript.squeeze(b"g")?;
let comms = u_vec.into_iter().map(|u| u.c).collect::<Vec<_>>();
let u_joint = PolyEvalInstance::batch_diff_size(&comms, evals_batch, &num_rounds, r, gamma);
Ok(u_joint)
}