use alloc::vec::Vec;
use core::borrow::Borrow;
use slop_algebra::{AbstractField, PrimeField32};
use slop_symmetric::CryptographicHasher;
use sp1_hypercube::{
verify_merkle_proof, HashableKey, InnerSC, MachineVerifier, MachineVerifierError,
SP1RecursionProof, ShardVerifier, DIGEST_SIZE, PROOF_MAX_NUM_PVS,
};
use sp1_primitives::{fri_params::recursion_fri_config, poseidon2_hasher, SP1Field};
use sp1_recursion_executor::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH};
use super::CompressedError;
use crate::{
blake3_hash,
compressed::{RECURSION_LOG_STACKING_HEIGHT, RECURSION_MAX_LOG_ROW_COUNT},
};
type GC = sp1_primitives::SP1GlobalContext;
type C = sp1_hypercube::SP1PcsProofInner;
pub const COMPRESS_DEGREE: usize = 3;
pub type CompressAir<SP1Field> = sp1_recursion_machine::RecursionAir<SP1Field, COMPRESS_DEGREE, 2>;
pub struct SP1CompressedVerifier {
verifier: MachineVerifier<GC, InnerSC<CompressAir<SP1Field>>>,
vk_merkle_root: [SP1Field; DIGEST_SIZE],
}
impl Default for SP1CompressedVerifier {
fn default() -> Self {
let compress_log_stacking_height = RECURSION_LOG_STACKING_HEIGHT;
let compress_max_log_row_count = RECURSION_MAX_LOG_ROW_COUNT;
let machine = CompressAir::<SP1Field>::compress_machine();
let recursion_shard_verifier = ShardVerifier::from_basefold_parameters(
recursion_fri_config(),
compress_log_stacking_height,
compress_max_log_row_count,
machine.clone(),
);
let verifier = MachineVerifier::new(recursion_shard_verifier);
let vk_merkle_root = crate::VerifierRecursionVks::default().root();
Self { verifier, vk_merkle_root }
}
}
impl SP1CompressedVerifier {
pub fn new() -> Self {
Self::default()
}
pub fn recursion_public_values_digest(
&self,
public_values: &RecursionPublicValues<SP1Field>,
) -> [SP1Field; 8] {
let hasher = poseidon2_hasher();
hasher.hash_slice(&public_values.as_array()[0..NUM_PV_ELMS_TO_HASH])
}
pub fn is_recursion_public_values_valid(
&self,
public_values: &RecursionPublicValues<SP1Field>,
) -> bool {
let expected_digest = self.recursion_public_values_digest(public_values);
public_values.digest.iter().copied().eq(expected_digest)
}
pub fn verify_compressed(
&self,
proof: &SP1RecursionProof<GC, C>,
vkey_hash: &[SP1Field; 8],
) -> Result<(), CompressedError> {
let SP1RecursionProof { vk: compress_vk, proof, vk_merkle_proof } = proof;
let mut challenger = self.verifier.challenger();
compress_vk.observe_into(&mut challenger);
self.verifier
.verify_shard(compress_vk, proof, &mut challenger)
.map_err(MachineVerifierError::InvalidShardProof)?;
if proof.public_values.len() != PROOF_MAX_NUM_PVS {
return Err(MachineVerifierError::InvalidPublicValues("invalid public values length"))?;
}
let public_values: &RecursionPublicValues<_> = proof.public_values.as_slice().borrow();
if !self.is_recursion_public_values_valid(public_values) {
return Err(MachineVerifierError::InvalidPublicValues(
"recursion public values are invalid",
)
.into());
}
verify_merkle_proof(vk_merkle_proof, compress_vk.hash_koalabear(), self.vk_merkle_root)
.map_err(CompressedError::InvalidVkey)?;
if public_values.vk_root != self.vk_merkle_root {
return Err(MachineVerifierError::InvalidPublicValues("vk merkle root mismatch"))?;
}
if public_values.is_complete != SP1Field::one() {
return Err(MachineVerifierError::InvalidPublicValues("is_complete is not 1").into());
}
if public_values.sp1_vk_digest != *vkey_hash {
return Err(MachineVerifierError::InvalidPublicValues("sp1 vk hash mismatch").into());
}
Ok(())
}
pub fn verify_compressed_with_public_values(
&self,
proof: &SP1RecursionProof<GC, C>,
sp1_public_inputs: &[u8],
vkey_hash: &[SP1Field; 8],
) -> Result<(), CompressedError> {
self.verify_compressed(proof, vkey_hash)?;
let SP1RecursionProof { proof, .. } = proof;
let public_values: &RecursionPublicValues<_> = proof.public_values.as_slice().borrow();
let committed_value_digest_bytes = public_values
.committed_value_digest
.iter()
.flat_map(|w| w.iter().map(|x| x.as_canonical_u32() as u8))
.collect::<Vec<_>>();
let sha256_digest = crate::sha256_hash(sp1_public_inputs);
let blake3_digest = blake3_hash(sp1_public_inputs);
if committed_value_digest_bytes.as_slice() != sha256_digest.as_slice()
&& committed_value_digest_bytes.as_slice() != blake3_digest.as_slice()
{
return Err(CompressedError::PublicValuesMismatch);
}
Ok(())
}
}