use crate::curve::{Fr, outer};
use base_crypto::hash::{HashOutput, persistent_hash};
use lazy_static::lazy_static;
use lru::LruCache;
use midnight_curves::Bls12;
use midnight_proofs::{
poly::kzg::params::{ParamsKZG, ParamsVerifierKZG},
utils::SerdeFormat,
};
use midnight_zk_stdlib::{MidnightCircuit, MidnightPK, MidnightVK, Relation};
#[cfg(feature = "proptest")]
use proptest::arbitrary::Arbitrary;
#[cfg(feature = "proptest")]
use proptest_derive::Arbitrary;
use rand::distributions::{Distribution, Standard};
use rand::{CryptoRng, Rng};
use serde::{Deserialize, Deserializer, Serialize, Serializer, ser::Error as SerError};
use serialize::{
Deserializable, Serializable, Tagged, VecExt, tag_enforcement_test, tagged_deserialize,
};
#[cfg(feature = "proptest")]
use serialize::{NoStrategy, simple_arbitrary};
use std::fmt::Debug;
use std::hash::{Hash, Hasher};
use std::io::Write;
use std::io::{self, Read};
#[cfg(feature = "proptest")]
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
use std::{any::Any, cmp::Ordering};
use std::{borrow::Cow, num::NonZeroUsize};
use storage_core::Storable;
use storage_core::arena::ArenaKey;
use storage_core::db::DB;
use storage_core::storable::Loader;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub trait ParamsProverProvider {
#[allow(async_fn_in_trait)]
async fn get_params(&self, k: u8) -> io::Result<ParamsProver>;
}
pub type TranscriptHash = blake2b_simd::State;
impl ParamsProverProvider for base_crypto::data_provider::MidnightDataProvider {
async fn get_params(&self, k: u8) -> io::Result<ParamsProver> {
let name = Self::name_k(k);
let reader = self
.get_file(
&name,
&format!("public parameters for k={k} not found in cache"),
)
.await?;
ParamsProver::read(reader)
}
}
#[derive(Clone)]
pub struct ParamsProver(pub Arc<ParamsKZG<Bls12>>);
impl AsRef<ParamsKZG<Bls12>> for ParamsProver {
fn as_ref(&self) -> &ParamsKZG<Bls12> {
&self.0
}
}
impl ParamsProver {
pub fn read<R: Read>(mut reader: R) -> io::Result<Self> {
Ok(ParamsProver(Arc::new(ParamsKZG::read_custom(
&mut reader,
SerdeFormat::RawBytesUnchecked,
)?)))
}
pub(crate) fn as_verifier(&self) -> ParamsVerifier {
ParamsVerifier(Arc::new(self.0.verifier_params()))
}
}
pub const VERIFIER_MAX_DEGREE: u8 = 14;
#[derive(Clone)]
pub struct ParamsVerifier(Arc<ParamsVerifierKZG<Bls12>>);
impl ParamsVerifier {
pub fn read<R: Read>(reader: R) -> io::Result<Self> {
Ok(ParamsProver::read(reader)?.as_verifier())
}
}
const PARAMS_VERIFIER_RAW: &[u8] = include_bytes!("../static/bls_midnight_2p14");
lazy_static! {
pub static ref PARAMS_VERIFIER: ParamsVerifier = ParamsVerifier::read(PARAMS_VERIFIER_RAW).expect("Static verifier parameters should be valid.");
}
#[cfg_attr(feature = "proptest", derive(Arbitrary))]
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serializable, Storable)]
#[storable(base)]
#[tag = "proof[v5]"]
pub struct Proof(pub Vec<u8>);
tag_enforcement_test!(Proof);
#[derive(Debug, Clone)]
pub struct ProverKey<T: Zkir>(Arc<Mutex<InnerProverKey<T>>>);
impl<T: Zkir> From<MidnightPK<T>> for ProverKey<T> {
fn from(pk: MidnightPK<T>) -> Self {
ProverKey(Arc::new(Mutex::new(InnerProverKey::Initialized(Arc::new(
pk,
)))))
}
}
#[allow(async_fn_in_trait)]
pub trait Zkir: Relation + Tagged + Deserializable + Any + Send + Sync + Debug {
fn check(&self, preimage: &ProofPreimage) -> Result<Vec<Option<usize>>, ProvingError>;
async fn prove(
&self,
rng: impl Rng + CryptoRng,
params: &impl ParamsProverProvider,
pk: ProverKey<Self>,
preimage: &ProofPreimage,
) -> Result<(Proof, Vec<Fr>, Vec<Option<usize>>), ProvingError>;
fn k(&self) -> u8 {
MidnightCircuit::from_relation(self).min_k() as u8
}
async fn keygen_vk(
&self,
params: &impl ParamsProverProvider,
) -> Result<VerifierKey, anyhow::Error> {
use midnight_zk_stdlib::setup_vk;
let vk = VerifierKey::from(setup_vk(params.get_params(self.k()).await?.as_ref(), self));
Ok(vk)
}
async fn keygen(
&self,
params: &impl ParamsProverProvider,
) -> Result<(ProverKey<Self>, VerifierKey), anyhow::Error> {
use midnight_zk_stdlib::{setup_pk, setup_vk};
let vk = setup_vk(params.get_params(self.k()).await?.as_ref(), self);
let pk = setup_pk(self, &vk);
Ok((ProverKey::from(pk), VerifierKey::from(vk)))
}
}
impl<T: Zkir> PartialEq for ProverKey<T> {
fn eq(&self, other: &Self) -> bool {
let mut self_ser = Vec::new();
let mut other_ser = Vec::new();
Serializable::serialize(self, &mut self_ser).expect("In-memory serialization must succeed");
Serializable::serialize(other, &mut other_ser)
.expect("In-memory serialization must succeed");
self_ser == other_ser
}
}
impl<T: Zkir> Eq for ProverKey<T> {}
impl<T: Zkir> Distribution<ProverKey<T>> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ProverKey<T> {
let size: u8 = rng.gen_range(0..32);
let mut bytes = Vec::with_bounded_capacity(size as usize);
rng.fill_bytes(&mut bytes);
ProverKey(Arc::new(Mutex::new(InnerProverKey::Uninitialized(bytes))))
}
}
#[derive(Debug, Clone)]
pub(crate) enum InnerProverKey<T: Zkir> {
Uninitialized(Vec<u8>),
Invalid(Vec<u8>),
Initialized(Arc<MidnightPK<T>>),
}
impl<T: Zkir> Tagged for ProverKey<T> {
fn tag() -> Cow<'static, str> {
Cow::Owned(format!("prover-key[v7]({})", T::tag()))
}
fn tag_unique_factor() -> String {
format!("prover-key[v7]({})", T::tag())
}
}
const PK_COMPRESSION_LEVEL: u32 = 6;
const PK_CACHE_SIZE: usize = 5;
lazy_static! {
static ref PK_CACHE: Mutex<LruCache<HashOutput, Arc<dyn Any + Send + Sync>>> =
Mutex::new(LruCache::new(NonZeroUsize::new(PK_CACHE_SIZE).unwrap()));
}
impl<T: Zkir> InnerProverKey<T> {
fn try_cache(&mut self) {
let hash = match self {
InnerProverKey::Uninitialized(data) => persistent_hash(&data[..]),
_ => return,
};
if let Some(pk) = PK_CACHE
.lock()
.ok()
.and_then(|mut c| c.get(&hash).cloned())
.and_then(|ptr| ptr.downcast().ok())
{
*self = InnerProverKey::Initialized(pk);
}
}
}
impl<T: Zkir> ProverKey<T> {
pub fn init(&self) -> Result<Arc<MidnightPK<T>>, ProvingError> {
let mut mutex = self.0.lock().expect("mutex is not poisoned");
mutex.try_cache();
let data = match &*mutex {
InnerProverKey::Initialized(key) => {
return Ok(key.clone());
}
InnerProverKey::Invalid(_) => {
return Err(anyhow::anyhow!("known invalid verifier key"));
}
InnerProverKey::Uninitialized(data) => data.clone(),
};
let inner_reader = &mut &data[..];
let mut reader = flate2::read::GzDecoder::new(inner_reader);
let read_inner = |reader| {
let pk = MidnightPK::<T>::read(reader, SerdeFormat::RawBytesUnchecked)?;
Ok(pk)
};
let res: Result<_, ProvingError> = read_inner(&mut reader);
match res {
Ok(pk) => {
let key = Arc::new(pk);
PK_CACHE
.lock()
.ok()
.and_then(|mut c| c.put(persistent_hash(&data), key.clone()));
*mutex = InnerProverKey::Initialized(key.clone());
Ok(key)
}
Err(e) => {
*mutex = InnerProverKey::Invalid(data);
Err(e)
}
}
}
fn inner_serialize<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
match &*self.0.lock().expect("mutex is not poisoned") {
InnerProverKey::Uninitialized(data) | InnerProverKey::Invalid(data) => {
writer.write_all(data)?;
Ok(())
}
InnerProverKey::Initialized(key) => {
let mut writer = flate2::write::GzEncoder::new(
writer,
flate2::Compression::new(PK_COMPRESSION_LEVEL),
);
key.write(&mut writer, SerdeFormat::RawBytesUnchecked)
}
}
}
}
struct Count(usize);
impl std::io::Write for Count {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0 += buf.len();
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<T: Zkir> Serializable for ProverKey<T> {
fn serialize(&self, writer: &mut impl Write) -> std::io::Result<()> {
let mut count = Count(0);
self.inner_serialize(&mut count).ok();
Serializable::serialize(&(count.0 as u64), writer)?;
self.inner_serialize(writer)
}
fn serialized_size(&self) -> usize {
let mut writer = Count(0);
self.inner_serialize(&mut writer).ok();
(writer.0 as u64).serialized_size() + writer.0
}
}
impl<T: Zkir> Deserializable for ProverKey<T> {
fn deserialize(reader: &mut impl Read, recursion_depth: u32) -> Result<Self, std::io::Error> {
let buf = <Vec<u8> as Deserializable>::deserialize(reader, recursion_depth)?;
let mut pk = InnerProverKey::Uninitialized(buf);
pk.try_cache();
Ok(Self(Arc::new(Mutex::new(pk))))
}
}
#[derive(Debug, Storable)]
#[storable(base)]
pub struct VerifierKey(Arc<Mutex<InnerVerifierKey>>);
#[cfg(feature = "proptest")]
simple_arbitrary!(VerifierKey);
impl Tagged for VerifierKey {
fn tag() -> Cow<'static, str> {
Cow::Borrowed("verifier-key[v6]")
}
fn tag_unique_factor() -> String {
"verifier-key[v6]".into()
}
}
tag_enforcement_test!(VerifierKey);
impl Distribution<VerifierKey> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> VerifierKey {
let size: u8 = rng.r#gen();
let mut bytes = Vec::with_bounded_capacity(size as usize);
rng.fill_bytes(&mut bytes);
VerifierKey(Arc::new(Mutex::new(InnerVerifierKey::Uninitialized(bytes))))
}
}
impl From<MidnightVK> for VerifierKey {
fn from(vk: MidnightVK) -> Self {
VerifierKey(Arc::new(Mutex::new(InnerVerifierKey::Initialized(vk))))
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)] #[allow(clippy::large_enum_variant)]
pub(crate) enum InnerVerifierKey {
Uninitialized(Vec<u8>),
Invalid(Vec<u8>),
Initialized(MidnightVK),
}
impl Clone for VerifierKey {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl Deserializable for VerifierKey {
fn deserialize(
reader: &mut impl std::io::Read,
recursion_depth: u32,
) -> Result<Self, std::io::Error> {
const MAX_EXPECTED_SIZE: usize = 50_000;
let buf = <Vec<u8> as Deserializable>::deserialize(reader, recursion_depth)?;
if buf.len() > MAX_EXPECTED_SIZE {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Declared vk size {} exceeded permitted limit of {MAX_EXPECTED_SIZE}",
buf.len()
),
));
}
Ok(Self(Arc::new(Mutex::new(InnerVerifierKey::Uninitialized(
buf,
)))))
}
}
#[derive(Clone)]
struct DummyRelation;
impl Relation for DummyRelation {
type Instance = Vec<outer::Scalar>;
type Witness = ();
fn format_instance(
instance: &Self::Instance,
) -> Result<Vec<outer::Scalar>, midnight_proofs::plonk::Error> {
Ok(instance.clone())
}
fn circuit(
&self,
_std_lib: &midnight_zk_stdlib::ZkStdLib,
_layouter: &mut impl midnight_proofs::circuit::Layouter<outer::Scalar>,
_instance: midnight_proofs::circuit::Value<Self::Instance>,
_witness: midnight_proofs::circuit::Value<Self::Witness>,
) -> Result<(), midnight_proofs::plonk::Error> {
unimplemented!("should not attempt to execute dummy relation")
}
fn read_relation<R: io::Read>(_reader: &mut R) -> io::Result<Self> {
unimplemented!("should not attempt to read dummy relation")
}
fn write_relation<W: io::Write>(&self, _writer: &mut W) -> io::Result<()> {
unimplemented!("should not attempt to write dummy relation")
}
}
impl Serialize for VerifierKey {
fn serialize<S: Serializer>(&self, ser: S) -> Result<S::Ok, S::Error> {
let mut vec = Vec::new();
<VerifierKey as Serializable>::serialize(self, &mut vec).map_err(S::Error::custom)?;
ser.serialize_bytes(&vec)
}
}
impl<'de> Deserialize<'de> for VerifierKey {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes = serde_bytes::ByteBuf::deserialize(deserializer)?;
<VerifierKey as Deserializable>::deserialize(&mut &bytes[..], 0)
.map_err(serde::de::Error::custom)
}
}
#[allow(clippy::derived_hash_with_manual_eq)]
impl Hash for VerifierKey {
fn hash<H: Hasher>(&self, state: &mut H) {
let mut data = Vec::new();
Serializable::serialize(&self, &mut data).ok();
state.write(&data);
}
}
impl Serializable for VerifierKey {
fn serialize(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
let mut count = Count(0);
self.inner_serialize(&mut count).ok();
Serializable::serialize(&(count.0 as u64), writer)?;
self.inner_serialize(writer)
}
fn serialized_size(&self) -> usize {
let mut writer = Count(0);
self.inner_serialize(&mut writer).ok();
(writer.0 as u64).serialized_size() + writer.0
}
}
impl VerifierKey {
pub fn init(&self) -> Result<(), VerifyingError> {
self.force_init()?;
Ok(())
}
#[allow(dead_code)] pub(crate) fn force_init(&self) -> Result<MidnightVK, VerifyingError> {
let mut mutex = self.0.lock().expect("mutex is not poisoned");
let data = match &*mutex {
InnerVerifierKey::Initialized(key) => {
return Ok(key.clone());
}
InnerVerifierKey::Invalid(_) => {
return Err(anyhow::anyhow!("known invalid verifier key"));
}
InnerVerifierKey::Uninitialized(data) => data.clone(),
};
let reader = &mut &data[..];
let vk = MidnightVK::read(reader, SerdeFormat::Processed)
.map_err(|_| anyhow::anyhow!("problem reading the verifier key"))?;
*mutex = InnerVerifierKey::Initialized(vk.clone());
Ok(vk)
}
fn inner_serialize<W: std::io::Write>(&self, mut writer: W) -> std::io::Result<()> {
match &*self.0.lock().expect("mutex is not poisoned") {
InnerVerifierKey::Uninitialized(data) | InnerVerifierKey::Invalid(data) => {
writer.write_all(data)
}
InnerVerifierKey::Initialized(key) => key.write(&mut writer, SerdeFormat::Processed),
}
}
pub fn verify<F: Iterator<Item = Fr>>(
&self,
params: &ParamsVerifier,
proof: &Proof,
statement: F,
) -> Result<(), VerifyingError> {
let vk = self.force_init()?;
let pi = statement.map(|f| f.0).collect::<Vec<_>>();
trace!(statement = ?pi, "verifying proof against statement");
midnight_zk_stdlib::verify::<DummyRelation, TranscriptHash>(
¶ms.0, &vk, &pi, None, &proof.0,
)
.map_err(|_| anyhow::anyhow!("Invalid proof"))
}
#[cfg(feature = "mock-verify")]
pub fn mock_verify<F: Iterator<Item = Fr>>(&self, statement: F) -> Result<(), VerifyingError> {
let pi_len = statement.count();
crate::mock_verify::mock_verify_for(pi_len)
}
pub fn batch_verify<
'a,
F: Iterator<Item = Fr>,
V: Iterator<Item = (&'a VerifierKey, &'a Proof, F)>,
>(
params: &ParamsVerifier,
parts: V,
) -> Result<(), VerifyingError> {
use midnight_zk_stdlib::batch_verify;
let mut vks = vec![];
let mut pis = vec![];
let mut proofs = vec![];
for (vk, proof, stmt) in parts.into_iter() {
let pi = stmt.map(|f| f.0).collect::<Vec<_>>();
let vk = vk.force_init()?;
vks.push(vk);
pis.push(pi);
proofs.push(proof.0.clone());
}
batch_verify::<TranscriptHash>(¶ms.0, &vks, &pis, &proofs)
.map_err(|_| anyhow::anyhow!("Invalid proof"))
}
#[cfg(feature = "mock-verify")]
pub fn mock_batch_verify<
'a,
F: Iterator<Item = Fr>,
V: Iterator<Item = (&'a VerifierKey, &'a Proof, F)>,
>(
parts: V,
) -> Result<(), VerifyingError> {
for (vk, _proof, stmt) in parts {
vk.mock_verify(stmt)?;
}
Ok(())
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serializable)]
#[cfg_attr(feature = "proptest", derive(Arbitrary))]
pub struct KeyLocation(pub Cow<'static, str>);
impl Zeroize for KeyLocation {
fn zeroize(&mut self) {
if let Cow::Owned(s) = &mut self.0 {
s.zeroize();
}
self.0 = Cow::Borrowed("");
}
}
impl Tagged for KeyLocation {
fn tag() -> Cow<'static, str> {
Cow::Borrowed("string")
}
fn tag_unique_factor() -> String {
"string".into()
}
}
#[derive(Serializable)]
#[tag = "wrapped-ir"]
pub struct WrappedIr(pub Vec<u8>);
tag_enforcement_test!(WrappedIr);
#[derive(Clone, Serializable)]
#[tag = "proving-data"]
pub struct ProvingKeyMaterial {
pub prover_key: Vec<u8>,
pub verifier_key: Vec<u8>,
pub ir_source: Vec<u8>,
}
tag_enforcement_test!(ProvingKeyMaterial);
pub trait Resolver {
#[allow(async_fn_in_trait)]
async fn resolve_key(&self, key: KeyLocation) -> io::Result<Option<ProvingKeyMaterial>>;
}
#[allow(async_fn_in_trait)]
pub trait ProvingProvider {
async fn check(&self, preimage: &ProofPreimage) -> Result<Vec<Option<usize>>, anyhow::Error>;
async fn prove(
self,
preimage: &ProofPreimage,
overwrite_binding_input: Option<Fr>,
) -> Result<Proof, anyhow::Error>;
fn split(&mut self) -> Self;
}
#[derive(
Clone,
Debug,
PartialEq,
Eq,
PartialOrd,
Ord,
Serializable,
Hash,
Storable,
Zeroize,
ZeroizeOnDrop,
)]
#[storable(base)]
#[tag = "proof-preimage"]
#[cfg_attr(feature = "proptest", derive(Arbitrary))]
pub struct ProofPreimage {
pub inputs: Vec<Fr>,
pub private_transcript: Vec<Fr>,
pub public_transcript_inputs: Vec<Fr>,
pub public_transcript_outputs: Vec<Fr>,
pub binding_input: Fr,
pub communications_commitment: Option<(Fr, Fr)>,
pub key_location: KeyLocation,
}
tag_enforcement_test!(ProofPreimage);
impl ProofPreimage {
#[allow(unused_variables)]
pub fn check(&self, ir: &impl Zkir) -> Result<Vec<Option<usize>>, ProvingError> {
ir.check(self)
}
#[allow(unreachable_code, unused_variables)]
pub async fn prove<Z: Zkir>(
&self,
rng: impl Rng + CryptoRng,
params: &impl ParamsProverProvider,
resolver: &impl Resolver,
) -> Result<(Proof, Vec<Option<usize>>), ProvingError> {
let proof_data = resolver
.resolve_key(self.key_location.clone())
.await?
.ok_or(anyhow::Error::msg(format!(
"failed to find proving key for '{}'",
&self.key_location.0
)))?;
let ir = tagged_deserialize::<Z>(&mut &proof_data.ir_source[..])?;
let verifier_key = tagged_deserialize::<VerifierKey>(&mut &proof_data.verifier_key[..])?;
let prover_key = tagged_deserialize::<ProverKey<Z>>(&mut &proof_data.prover_key[..])?;
let (proof, pis, pi_skips) = ir.prove(rng, params, prover_key, self).await?;
debug!("proof created; verifying to make sure");
let k = verifier_key.force_init()?.k();
if let Err(e) = verifier_key.verify(
¶ms.get_params(k).await?.as_verifier(),
&proof,
pis.iter().copied(),
) {
error!(error = ?e, ?pis, ?ir, "self-verification failed! This may be a bug, check that your keys match!");
return Err(e);
}
debug!("proof ok");
Ok((proof, pi_skips))
}
}
impl PartialEq for VerifierKey {
fn eq(&self, other: &Self) -> bool {
let mut self_ser = Vec::new();
let mut other_ser = Vec::new();
Serializable::serialize(self, &mut self_ser).expect("In-memory serialization must succeed");
Serializable::serialize(other, &mut other_ser)
.expect("In-memory serialization must succeed");
self_ser == other_ser
}
}
impl Eq for VerifierKey {}
impl PartialOrd for VerifierKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for VerifierKey {
fn cmp(&self, other: &Self) -> Ordering {
let mut self_ser = Vec::new();
let mut other_ser = Vec::new();
Serializable::serialize(self, &mut self_ser).expect("In-memory serialization must succeed");
Serializable::serialize(other, &mut other_ser)
.expect("In-memory serialization must succeed");
self_ser.cmp(&other_ser)
}
}
pub type ProvingError = anyhow::Error;
pub type VerifyingError = anyhow::Error;