use std::{
array,
borrow::{Borrow, BorrowMut},
marker::PhantomData,
mem::MaybeUninit,
};
use itertools::Itertools;
use slop_air::Air;
use slop_algebra::{AbstractField, PrimeField32};
use slop_challenger::IopCtx;
use serde::{Deserialize, Serialize};
use sp1_core_machine::riscv::MAX_LOG_NUMBER_OF_SHARDS;
use sp1_recursion_compiler::ir::{Builder, Felt, IrIter};
use sp1_primitives::{SP1Field, SP1GlobalContext};
use sp1_recursion_executor::{RecursionPublicValues, RECURSIVE_PROOF_NUM_PV_ELTS};
use sp1_hypercube::{
air::{MachineAir, ShardRange, POSEIDON_NUM_WORDS, PV_DIGEST_NUM_WORDS},
MachineVerifyingKey, ShardProof, DIGEST_SIZE,
};
use crate::{
challenger::CanObserveVariable,
machine::{
assert_complete, assert_recursion_public_values_valid, recursion_public_values_digest,
root_public_values_digest,
},
shard::{MachineVerifyingKeyVariable, RecursiveShardVerifier, ShardProofVariable},
zerocheck::RecursiveVerifierConstraintFolder,
CircuitConfig, SP1FieldConfigVariable,
};
use sp1_recursion_compiler::circuit::CircuitV2Builder;
use super::InnerVal;
#[derive(Debug, Clone, Copy)]
pub struct SP1CompressVerifier<C, SC, A> {
_phantom: PhantomData<(C, SC, A)>,
}
pub enum PublicValuesOutputDigest {
Reduce,
Root,
}
#[allow(clippy::type_complexity)]
pub struct SP1ShapedWitnessVariable<C: CircuitConfig, GC: SP1FieldConfigVariable<C>> {
pub vks_and_proofs: Vec<(MachineVerifyingKeyVariable<C, GC>, ShardProofVariable<C, GC>)>,
pub is_complete: Felt<SP1Field>,
}
pub type VkAndProof<GC, Proof> = (MachineVerifyingKey<GC>, ShardProof<GC, Proof>);
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "ShardProof<GC,Proof>: Serialize"))]
#[serde(bound(deserialize = "ShardProof<GC,Proof>: Deserialize<'de>"))]
pub struct SP1ShapedWitnessValues<GC: IopCtx, Proof> {
pub vks_and_proofs: Vec<VkAndProof<GC, Proof>>,
pub is_complete: bool,
}
impl<GC: IopCtx, Proof> SP1ShapedWitnessValues<GC, Proof> {
pub fn range(&self) -> ShardRange
where
GC::F: PrimeField32,
{
let start_pv: &RecursionPublicValues<GC::F> =
self.vks_and_proofs[0].1.public_values.as_slice().borrow();
let end_pv: &RecursionPublicValues<GC::F> =
self.vks_and_proofs[self.vks_and_proofs.len() - 1].1.public_values.as_slice().borrow();
let start = start_pv.range().start();
let end = end_pv.range().end();
(start..end).into()
}
}
impl<C, SC, A> SP1CompressVerifier<C, SC, A>
where
C: CircuitConfig<Bit = Felt<SP1Field>>,
A: MachineAir<InnerVal> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
{
pub fn verify(
builder: &mut Builder<C>,
machine: &RecursiveShardVerifier<SP1GlobalContext, A, C>,
input: SP1ShapedWitnessVariable<C, SP1GlobalContext>,
vk_root: [Felt<SP1Field>; DIGEST_SIZE],
kind: PublicValuesOutputDigest,
) {
let SP1ShapedWitnessVariable { vks_and_proofs, is_complete } = input;
let mut reduce_public_values_stream: Vec<Felt<_>> = (0..RECURSIVE_PROOF_NUM_PV_ELTS)
.map(|_| unsafe { MaybeUninit::zeroed().assume_init() })
.collect();
let compress_public_values: &mut RecursionPublicValues<_> =
reduce_public_values_stream.as_mut_slice().borrow_mut();
assert!(!vks_and_proofs.is_empty());
let mut sp1_vk_digest: [Felt<_>; DIGEST_SIZE] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut pc: [Felt<_>; 3] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut current_exit_code: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
let mut current_timestamp: [Felt<_>; 4] = array::from_fn(|_| builder.uninit());
let mut committed_value_digest: [[Felt<_>; 4]; PV_DIGEST_NUM_WORDS] =
array::from_fn(|_| array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() }));
let mut deferred_proofs_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut deferred_proof_index: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
let mut reconstruct_deferred_digest: [Felt<_>; POSEIDON_NUM_WORDS] =
core::array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut global_cumulative_sums = Vec::new();
let mut init_addr: [Felt<_>; 3] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut finalize_addr: [Felt<_>; 3] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut init_page_idx: [Felt<_>; 3] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut finalize_page_idx: [Felt<_>; 3] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
let mut commit_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
let mut commit_deferred_syscall: Felt<_> = unsafe { MaybeUninit::zeroed().assume_init() };
let mut contains_first_shard: Felt<_> = builder.eval(SP1Field::zero());
let mut num_included_shard: Felt<_> = builder.eval(SP1Field::zero());
let mut proof_nonce: [Felt<_>; 4] =
array::from_fn(|_| unsafe { MaybeUninit::zeroed().assume_init() });
vks_and_proofs.iter().ir_par_map_collect::<Vec<_>, _, _>(
builder,
|builder, (vk, shard_proof)| {
let mut challenger = SP1GlobalContext::challenger_variable(builder);
challenger.observe(builder, vk.preprocessed_commit);
challenger.observe_slice(builder, vk.pc_start);
challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.x.0);
challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.y.0);
challenger.observe(builder, vk.enable_untrusted_programs);
let zero: Felt<_> = builder.eval(SP1Field::zero());
for _ in 0..6 {
challenger.observe(builder, zero);
}
machine.verify_shard(builder, vk, shard_proof, &mut challenger);
},
);
for (i, (_, shard_proof)) in vks_and_proofs.into_iter().enumerate() {
let current_public_values: &RecursionPublicValues<Felt<SP1Field>> =
shard_proof.public_values.as_slice().borrow();
assert_recursion_public_values_valid::<C, SP1GlobalContext>(
builder,
current_public_values,
);
for (expected, actual) in vk_root.iter().zip(current_public_values.vk_root.iter()) {
builder.assert_felt_eq(*expected, *actual);
}
C::range_check_felt(
builder,
current_public_values.num_included_shard,
MAX_LOG_NUMBER_OF_SHARDS,
);
builder.assert_felt_eq(
current_public_values.contains_first_shard
* (current_public_values.contains_first_shard - SP1Field::one()),
SP1Field::zero(),
);
num_included_shard =
builder.eval(num_included_shard + current_public_values.num_included_shard);
contains_first_shard =
builder.eval(contains_first_shard + current_public_values.contains_first_shard);
global_cumulative_sums.push(current_public_values.global_cumulative_sum);
if i == 0 {
compress_public_values.prev_committed_value_digest =
current_public_values.prev_committed_value_digest;
committed_value_digest = current_public_values.prev_committed_value_digest;
compress_public_values.prev_deferred_proofs_digest =
current_public_values.prev_deferred_proofs_digest;
deferred_proofs_digest = current_public_values.prev_deferred_proofs_digest;
compress_public_values.prev_deferred_proof =
current_public_values.prev_deferred_proof;
deferred_proof_index = current_public_values.prev_deferred_proof;
compress_public_values.pc_start = current_public_values.pc_start;
pc = current_public_values.pc_start;
compress_public_values.initial_timestamp = current_public_values.initial_timestamp;
current_timestamp = current_public_values.initial_timestamp;
compress_public_values.previous_init_addr =
current_public_values.previous_init_addr;
init_addr = current_public_values.previous_init_addr;
compress_public_values.previous_finalize_addr =
current_public_values.previous_finalize_addr;
finalize_addr = current_public_values.previous_finalize_addr;
compress_public_values.previous_init_page_idx =
current_public_values.previous_init_page_idx;
init_page_idx = current_public_values.previous_init_page_idx;
compress_public_values.previous_finalize_page_idx =
current_public_values.previous_finalize_page_idx;
finalize_page_idx = current_public_values.previous_finalize_page_idx;
compress_public_values.start_reconstruct_deferred_digest =
current_public_values.start_reconstruct_deferred_digest;
reconstruct_deferred_digest =
current_public_values.start_reconstruct_deferred_digest;
compress_public_values.prev_exit_code = current_public_values.prev_exit_code;
current_exit_code = current_public_values.prev_exit_code;
compress_public_values.prev_commit_syscall =
current_public_values.prev_commit_syscall;
commit_syscall = current_public_values.prev_commit_syscall;
compress_public_values.prev_commit_deferred_syscall =
current_public_values.prev_commit_deferred_syscall;
commit_deferred_syscall = current_public_values.prev_commit_deferred_syscall;
compress_public_values.sp1_vk_digest = current_public_values.sp1_vk_digest;
sp1_vk_digest = current_public_values.sp1_vk_digest;
compress_public_values.proof_nonce = current_public_values.proof_nonce;
proof_nonce = current_public_values.proof_nonce;
}
for (word, current_word) in committed_value_digest
.iter()
.zip_eq(current_public_values.prev_committed_value_digest.iter())
{
for (limb, current_limb) in word.iter().zip_eq(current_word.iter()) {
builder.assert_felt_eq(*limb, *current_limb);
}
}
committed_value_digest = current_public_values.committed_value_digest;
for (limb, current_limb) in deferred_proofs_digest
.iter()
.zip_eq(current_public_values.prev_deferred_proofs_digest.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
deferred_proofs_digest = current_public_values.deferred_proofs_digest;
builder.assert_felt_eq(deferred_proof_index, current_public_values.prev_deferred_proof);
deferred_proof_index = current_public_values.deferred_proof;
for (limb, current_limb) in pc.iter().zip(current_public_values.pc_start.iter()) {
builder.assert_felt_eq(*limb, *current_limb);
}
pc = current_public_values.next_pc;
for (limb, current_limb) in
current_timestamp.iter().zip(current_public_values.initial_timestamp.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
current_timestamp = current_public_values.last_timestamp;
for (limb, current_limb) in
init_addr.iter().zip(current_public_values.previous_init_addr.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
init_addr = current_public_values.last_init_addr;
for (limb, current_limb) in
finalize_addr.iter().zip(current_public_values.previous_finalize_addr.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
finalize_addr = current_public_values.last_finalize_addr;
for (limb, current_limb) in
init_page_idx.iter().zip(current_public_values.previous_init_page_idx.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
init_page_idx = current_public_values.last_init_page_idx;
for (limb, current_limb) in finalize_page_idx
.iter()
.zip(current_public_values.previous_finalize_page_idx.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
finalize_page_idx = current_public_values.last_finalize_page_idx;
for (digest, current_digest) in reconstruct_deferred_digest
.iter()
.zip_eq(current_public_values.start_reconstruct_deferred_digest.iter())
{
builder.assert_felt_eq(*digest, *current_digest);
}
reconstruct_deferred_digest = current_public_values.end_reconstruct_deferred_digest;
builder.assert_felt_eq(current_exit_code, current_public_values.prev_exit_code);
current_exit_code = current_public_values.exit_code;
builder.assert_felt_eq(commit_syscall, current_public_values.prev_commit_syscall);
commit_syscall = current_public_values.commit_syscall;
builder.assert_felt_eq(
commit_deferred_syscall,
current_public_values.prev_commit_deferred_syscall,
);
commit_deferred_syscall = current_public_values.commit_deferred_syscall;
for (digest, current) in sp1_vk_digest.iter().zip(current_public_values.sp1_vk_digest) {
builder.assert_felt_eq(*digest, current);
}
for (limb, current_limb) in
proof_nonce.iter().zip(current_public_values.proof_nonce.iter())
{
builder.assert_felt_eq(*limb, *current_limb);
}
}
C::range_check_felt(builder, num_included_shard, MAX_LOG_NUMBER_OF_SHARDS);
builder.assert_felt_eq(
contains_first_shard * (contains_first_shard - SP1Field::one()),
SP1Field::zero(),
);
let global_cumulative_sum = builder.sum_digest_v2(global_cumulative_sums);
compress_public_values.committed_value_digest = committed_value_digest;
compress_public_values.deferred_proofs_digest = deferred_proofs_digest;
compress_public_values.next_pc = pc;
compress_public_values.last_timestamp = current_timestamp;
compress_public_values.last_init_addr = init_addr;
compress_public_values.last_finalize_addr = finalize_addr;
compress_public_values.last_init_page_idx = init_page_idx;
compress_public_values.last_finalize_page_idx = finalize_page_idx;
compress_public_values.end_reconstruct_deferred_digest = reconstruct_deferred_digest;
compress_public_values.deferred_proof = deferred_proof_index;
compress_public_values.sp1_vk_digest = sp1_vk_digest;
compress_public_values.vk_root = vk_root;
compress_public_values.global_cumulative_sum = global_cumulative_sum;
compress_public_values.contains_first_shard = contains_first_shard;
compress_public_values.num_included_shard = num_included_shard;
compress_public_values.is_complete = is_complete;
compress_public_values.exit_code = current_exit_code;
compress_public_values.commit_syscall = commit_syscall;
compress_public_values.commit_deferred_syscall = commit_deferred_syscall;
compress_public_values.proof_nonce = proof_nonce;
compress_public_values.digest = match kind {
PublicValuesOutputDigest::Reduce => {
recursion_public_values_digest::<C, SP1GlobalContext>(
builder,
compress_public_values,
)
}
PublicValuesOutputDigest::Root => {
root_public_values_digest::<C, SP1GlobalContext>(builder, compress_public_values)
}
};
assert_complete(builder, compress_public_values, is_complete);
SP1GlobalContext::commit_recursion_public_values(builder, *compress_public_values);
}
}