use alloc::vec::Vec;
use core::ops;
use dusk_bytes::{DeserializableSlice, Serializable};
use dusk_curves::bls12_381::BlsScalar;
use ff::Field;
use merlin::Transcript;
use rand_core::{CryptoRng, RngCore};
use super::{Circuit, Composer, PlonkVersion};
use crate::commitment_scheme::CommitKey;
use crate::compiler::prover::linearization_poly::ProofEvaluations;
use crate::error::Error;
use crate::fft::{EvaluationDomain, Polynomial};
use crate::proof_system::proof::Proof;
use crate::proof_system::{
ProverKey, VerifierKey, linearization_poly, quotient_poly,
};
use crate::transcript::TranscriptProtocol;
#[derive(Clone)]
pub struct Prover {
label: Vec<u8>,
pub(crate) prover_key: ProverKey,
pub(crate) commit_key: CommitKey,
pub(crate) verifier_key: VerifierKey,
pub(crate) transcript: Transcript,
sigma_evals: [Vec<BlsScalar>; 4],
pub(crate) size: usize,
pub(crate) constraints: usize,
}
impl ops::Deref for Prover {
type Target = ProverKey;
fn deref(&self) -> &Self::Target {
&self.prover_key
}
}
impl Prover {
pub(crate) fn new(
label: Vec<u8>,
prover_key: ProverKey,
commit_key: CommitKey,
verifier_key: VerifierKey,
size: usize,
constraints: usize,
) -> Self {
let transcript =
Transcript::base(label.as_slice(), &verifier_key, constraints);
let sigma_evals = Self::compute_sigma_evals(&prover_key, constraints);
Self {
label,
prover_key,
commit_key,
verifier_key,
transcript,
sigma_evals,
size,
constraints,
}
}
fn compute_sigma_evals(
prover_key: &ProverKey,
constraints: usize,
) -> [Vec<BlsScalar>; 4] {
let domain = EvaluationDomain::new(constraints)
.expect("compiler constructed an invalid evaluation domain");
[
domain.fft(&prover_key.permutation.s_sigma_1.0),
domain.fft(&prover_key.permutation.s_sigma_2.0),
domain.fft(&prover_key.permutation.s_sigma_3.0),
domain.fft(&prover_key.permutation.s_sigma_4.0),
]
}
pub(crate) fn blind_poly<R>(
rng: &mut R,
witnesses: &[BlsScalar],
hiding_degree: usize,
domain: &EvaluationDomain,
) -> Polynomial
where
R: RngCore + CryptoRng,
{
let blinders: Vec<_> = (0..=hiding_degree)
.map(|_| BlsScalar::random(&mut *rng))
.collect();
Self::blind_poly_with_blinders(witnesses, &blinders, domain)
}
fn blind_poly_with_blinders(
witnesses: &[BlsScalar],
blinders: &[BlsScalar],
domain: &EvaluationDomain,
) -> Polynomial {
let mut w_vec_inverse = domain.ifft(witnesses);
for (i, blinding_scalar) in blinders.iter().copied().enumerate() {
w_vec_inverse[i] -= blinding_scalar;
w_vec_inverse.push(blinding_scalar);
}
Polynomial::from_coefficients_vec(w_vec_inverse)
}
fn prepare_serialize(
&self,
) -> (usize, Vec<u8>, Vec<u8>, [u8; VerifierKey::SIZE]) {
let prover_key = self.prover_key.to_var_bytes();
let commit_key = self.commit_key.to_raw_var_bytes();
let verifier_key = self.verifier_key.to_bytes();
let label_len = self.label.len();
let prover_key_len = prover_key.len();
let commit_key_len = commit_key.len();
let verifier_key_len = verifier_key.len();
let size =
48 + label_len + prover_key_len + commit_key_len + verifier_key_len;
(size, prover_key, commit_key, verifier_key)
}
pub fn serialized_size(&self) -> usize {
self.prepare_serialize().0
}
pub fn to_bytes(&self) -> Vec<u8> {
let (size, prover_key, commit_key, verifier_key) =
self.prepare_serialize();
let mut bytes = Vec::with_capacity(size);
let label_len = self.label.len() as u64;
let prover_key_len = prover_key.len() as u64;
let commit_key_len = commit_key.len() as u64;
let verifier_key_len = verifier_key.len() as u64;
let size = self.size as u64;
let constraints = self.constraints as u64;
bytes.extend(label_len.to_be_bytes());
bytes.extend(prover_key_len.to_be_bytes());
bytes.extend(commit_key_len.to_be_bytes());
bytes.extend(verifier_key_len.to_be_bytes());
bytes.extend(size.to_be_bytes());
bytes.extend(constraints.to_be_bytes());
bytes.extend(self.label.as_slice());
bytes.extend(prover_key);
bytes.extend(commit_key);
bytes.extend(verifier_key);
bytes
}
pub fn try_from_bytes<B>(bytes: B) -> Result<Self, Error>
where
B: AsRef<[u8]>,
{
let mut bytes = bytes.as_ref();
if bytes.len() < 48 {
return Err(Error::NotEnoughBytes);
}
let label_len = <[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let label_len = u64::from_be_bytes(label_len) as usize;
bytes = &bytes[8..];
let prover_key_len =
<[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let prover_key_len = u64::from_be_bytes(prover_key_len) as usize;
bytes = &bytes[8..];
let commit_key_len =
<[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let commit_key_len = u64::from_be_bytes(commit_key_len) as usize;
bytes = &bytes[8..];
let verifier_key_len =
<[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let verifier_key_len = u64::from_be_bytes(verifier_key_len) as usize;
bytes = &bytes[8..];
let size = <[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let size = u64::from_be_bytes(size) as usize;
bytes = &bytes[8..];
let constraints =
<[u8; 8]>::try_from(&bytes[..8]).expect("checked len");
let constraints = u64::from_be_bytes(constraints) as usize;
bytes = &bytes[8..];
let required_len = label_len
.checked_add(prover_key_len)
.and_then(|len| len.checked_add(commit_key_len))
.and_then(|len| len.checked_add(verifier_key_len))
.ok_or(Error::NotEnoughBytes)?;
if bytes.len() < required_len {
return Err(Error::NotEnoughBytes);
}
let label = &bytes[..label_len];
bytes = &bytes[label_len..];
let prover_key = &bytes[..prover_key_len];
bytes = &bytes[prover_key_len..];
let commit_key = &bytes[..commit_key_len];
bytes = &bytes[commit_key_len..];
let verifier_key = &bytes[..verifier_key_len];
let label = label.to_vec();
let prover_key = ProverKey::from_slice(prover_key)?;
let commit_key = CommitKey::from_raw_var_bytes(commit_key)?;
let verifier_key = VerifierKey::from_slice(verifier_key)?;
Ok(Self::new(
label,
prover_key,
commit_key,
verifier_key,
size,
constraints,
))
}
pub fn prove<C, R>(
&self,
rng: &mut R,
circuit: &C,
) -> Result<(Proof, Vec<BlsScalar>), Error>
where
C: Circuit,
R: RngCore + CryptoRng,
{
self.prove_with_version(rng, circuit, PlonkVersion::current())
}
pub fn prove_with_version<C, R>(
&self,
rng: &mut R,
circuit: &C,
version: PlonkVersion,
) -> Result<(Proof, Vec<BlsScalar>), Error>
where
C: Circuit,
R: RngCore + CryptoRng,
{
match version {
PlonkVersion::V1 => Err(Error::UnsupportedProvingVersion),
PlonkVersion::V2 => self.prove_legacy(rng, circuit, version),
PlonkVersion::V3 => self.prove_inner(rng, circuit, version),
}
}
fn prove_legacy<C, R>(
&self,
rng: &mut R,
circuit: &C,
version: PlonkVersion,
) -> Result<(Proof, Vec<BlsScalar>), Error>
where
C: Circuit,
R: RngCore + CryptoRng,
{
#[cfg(feature = "legacy-proving")]
{
self.prove_inner(rng, circuit, version)
}
#[cfg(not(feature = "legacy-proving"))]
{
let _ = (rng, circuit, version);
Err(Error::LegacyProvingDisabled)
}
}
fn transcript_for_version(&self, version: PlonkVersion) -> Transcript {
match version {
PlonkVersion::V1 | PlonkVersion::V2 => self.transcript.clone(),
PlonkVersion::V3 => Transcript::base_v3(
self.label.as_slice(),
&self.verifier_key,
self.constraints,
),
}
}
fn prove_inner<C, R>(
&self,
rng: &mut R,
circuit: &C,
version: PlonkVersion,
) -> Result<(Proof, Vec<BlsScalar>), Error>
where
C: Circuit,
R: RngCore + CryptoRng,
{
let prover = Composer::prove(self.constraints, circuit)?;
let constraints = self.constraints;
let size = self.size;
let domain = EvaluationDomain::new(constraints)?;
let mut transcript = self.transcript_for_version(version);
let public_inputs = prover.public_inputs();
let public_input_indexes = prover.public_input_indexes();
let dense_public_inputs = Composer::dense_public_inputs(
&public_input_indexes,
&public_inputs,
self.size,
);
public_inputs
.iter()
.for_each(|pi| transcript.append_scalar(b"pi", pi));
let mut a_scalars = vec![BlsScalar::zero(); size];
let mut b_scalars = vec![BlsScalar::zero(); size];
let mut c_scalars = vec![BlsScalar::zero(); size];
let mut d_scalars = vec![BlsScalar::zero(); size];
prover
.constraints
.iter()
.enumerate()
.for_each(|(i, constraint)| {
a_scalars[i] = prover[constraint.a];
b_scalars[i] = prover[constraint.b];
c_scalars[i] = prover[constraint.c];
d_scalars[i] = prover[constraint.d];
});
#[cfg(feature = "std")]
let wire_blinders = [
[BlsScalar::random(&mut *rng), BlsScalar::random(&mut *rng)],
[BlsScalar::random(&mut *rng), BlsScalar::random(&mut *rng)],
[BlsScalar::random(&mut *rng), BlsScalar::random(&mut *rng)],
[BlsScalar::random(&mut *rng), BlsScalar::random(&mut *rng)],
];
#[cfg(feature = "std")]
let ((a_poly, b_poly), (c_poly, d_poly)) = rayon::join(
|| {
rayon::join(
|| {
Self::blind_poly_with_blinders(
&a_scalars,
&wire_blinders[0],
&domain,
)
},
|| {
Self::blind_poly_with_blinders(
&b_scalars,
&wire_blinders[1],
&domain,
)
},
)
},
|| {
rayon::join(
|| {
Self::blind_poly_with_blinders(
&c_scalars,
&wire_blinders[2],
&domain,
)
},
|| {
Self::blind_poly_with_blinders(
&d_scalars,
&wire_blinders[3],
&domain,
)
},
)
},
);
#[cfg(not(feature = "std"))]
let a_poly = Self::blind_poly(rng, &a_scalars, 1, &domain);
#[cfg(not(feature = "std"))]
let b_poly = Self::blind_poly(rng, &b_scalars, 1, &domain);
#[cfg(not(feature = "std"))]
let c_poly = Self::blind_poly(rng, &c_scalars, 1, &domain);
#[cfg(not(feature = "std"))]
let d_poly = Self::blind_poly(rng, &d_scalars, 1, &domain);
#[cfg(feature = "std")]
let ((a_comm, b_comm), (c_comm, d_comm)) = rayon::join(
|| {
rayon::join(
|| self.commit_key.commit(&a_poly),
|| self.commit_key.commit(&b_poly),
)
},
|| {
rayon::join(
|| self.commit_key.commit(&c_poly),
|| self.commit_key.commit(&d_poly),
)
},
);
#[cfg(feature = "std")]
let (a_comm, b_comm, c_comm, d_comm) =
(a_comm?, b_comm?, c_comm?, d_comm?);
#[cfg(not(feature = "std"))]
let a_comm = self.commit_key.commit(&a_poly)?;
#[cfg(not(feature = "std"))]
let b_comm = self.commit_key.commit(&b_poly)?;
#[cfg(not(feature = "std"))]
let c_comm = self.commit_key.commit(&c_poly)?;
#[cfg(not(feature = "std"))]
let d_comm = self.commit_key.commit(&d_poly)?;
transcript.append_commitment(b"a_comm", &a_comm);
transcript.append_commitment(b"b_comm", &b_comm);
transcript.append_commitment(b"c_comm", &c_comm);
transcript.append_commitment(b"d_comm", &d_comm);
let beta = transcript.challenge_scalar(b"beta");
transcript.append_scalar(b"beta", &beta);
let gamma = transcript.challenge_scalar(b"gamma");
let sigma = [
self.sigma_evals[0].as_slice(),
self.sigma_evals[1].as_slice(),
self.sigma_evals[2].as_slice(),
self.sigma_evals[3].as_slice(),
];
let wires = [
a_scalars.as_slice(),
b_scalars.as_slice(),
c_scalars.as_slice(),
d_scalars.as_slice(),
];
let permutation = prover
.perm
.compute_permutation_vec(&domain, wires, &beta, &gamma, sigma);
let z_poly = Self::blind_poly(rng, &permutation, 2, &domain);
let z_comm = self.commit_key.commit(&z_poly)?;
transcript.append_commitment(b"z_comm", &z_comm);
let alpha = transcript.challenge_scalar(b"alpha");
let range_sep_challenge =
transcript.challenge_scalar(b"range separation challenge");
let logic_sep_challenge =
transcript.challenge_scalar(b"logic separation challenge");
let fixed_base_sep_challenge =
transcript.challenge_scalar(b"fixed base separation challenge");
let var_base_sep_challenge =
transcript.challenge_scalar(b"variable base separation challenge");
let pi_poly = domain.ifft(&dense_public_inputs);
let pi_poly = Polynomial::from_coefficients_vec(pi_poly);
let wires = (&a_poly, &b_poly, &c_poly, &d_poly);
let args = &(
alpha,
beta,
gamma,
range_sep_challenge,
logic_sep_challenge,
fixed_base_sep_challenge,
var_base_sep_challenge,
);
let t_poly = quotient_poly::compute(
&domain,
&self.prover_key,
&z_poly,
wires,
&pi_poly,
args,
)?;
let domain_size = domain.size();
let mut t_low_vec = t_poly[0..domain_size].to_vec();
let mut t_mid_vec = t_poly[domain_size..2 * domain_size].to_vec();
let mut t_high_vec = t_poly[2 * domain_size..3 * domain_size].to_vec();
let mut t_fourth_vec = t_poly[3 * domain_size..].to_vec();
let b_12 = BlsScalar::random(&mut *rng);
let b_13 = BlsScalar::random(&mut *rng);
let b_14 = BlsScalar::random(&mut *rng);
t_low_vec.push(b_12);
t_mid_vec[0] -= b_12;
t_mid_vec.push(b_13);
t_high_vec[0] -= b_13;
t_high_vec.push(b_14);
t_fourth_vec[0] -= b_14;
let t_low_poly = Polynomial::from_coefficients_vec(t_low_vec);
let t_mid_poly = Polynomial::from_coefficients_vec(t_mid_vec);
let t_high_poly = Polynomial::from_coefficients_vec(t_high_vec);
let t_fourth_poly = Polynomial::from_coefficients_vec(t_fourth_vec);
#[cfg(feature = "std")]
let ((t_low_comm, t_mid_comm), (t_high_comm, t_fourth_comm)) =
rayon::join(
|| {
rayon::join(
|| self.commit_key.commit(&t_low_poly),
|| self.commit_key.commit(&t_mid_poly),
)
},
|| {
rayon::join(
|| self.commit_key.commit(&t_high_poly),
|| self.commit_key.commit(&t_fourth_poly),
)
},
);
#[cfg(feature = "std")]
let (t_low_comm, t_mid_comm, t_high_comm, t_fourth_comm) =
(t_low_comm?, t_mid_comm?, t_high_comm?, t_fourth_comm?);
#[cfg(not(feature = "std"))]
let t_low_comm = self.commit_key.commit(&t_low_poly)?;
#[cfg(not(feature = "std"))]
let t_mid_comm = self.commit_key.commit(&t_mid_poly)?;
#[cfg(not(feature = "std"))]
let t_high_comm = self.commit_key.commit(&t_high_poly)?;
#[cfg(not(feature = "std"))]
let t_fourth_comm = self.commit_key.commit(&t_fourth_poly)?;
transcript.append_commitment(b"t_low_comm", &t_low_comm);
transcript.append_commitment(b"t_mid_comm", &t_mid_comm);
transcript.append_commitment(b"t_high_comm", &t_high_comm);
transcript.append_commitment(b"t_fourth_comm", &t_fourth_comm);
let z_challenge = transcript.challenge_scalar(b"z_challenge");
let a_eval = a_poly.evaluate(&z_challenge);
let b_eval = b_poly.evaluate(&z_challenge);
let c_eval = c_poly.evaluate(&z_challenge);
let d_eval = d_poly.evaluate(&z_challenge);
let s_sigma_1_eval = self
.prover_key
.permutation
.s_sigma_1
.0
.evaluate(&z_challenge);
let s_sigma_2_eval = self
.prover_key
.permutation
.s_sigma_2
.0
.evaluate(&z_challenge);
let s_sigma_3_eval = self
.prover_key
.permutation
.s_sigma_3
.0
.evaluate(&z_challenge);
let z_eval = z_poly.evaluate(&(z_challenge * domain.group_gen));
transcript.append_scalar(b"a_eval", &a_eval);
transcript.append_scalar(b"b_eval", &b_eval);
transcript.append_scalar(b"c_eval", &c_eval);
transcript.append_scalar(b"d_eval", &d_eval);
transcript.append_scalar(b"s_sigma_1_eval", &s_sigma_1_eval);
transcript.append_scalar(b"s_sigma_2_eval", &s_sigma_2_eval);
transcript.append_scalar(b"s_sigma_3_eval", &s_sigma_3_eval);
transcript.append_scalar(b"z_eval", &z_eval);
let a_w_eval = a_poly.evaluate(&(z_challenge * domain.group_gen));
let b_w_eval = b_poly.evaluate(&(z_challenge * domain.group_gen));
let d_w_eval = d_poly.evaluate(&(z_challenge * domain.group_gen));
let q_arith_eval =
self.prover_key.arithmetic.q_arith.0.evaluate(&z_challenge);
let q_c_eval = self.prover_key.arithmetic.q_c.0.evaluate(&z_challenge);
let q_l_eval = self.prover_key.arithmetic.q_l.0.evaluate(&z_challenge);
let q_r_eval = self.prover_key.arithmetic.q_r.0.evaluate(&z_challenge);
transcript.append_scalar(b"a_w_eval", &a_w_eval);
transcript.append_scalar(b"b_w_eval", &b_w_eval);
transcript.append_scalar(b"d_w_eval", &d_w_eval);
transcript.append_scalar(b"q_arith_eval", &q_arith_eval);
transcript.append_scalar(b"q_c_eval", &q_c_eval);
transcript.append_scalar(b"q_l_eval", &q_l_eval);
transcript.append_scalar(b"q_r_eval", &q_r_eval);
let evaluations = ProofEvaluations {
a_eval,
b_eval,
c_eval,
d_eval,
a_w_eval,
b_w_eval,
d_w_eval,
q_arith_eval,
q_c_eval,
q_l_eval,
q_r_eval,
s_sigma_1_eval,
s_sigma_2_eval,
s_sigma_3_eval,
z_eval,
};
let v_challenge = transcript.challenge_scalar(b"v_challenge");
let r_poly = linearization_poly::compute(
&self.prover_key,
&(
alpha,
beta,
gamma,
range_sep_challenge,
logic_sep_challenge,
fixed_base_sep_challenge,
var_base_sep_challenge,
z_challenge,
),
&z_poly,
&evaluations,
&domain,
&t_low_poly,
&t_mid_poly,
&t_high_poly,
&t_fourth_poly,
&public_inputs,
);
let aggregate_witness = CommitKey::compute_aggregate_witness(
&[
r_poly,
a_poly.clone(),
b_poly.clone(),
c_poly,
d_poly.clone(),
self.prover_key.permutation.s_sigma_1.0.clone(),
self.prover_key.permutation.s_sigma_2.0.clone(),
self.prover_key.permutation.s_sigma_3.0.clone(),
self.prover_key.arithmetic.q_arith.0.clone(),
self.prover_key.arithmetic.q_c.0.clone(),
self.prover_key.arithmetic.q_l.0.clone(),
self.prover_key.arithmetic.q_r.0.clone(),
],
&z_challenge,
&v_challenge,
);
let w_z_chall_comm = self.commit_key.commit(&aggregate_witness)?;
let v_w_challenge = transcript.challenge_scalar(b"v_w_challenge");
let shifted_aggregate_witness = CommitKey::compute_aggregate_witness(
&[z_poly, a_poly, b_poly, d_poly],
&(z_challenge * domain.group_gen),
&v_w_challenge,
);
let w_z_chall_w_comm =
self.commit_key.commit(&shifted_aggregate_witness)?;
let proof = Proof {
a_comm,
b_comm,
c_comm,
d_comm,
z_comm,
t_low_comm,
t_mid_comm,
t_high_comm,
t_fourth_comm,
w_z_chall_comm,
w_z_chall_w_comm,
evaluations,
};
Ok((proof, public_inputs))
}
}
#[cfg(test)]
mod tests {
use dusk_curves::bls12_381::BlsScalar;
use rand::SeedableRng;
use rand::rngs::StdRng;
use super::Prover;
use crate::error::Error;
use crate::prelude::{Circuit, Compiler, Composer, PublicParameters};
#[derive(Default)]
struct MinimalCircuit;
impl Circuit for MinimalCircuit {
fn circuit(&self, composer: &mut Composer) -> Result<(), Error> {
let w = composer.append_witness(BlsScalar::from(7u64));
composer.assert_equal_constant(w, BlsScalar::from(7u64), None);
Ok(())
}
}
#[test]
fn prover_try_from_bytes_rejects_malformed_commit_key_without_panicking() {
let mut rng = StdRng::seed_from_u64(42);
let pp = PublicParameters::setup(1 << 10, &mut rng)
.expect("public parameters should build");
let (prover, _) = Compiler::compile::<MinimalCircuit>(&pp, b"p1.4-3")
.expect("circuit should compile");
let mut bytes = prover.to_bytes();
bytes[16..24].copy_from_slice(&(0u64).to_be_bytes());
let result =
std::panic::catch_unwind(|| Prover::try_from_bytes(&bytes));
assert!(result.is_ok(), "deserializer should never panic");
assert!(
result.expect("checked above").is_err(),
"malformed input must be rejected"
);
}
#[test]
fn prover_try_from_bytes_rejects_overflow_lengths_without_panicking() {
let mut bytes = Vec::with_capacity(48);
bytes.extend_from_slice(&0u64.to_be_bytes()); bytes.extend_from_slice(&u64::MAX.to_be_bytes()); bytes.extend_from_slice(&u64::MAX.to_be_bytes()); bytes.extend_from_slice(&u64::MAX.to_be_bytes()); bytes.extend_from_slice(&0u64.to_be_bytes()); bytes.extend_from_slice(&0u64.to_be_bytes());
let result =
std::panic::catch_unwind(|| Prover::try_from_bytes(&bytes));
assert!(result.is_ok(), "deserializer should never panic");
assert!(matches!(
result.expect("checked above"),
Err(Error::NotEnoughBytes)
));
}
}