use crate::{
digest::{DigestComputer, SimpleDigestible},
errors::NovaError,
r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness, SparseMatrix},
spartan::{
compute_eval_table_sparse,
math::Math,
polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial, multilinear::SparsePolynomial},
sumcheck::SumcheckProof,
PolyEvalInstance, PolyEvalWitness,
},
traits::{
evaluation::EvaluationEngineTrait,
snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait},
Engine, TranscriptEngineTrait,
},
CommitmentKey,
};
use ff::Field;
use once_cell::sync::OnceCell;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct ProverKey<E: Engine, EE: EvaluationEngineTrait<E>> {
pk_ee: EE::ProverKey,
vk_digest: E::Scalar, }
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub struct VerifierKey<E: Engine, EE: EvaluationEngineTrait<E>> {
vk_ee: EE::VerifierKey,
S: R1CSShape<E>,
#[serde(skip, default = "OnceCell::new")]
digest: OnceCell<E::Scalar>,
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> SimpleDigestible for VerifierKey<E, EE> {}
impl<E: Engine, EE: EvaluationEngineTrait<E>> VerifierKey<E, EE> {
fn new(shape: R1CSShape<E>, vk_ee: EE::VerifierKey) -> Self {
VerifierKey {
vk_ee,
S: shape,
digest: OnceCell::new(),
}
}
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> DigestHelperTrait<E> for VerifierKey<E, EE> {
fn digest(&self) -> E::Scalar {
self
.digest
.get_or_try_init(|| {
let dc = DigestComputer::<E::Scalar, _>::new(self);
dc.digest()
})
.cloned()
.expect("Failure to retrieve digest!")
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct RelaxedR1CSSNARK<E: Engine, EE: EvaluationEngineTrait<E>> {
sc_proof_outer: SumcheckProof<E>,
claims_outer: (E::Scalar, E::Scalar, E::Scalar),
eval_E: E::Scalar,
sc_proof_inner: SumcheckProof<E>,
eval_W: E::Scalar,
sc_proof_batch: SumcheckProof<E>,
evals_batch: Vec<E::Scalar>,
eval_arg: EE::EvaluationArgument,
}
impl<E: Engine, EE: EvaluationEngineTrait<E>> RelaxedR1CSSNARKTrait<E> for RelaxedR1CSSNARK<E, EE> {
type ProverKey = ProverKey<E, EE>;
type VerifierKey = VerifierKey<E, EE>;
fn setup(
ck: &CommitmentKey<E>,
S: &R1CSShape<E>,
) -> Result<(Self::ProverKey, Self::VerifierKey), NovaError> {
let (pk_ee, vk_ee) = EE::setup(ck)?;
let S = S.pad();
let vk: VerifierKey<E, EE> = VerifierKey::new(S, vk_ee);
let pk = ProverKey {
pk_ee,
vk_digest: vk.digest(),
};
Ok((pk, vk))
}
fn prove(
ck: &CommitmentKey<E>,
pk: &Self::ProverKey,
S: &R1CSShape<E>,
U: &RelaxedR1CSInstance<E>,
W: &RelaxedR1CSWitness<E>,
) -> Result<Self, NovaError> {
let S = S.pad();
assert!(S.is_regular_shape());
let W = W.pad(&S); let mut transcript = E::TE::new(b"RelaxedR1CSSNARK");
transcript.absorb(b"vk", &pk.vk_digest);
transcript.absorb(b"U", U);
let mut z = [W.W.clone(), vec![U.u], U.X.clone()].concat();
let (num_rounds_x, num_rounds_y) = (
usize::try_from(S.num_cons.ilog2()).unwrap(),
(usize::try_from(S.num_vars.ilog2()).unwrap() + 1),
);
let tau = (0..num_rounds_x)
.map(|_i| transcript.squeeze(b"t"))
.collect::<Result<Vec<_>, NovaError>>()?;
let (mut poly_Az, mut poly_Bz, poly_Cz, mut poly_uCz_E) = {
let (poly_Az, poly_Bz, poly_Cz) = S.multiply_vec(&z)?;
let poly_uCz_E = (0..S.num_cons)
.map(|i| U.u * poly_Cz[i] + W.E[i])
.collect::<Vec<E::Scalar>>();
(
MultilinearPolynomial::new(poly_Az),
MultilinearPolynomial::new(poly_Bz),
MultilinearPolynomial::new(poly_Cz),
MultilinearPolynomial::new(poly_uCz_E),
)
};
let (sc_proof_outer, r_x, claims_outer) = SumcheckProof::prove_cubic_with_three_inputs(
&E::Scalar::ZERO, tau,
&mut poly_Az,
&mut poly_Bz,
&mut poly_uCz_E,
&mut transcript,
)?;
let (claim_Az, claim_Bz): (E::Scalar, E::Scalar) = (claims_outer[0], claims_outer[1]);
let claim_Cz = poly_Cz.evaluate(&r_x);
let eval_E = MultilinearPolynomial::new(W.E.clone()).evaluate(&r_x);
transcript.absorb(
b"claims_outer",
&[claim_Az, claim_Bz, claim_Cz, eval_E].as_slice(),
);
let r = transcript.squeeze(b"r")?;
let claim_inner_joint = claim_Az + r * claim_Bz + r * r * claim_Cz;
let poly_ABC = {
let evals_rx = EqPolynomial::evals_from_points(&r_x.clone());
let (evals_A, evals_B, evals_C) = compute_eval_table_sparse(&S, &evals_rx);
assert_eq!(evals_A.len(), evals_B.len());
assert_eq!(evals_A.len(), evals_C.len());
(0..evals_A.len())
.into_par_iter()
.map(|i| evals_A[i] + r * evals_B[i] + r * r * evals_C[i])
.collect::<Vec<E::Scalar>>()
};
let poly_z = {
z.resize(S.num_vars * 2, E::Scalar::ZERO);
z
};
let (sc_proof_inner, r_y, _claims_inner) = SumcheckProof::prove_quad_prod(
&claim_inner_joint,
num_rounds_y,
&mut MultilinearPolynomial::new(poly_ABC),
&mut MultilinearPolynomial::new(poly_z),
&mut transcript,
)?;
let eval_W = MultilinearPolynomial::evaluate_with(&W.W, &r_y[1..]);
let w_vec = vec![PolyEvalWitness { p: W.W }, PolyEvalWitness { p: W.E }];
let u_vec = vec![
PolyEvalInstance {
c: U.comm_W,
x: r_y[1..].to_vec(),
e: eval_W,
},
PolyEvalInstance {
c: U.comm_E,
x: r_x,
e: eval_E,
},
];
let (batched_u, batched_w, _chal, sc_proof_batch, claims_batch_left) =
super::batch_eval_reduce(u_vec, w_vec, &mut transcript)?;
let eval_arg = EE::prove(
ck,
&pk.pk_ee,
&mut transcript,
&batched_u.c,
&batched_w.p,
&batched_u.x,
&batched_u.e,
)?;
Ok(RelaxedR1CSSNARK {
sc_proof_outer,
claims_outer: (claim_Az, claim_Bz, claim_Cz),
eval_E,
sc_proof_inner,
eval_W,
sc_proof_batch,
evals_batch: claims_batch_left,
eval_arg,
})
}
fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance<E>) -> Result<(), NovaError> {
let mut transcript = E::TE::new(b"RelaxedR1CSSNARK");
transcript.absorb(b"vk", &vk.digest());
transcript.absorb(b"U", U);
let (num_rounds_x, num_rounds_y) = (
usize::try_from(vk.S.num_cons.ilog2()).unwrap(),
(usize::try_from(vk.S.num_vars.ilog2()).unwrap() + 1),
);
let tau = (0..num_rounds_x)
.map(|_i| transcript.squeeze(b"t"))
.collect::<Result<EqPolynomial<_>, NovaError>>()?;
let (claim_outer_final, r_x) =
self
.sc_proof_outer
.verify(E::Scalar::ZERO, num_rounds_x, 3, &mut transcript)?;
let (claim_Az, claim_Bz, claim_Cz) = self.claims_outer;
let taus_bound_rx = tau.evaluate(&r_x);
let claim_outer_final_expected =
taus_bound_rx * (claim_Az * claim_Bz - U.u * claim_Cz - self.eval_E);
if claim_outer_final != claim_outer_final_expected {
return Err(NovaError::InvalidSumcheckProof);
}
transcript.absorb(
b"claims_outer",
&[
self.claims_outer.0,
self.claims_outer.1,
self.claims_outer.2,
self.eval_E,
]
.as_slice(),
);
let r = transcript.squeeze(b"r")?;
let claim_inner_joint =
self.claims_outer.0 + r * self.claims_outer.1 + r * r * self.claims_outer.2;
let (claim_inner_final, r_y) =
self
.sc_proof_inner
.verify(claim_inner_joint, num_rounds_y, 2, &mut transcript)?;
let eval_Z = {
let eval_X = {
let X = vec![U.u]
.into_iter()
.chain(U.X.iter().cloned())
.collect::<Vec<E::Scalar>>();
SparsePolynomial::new(vk.S.num_vars.log_2(), X).evaluate(&r_y[1..])
};
(E::Scalar::ONE - r_y[0]) * self.eval_W + r_y[0] * eval_X
};
let multi_evaluate = |M_vec: &[&SparseMatrix<E::Scalar>],
r_x: &[E::Scalar],
r_y: &[E::Scalar]|
-> Vec<E::Scalar> {
let evaluate_with_table =
|M: &SparseMatrix<E::Scalar>, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar {
M.indptr
.par_windows(2)
.enumerate()
.map(|(row_idx, ptrs)| {
M.get_row_unchecked(ptrs.try_into().unwrap())
.map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val)
.sum::<E::Scalar>()
})
.sum()
};
let (T_x, T_y) = rayon::join(
|| EqPolynomial::evals_from_points(r_x),
|| EqPolynomial::evals_from_points(r_y),
);
(0..M_vec.len())
.into_par_iter()
.map(|i| evaluate_with_table(M_vec[i], &T_x, &T_y))
.collect()
};
let evals = multi_evaluate(&[&vk.S.A, &vk.S.B, &vk.S.C], &r_x, &r_y);
let claim_inner_final_expected = (evals[0] + r * evals[1] + r * r * evals[2]) * eval_Z;
if claim_inner_final != claim_inner_final_expected {
return Err(NovaError::InvalidSumcheckProof);
}
let u_vec: Vec<PolyEvalInstance<E>> = vec![
PolyEvalInstance {
c: U.comm_W,
x: r_y[1..].to_vec(),
e: self.eval_W,
},
PolyEvalInstance {
c: U.comm_E,
x: r_x,
e: self.eval_E,
},
];
let (batched_u, _chal) = super::batch_eval_verify(
u_vec,
&mut transcript,
&self.sc_proof_batch,
&self.evals_batch,
)?;
EE::verify(
&vk.vk_ee,
&mut transcript,
&batched_u.c,
&batched_u.x,
&batched_u.e,
&self.eval_arg,
)?;
Ok(())
}
}