use crate::StatusCode;
use anyhow::Result;
use itertools::Itertools;
use num_bigint::BigUint;
use slop_algebra::PrimeField32;
use sp1_core_machine::io::SP1Stdin;
use sp1_hypercube::{air::PublicValues, PROOF_MAX_NUM_PVS};
use sp1_primitives::types::Elf;
use sp1_prover::{
verify::verify_public_values, worker::SP1NodeCore, SP1VerifyingKey, SP1_CIRCUIT_VERSION,
};
use sp1_recursion_executor::RecursionPublicValues;
use std::{
borrow::Borrow,
fmt,
future::{Future, IntoFuture},
str::FromStr,
};
use thiserror::Error;
mod execute;
mod prove;
pub use execute::ExecuteRequest;
pub(crate) use prove::BaseProveRequest;
pub use prove::{ProveRequest, SP1ProvingKey};
use crate::{SP1Proof, SP1ProofWithPublicValues};
pub trait Prover: Clone + Send + Sync {
type ProvingKey: ProvingKey;
type Error: fmt::Debug + fmt::Display;
type ProveRequest<'a>: ProveRequest<'a, Self>
where
Self: 'a;
fn inner(&self) -> &SP1NodeCore;
fn version(&self) -> &str {
SP1_CIRCUIT_VERSION
}
fn setup(&self, elf: Elf) -> impl SendFutureResult<Self::ProvingKey, Self::Error>;
fn prove<'a>(&'a self, pk: &'a Self::ProvingKey, stdin: SP1Stdin) -> Self::ProveRequest<'a>;
fn execute(&self, elf: Elf, stdin: SP1Stdin) -> ExecuteRequest<'_, Self> {
ExecuteRequest::new(self, elf, stdin)
}
fn verify(
&self,
proof: &SP1ProofWithPublicValues,
vkey: &SP1VerifyingKey,
status_code: Option<StatusCode>,
) -> Result<(), SP1VerificationError> {
verify_proof(self.inner(), self.version(), proof, vkey, status_code)
}
}
pub trait ProvingKey: Clone + Send + Sync {
fn verifying_key(&self) -> &SP1VerifyingKey;
fn elf(&self) -> &Elf;
}
pub trait SendFutureResult<T, E>: Future<Output = Result<T, E>> + Send {}
impl<F, T, E> SendFutureResult<T, E> for F where F: Future<Output = Result<T, E>> + Send {}
pub trait IntoSendFutureResult<T, E>: IntoFuture<Output = Result<T, E>> + Send {}
impl<F, T, E> IntoSendFutureResult<T, E> for F where F: IntoFuture<Output = Result<T, E>> + Send {}
#[derive(Error, Debug)]
pub enum SP1VerificationError {
#[error("Invalid public values")]
InvalidPublicValues,
#[error("Version mismatch: {0}")]
VersionMismatch(String),
#[error("Core machine verification error: {0}")]
Core(anyhow::Error),
#[error("Recursion verification error: {0}")]
Recursion(anyhow::Error),
#[error("Plonk verification error: {0}")]
Plonk(anyhow::Error),
#[error("Groth16 verification error: {0}")]
Groth16(anyhow::Error),
#[error("Unexpected error: {0:?}")]
Other(anyhow::Error),
#[error("Unexpected exit code: {0}")]
UnexpectedExitCode(u32),
}
pub(crate) fn verify_proof(
node: &SP1NodeCore,
version: &str,
bundle: &SP1ProofWithPublicValues,
vkey: &SP1VerifyingKey,
status_code: Option<StatusCode>,
) -> Result<(), SP1VerificationError> {
let status_code = status_code.unwrap_or(StatusCode::SUCCESS);
if bundle.sp1_version != version {
return Err(SP1VerificationError::VersionMismatch(bundle.sp1_version.clone()));
}
match &bundle.proof {
SP1Proof::Core(proof) => {
if proof.is_empty() {
return Err(SP1VerificationError::Core(anyhow::anyhow!("Empty core proof")));
}
if proof.last().unwrap().public_values.len() != PROOF_MAX_NUM_PVS {
return Err(SP1VerificationError::InvalidPublicValues);
}
let public_values: &PublicValues<[_; 4], [_; 3], [_; 4], _> =
proof.last().unwrap().public_values.as_slice().borrow();
if !status_code.is_accepted_code(public_values.exit_code.as_canonical_u32()) {
return Err(SP1VerificationError::UnexpectedExitCode(
public_values.exit_code.as_canonical_u32(),
));
}
let committed_value_digest_bytes = public_values
.committed_value_digest
.iter()
.flat_map(|w| w.iter().map(|x| x.as_canonical_u32() as u8))
.collect_vec();
if committed_value_digest_bytes != bundle.public_values.hash()
&& committed_value_digest_bytes != bundle.public_values.blake3_hash()
{
tracing::error!("committed value digest doesnt match");
return Err(SP1VerificationError::InvalidPublicValues);
}
}
SP1Proof::Compressed(proof) => {
if proof.proof.public_values.len() != PROOF_MAX_NUM_PVS {
return Err(SP1VerificationError::InvalidPublicValues);
}
let public_values: &RecursionPublicValues<_> =
proof.proof.public_values.as_slice().borrow();
if !status_code.is_accepted_code(public_values.exit_code.as_canonical_u32()) {
return Err(SP1VerificationError::UnexpectedExitCode(
public_values.exit_code.as_canonical_u32(),
));
}
let committed_value_digest_bytes = public_values
.committed_value_digest
.iter()
.flat_map(|w| w.iter().map(|x| x.as_canonical_u32() as u8))
.collect_vec();
if committed_value_digest_bytes != bundle.public_values.hash()
&& committed_value_digest_bytes != bundle.public_values.blake3_hash()
{
return Err(SP1VerificationError::InvalidPublicValues);
}
}
SP1Proof::Plonk(proof) => {
let exit_code = BigUint::from_str(&proof.public_inputs[2])
.map_err(|e| SP1VerificationError::Plonk(anyhow::anyhow!(e)))?;
let exit_code_u32 =
u32::try_from(&exit_code).map_err(|_| SP1VerificationError::InvalidPublicValues)?;
if !status_code.is_accepted_code(exit_code_u32) {
return Err(SP1VerificationError::UnexpectedExitCode(exit_code_u32));
}
let public_values_hash = BigUint::from_str(&proof.public_inputs[1])
.map_err(|e| SP1VerificationError::Plonk(anyhow::anyhow!(e)))?;
verify_public_values(&bundle.public_values, public_values_hash)
.map_err(SP1VerificationError::Plonk)?;
}
SP1Proof::Groth16(proof) => {
let exit_code = BigUint::from_str(&proof.public_inputs[2])
.map_err(|e| SP1VerificationError::Plonk(anyhow::anyhow!(e)))?;
let exit_code_u32 =
u32::try_from(&exit_code).map_err(|_| SP1VerificationError::InvalidPublicValues)?;
if !status_code.is_accepted_code(exit_code_u32) {
return Err(SP1VerificationError::UnexpectedExitCode(exit_code_u32));
}
let public_values_hash = BigUint::from_str(&proof.public_inputs[1])
.map_err(|e| SP1VerificationError::Groth16(anyhow::anyhow!(e)))?;
verify_public_values(&bundle.public_values, public_values_hash)
.map_err(SP1VerificationError::Groth16)?;
}
}
node.verify(vkey, &bundle.proof).map_err(|e| match bundle.proof {
SP1Proof::Core(_) => SP1VerificationError::Core(e),
SP1Proof::Compressed(_) => SP1VerificationError::Recursion(e),
SP1Proof::Plonk(_) => SP1VerificationError::Plonk(e),
SP1Proof::Groth16(_) => SP1VerificationError::Groth16(e),
})
}