use std::{
borrow::{Borrow, BorrowMut},
marker::PhantomData,
};
use itertools::Itertools;
use slop_air::Air;
use slop_algebra::{AbstractField, PrimeField32};
use slop_challenger::IopCtx;
use sp1_primitives::{SP1Field, SP1GlobalContext};
use serde::{Deserialize, Serialize};
use sp1_core_machine::riscv::RiscvAir;
use sp1_hypercube::air::{PublicValues, SP1CorePublicValues};
use sp1_hypercube::{air::ShardRange, MachineVerifyingKey, ShardProof};
use sp1_recursion_compiler::{
circuit::CircuitV2Builder,
ir::{Builder, Config, Felt},
};
use sp1_recursion_executor::{RecursionPublicValues, DIGEST_SIZE, RECURSIVE_PROOF_NUM_PV_ELTS};
use crate::{
challenger::CanObserveVariable,
machine::{assert_complete, recursion_public_values_digest},
shard::{MachineVerifyingKeyVariable, RecursiveShardVerifier, ShardProofVariable},
zerocheck::RecursiveVerifierConstraintFolder,
CircuitConfig, SP1FieldConfigVariable,
};
pub struct SP1RecursionWitnessVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C>> {
pub vk: MachineVerifyingKeyVariable<C, SC>,
pub shard_proofs: Vec<ShardProofVariable<C, SC>>,
pub reconstruct_deferred_digest: [Felt<SP1Field>; DIGEST_SIZE],
pub num_deferred_proofs: Felt<SP1Field>,
pub is_complete: Felt<SP1Field>,
pub vk_root: [Felt<SP1Field>; DIGEST_SIZE],
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "ShardProof<GC,Proof>: Serialize"))]
#[serde(bound(deserialize = "ShardProof<GC,Proof>: Deserialize<'de>"))]
pub struct SP1NormalizeWitnessValues<GC: IopCtx, Proof> {
pub vk: MachineVerifyingKey<GC>,
pub shard_proofs: Vec<ShardProof<GC, Proof>>,
pub is_complete: bool,
pub vk_root: [GC::F; DIGEST_SIZE],
pub reconstruct_deferred_digest: [GC::F; 8],
pub num_deferred_proofs: GC::F,
}
impl<GC: IopCtx, Proof> SP1NormalizeWitnessValues<GC, Proof> {
pub fn range(&self) -> ShardRange
where
GC::F: PrimeField32,
{
let start_pv: &SP1CorePublicValues<GC::F> =
self.shard_proofs[0].public_values.as_slice().borrow();
let end_pv: &SP1CorePublicValues<GC::F> =
self.shard_proofs[self.shard_proofs.len() - 1].public_values.as_slice().borrow();
let start = start_pv.range().start();
let end = end_pv.range().end();
let mut range: ShardRange = (start..end).into();
let num_deferred_proofs = self.num_deferred_proofs.as_canonical_u32() as u64;
range.deferred_proof_range = (num_deferred_proofs, num_deferred_proofs);
range
}
}
#[derive(Debug, Clone, Copy)]
pub struct SP1RecursiveVerifier<C: Config> {
_phantom: PhantomData<C>,
}
impl<C> SP1RecursiveVerifier<C>
where
C: CircuitConfig<Bit = Felt<SP1Field>>,
{
pub fn verify(
builder: &mut Builder<C>,
machine: &RecursiveShardVerifier<SP1GlobalContext, RiscvAir<SP1Field>, C>,
input: SP1RecursionWitnessVariable<C, SP1GlobalContext>,
) where
RiscvAir<SP1Field>: for<'b> Air<RecursiveVerifierConstraintFolder<'b>>,
{
let SP1RecursionWitnessVariable {
vk,
shard_proofs,
is_complete,
vk_root,
reconstruct_deferred_digest,
num_deferred_proofs,
} = input;
assert!(shard_proofs.len() == 1);
let shard_proof = &shard_proofs[0];
let mut global_cumulative_sums = Vec::new();
let public_values: &PublicValues<[Felt<_>; 4], [Felt<_>; 3], [Felt<_>; 4], Felt<_>> =
shard_proof.public_values.as_slice().borrow();
for (pc, vk_pc) in public_values.pc_start.iter().zip_eq(vk.pc_start.iter()) {
builder.assert_felt_eq(
public_values.is_first_execution_shard * (*pc - *vk_pc),
SP1Field::zero(),
);
}
global_cumulative_sums.push(builder.select_global_cumulative_sum(
public_values.is_first_execution_shard,
vk.initial_global_cumulative_sum,
));
let mut challenger = SP1GlobalContext::challenger_variable(builder);
challenger.observe(builder, vk.preprocessed_commit);
challenger.observe_slice(builder, vk.pc_start);
challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.x.0);
challenger.observe_slice(builder, vk.initial_global_cumulative_sum.0.y.0);
challenger.observe(builder, vk.enable_untrusted_programs);
let zero: Felt<_> = builder.eval(SP1Field::zero());
for _ in 0..6 {
challenger.observe(builder, zero);
}
tracing::debug_span!("verify shard")
.in_scope(|| machine.verify_shard(builder, &vk, shard_proof, &mut challenger));
builder.assert_felt_eq(
public_values.is_untrusted_programs_enabled,
vk.enable_untrusted_programs,
);
global_cumulative_sums.push(public_values.global_cumulative_sum);
let global_cumulative_sum = builder.sum_digest_v2(global_cumulative_sums);
{
let vk_digest = vk.hash(builder);
let zero: Felt<_> = builder.eval(SP1Field::zero());
let mut recursion_public_values_stream = [zero; RECURSIVE_PROOF_NUM_PV_ELTS];
let recursion_public_values: &mut RecursionPublicValues<_> =
recursion_public_values_stream.as_mut_slice().borrow_mut();
recursion_public_values.prev_committed_value_digest =
public_values.prev_committed_value_digest;
recursion_public_values.committed_value_digest = public_values.committed_value_digest;
recursion_public_values.prev_deferred_proofs_digest =
public_values.prev_deferred_proofs_digest;
recursion_public_values.deferred_proofs_digest = public_values.deferred_proofs_digest;
recursion_public_values.prev_deferred_proof = num_deferred_proofs;
recursion_public_values.deferred_proof = num_deferred_proofs;
recursion_public_values.pc_start = public_values.pc_start;
recursion_public_values.next_pc = public_values.next_pc;
recursion_public_values.initial_timestamp = public_values.initial_timestamp;
recursion_public_values.last_timestamp = public_values.last_timestamp;
recursion_public_values.previous_init_addr = public_values.previous_init_addr;
recursion_public_values.last_init_addr = public_values.last_init_addr;
recursion_public_values.previous_finalize_addr = public_values.previous_finalize_addr;
recursion_public_values.last_finalize_addr = public_values.last_finalize_addr;
recursion_public_values.previous_init_page_idx = public_values.previous_init_page_idx;
recursion_public_values.last_init_page_idx = public_values.last_init_page_idx;
recursion_public_values.previous_finalize_page_idx =
public_values.previous_finalize_page_idx;
recursion_public_values.last_finalize_page_idx = public_values.last_finalize_page_idx;
recursion_public_values.start_reconstruct_deferred_digest = reconstruct_deferred_digest;
recursion_public_values.end_reconstruct_deferred_digest = reconstruct_deferred_digest;
recursion_public_values.sp1_vk_digest = vk_digest;
recursion_public_values.vk_root = vk_root;
recursion_public_values.global_cumulative_sum = global_cumulative_sum;
recursion_public_values.contains_first_shard = public_values.is_first_execution_shard;
recursion_public_values.num_included_shard = builder.eval(SP1Field::one());
recursion_public_values.is_complete = is_complete;
recursion_public_values.prev_exit_code = public_values.prev_exit_code;
recursion_public_values.exit_code = public_values.exit_code;
recursion_public_values.prev_commit_syscall = public_values.prev_commit_syscall;
recursion_public_values.commit_syscall = public_values.commit_syscall;
recursion_public_values.prev_commit_deferred_syscall =
public_values.prev_commit_deferred_syscall;
recursion_public_values.commit_deferred_syscall = public_values.commit_deferred_syscall;
recursion_public_values.proof_nonce = public_values.proof_nonce;
recursion_public_values.digest = recursion_public_values_digest::<C, SP1GlobalContext>(
builder,
recursion_public_values,
);
assert_complete(builder, recursion_public_values, is_complete);
SP1GlobalContext::commit_recursion_public_values(builder, *recursion_public_values);
}
}
}