use std::collections::HashMap;
use num_traits::{One, Zero};
use paste::paste;
use serde_json::to_string_pretty;
use stwo::core::channel::{Channel, MerkleChannel};
use stwo::core::fields::m31::BaseField;
use stwo::core::fields::qm31::SecureField;
use stwo::core::pcs::CommitmentSchemeVerifier;
use stwo::core::verifier::{verify_ex, VerificationError};
use stwo_cairo_common::builtins::*;
use stwo_cairo_common::memory::{LARGE_MEMORY_VALUE_ID_BASE, LOG_MEMORY_ADDRESS_BOUND};
use stwo_cairo_common::prover_types::cpu::{CasmState, PRIME};
use stwo_constraint_framework::PREPROCESSED_TRACE_IDX;
use thiserror::Error;
use tracing::{span, Level};
use crate::air::{MemorySection, PublicData, PublicMemory, PublicSegmentRanges, SegmentRange};
use crate::cairo_components::CairoComponents;
use crate::claims::{lookup_sum, CairoClaim};
use crate::components::memory_address_to_id::MEMORY_ADDRESS_TO_ID_SPLIT;
use crate::components::{
add_mod_builtin, bitwise_builtin, ec_op_builtin, mul_mod_builtin, pedersen_builtin,
pedersen_builtin_narrow_windows, poseidon_builtin, range_check96_builtin, range_check_builtin,
};
use crate::relations::CommonLookupElements;
use crate::CairoProofForRustVerifier;
fn verify_claim(claim: &CairoClaim) {
let PublicData {
public_memory:
PublicMemory {
program,
public_segments,
output: _output,
safe_call_ids: _safe_call_ids,
},
initial_state:
CasmState {
pc: initial_pc,
ap: initial_ap,
fp: initial_fp,
},
final_state:
CasmState {
pc: final_pc,
ap: final_ap,
fp: final_fp,
},
} = &claim.public_data;
verify_builtins(
&claim.add_mod_builtin,
&claim.bitwise_builtin,
&claim.mul_mod_builtin,
&claim.pedersen_builtin,
&claim.pedersen_builtin_narrow_windows,
&claim.poseidon_builtin,
&claim.range_check96_builtin,
&claim.range_check_builtin,
&claim.ec_op_builtin,
public_segments,
);
verify_program(program, public_segments);
assert_eq!(*initial_pc, BaseField::one());
assert!(
*initial_pc + BaseField::from(2) < *initial_ap,
"Initial pc + 2 must be less than initial ap, but got initial_pc: {initial_pc}, initial_ap: {initial_ap}"
);
assert_eq!(initial_fp, final_fp);
assert_eq!(initial_fp, initial_ap);
assert_eq!(*final_pc, BaseField::from(5));
assert!(initial_ap <= final_ap);
let mut relation_uses = HashMap::<&'static str, u64>::new();
claim.accumulate_relation_uses(&mut relation_uses);
check_relation_uses(&relation_uses);
let largest_id = claim
.memory_id_to_big
.as_ref()
.unwrap()
.big_log_sizes
.iter()
.map(|log_size| 1 << log_size)
.sum::<u32>()
- 1
+ LARGE_MEMORY_VALUE_ID_BASE;
assert!(largest_id < PRIME);
}
fn check_relation_uses(relation_uses: &HashMap<&'static str, u64>) {
let all_relation_uses_pretty = to_string_pretty(&relation_uses).unwrap();
log::info!("All relation uses:\n{all_relation_uses_pretty}");
let outstanding_relations = relation_uses
.iter()
.filter(|(_, &uses)| uses >= PRIME.into())
.collect::<Vec<_>>();
if !outstanding_relations.is_empty() {
let outstanding_relations_pretty = to_string_pretty(&outstanding_relations).unwrap();
panic!(
"Found {} outstanding relations:\n{}",
outstanding_relations.len(),
outstanding_relations_pretty
);
}
}
#[derive(Clone)]
pub struct RelationUse {
pub relation_id: &'static str,
pub uses: u64,
}
struct BuiltinClaim {
segment_start: u32,
log_size: u32,
}
#[allow(clippy::too_many_arguments)]
fn verify_builtins(
add_mod_builtin_claim: &Option<add_mod_builtin::Claim>,
bitwise_builtin_claim: &Option<bitwise_builtin::Claim>,
mul_mod_builtin_claim: &Option<mul_mod_builtin::Claim>,
pedersen_builtin_claim: &Option<pedersen_builtin::Claim>,
pedersen_builtin_narrow_windows_claim: &Option<pedersen_builtin_narrow_windows::Claim>,
poseidon_builtin_claim: &Option<poseidon_builtin::Claim>,
range_check_96_builtin_claim: &Option<range_check96_builtin::Claim>,
range_check_128_builtin_claim: &Option<range_check_builtin::Claim>,
ec_op_builtin_claim: &Option<ec_op_builtin::Claim>,
segment_ranges: &PublicSegmentRanges,
) {
let PublicSegmentRanges {
output,
pedersen,
range_check_128,
ecdsa,
bitwise,
ec_op,
keccak,
poseidon,
range_check_96,
add_mod,
mul_mod,
} = *segment_ranges;
if let Some(ecdsa) = ecdsa {
assert_eq!(
ecdsa.start_ptr.value, ecdsa.stop_ptr.value,
"ECDSA segment is not empty"
);
}
if let Some(keccak) = keccak {
assert_eq!(
keccak.start_ptr.value, keccak.stop_ptr.value,
"Keccak segment is not empty"
);
}
assert!(output.stop_ptr.value < 1 << 31);
assert!(output.start_ptr.value <= output.stop_ptr.value);
macro_rules! check_builtin_generic {
($name:ident) => {
paste! {
check_builtin(
(*[<$name _builtin_claim>])
.map(|claim| BuiltinClaim {
segment_start: claim.[<$name _builtin_segment_start>],
log_size: claim.log_size,
}),
$name,
stringify!($name),
[<$name:upper _BUILTIN_MEMORY_CELLS>]
);
}
};
}
check_builtin(
(*range_check_128_builtin_claim).map(|claim| BuiltinClaim {
segment_start: claim.range_check_builtin_segment_start,
log_size: claim.log_size,
}),
range_check_128,
"range_check_128",
RANGE_CHECK_BUILTIN_MEMORY_CELLS,
);
check_builtin(
(*range_check_96_builtin_claim).map(|claim| BuiltinClaim {
segment_start: claim.range_check96_builtin_segment_start,
log_size: claim.log_size,
}),
range_check_96,
"range_check_96",
RANGE_CHECK_96_BUILTIN_MEMORY_CELLS,
);
assert!(
!(pedersen_builtin_claim.is_some() && pedersen_builtin_narrow_windows_claim.is_some()),
"Both pedersen_builtin_claim and pedersen_builtin_narrow_windows_claim builtins cannot be used together");
if let Some(claim) = pedersen_builtin_claim {
check_builtin(
Some(BuiltinClaim {
segment_start: claim.pedersen_builtin_segment_start,
log_size: claim.log_size,
}),
pedersen,
"pedersen",
PEDERSEN_BUILTIN_MEMORY_CELLS,
);
} else {
check_builtin(
pedersen_builtin_narrow_windows_claim.map(|claim| BuiltinClaim {
segment_start: claim.pedersen_builtin_segment_start,
log_size: claim.log_size,
}),
pedersen,
"pedersen",
PEDERSEN_BUILTIN_NARROW_WINDOWS_MEMORY_CELLS,
);
};
check_builtin_generic!(bitwise);
check_builtin_generic!(add_mod);
check_builtin_generic!(mul_mod);
check_builtin_generic!(poseidon);
check_builtin_generic!(ec_op);
}
fn verify_program(program: &MemorySection, public_segments: &PublicSegmentRanges) {
let n_builtins = public_segments.present_segments().len() as u32;
assert_eq!(program[0].1, [0x7fff7fff, 0x4078001, 0, 0, 0, 0, 0, 0]); assert_eq!(program[1].1, [n_builtins, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(program[2].1, [0x80018000, 0x11048001, 0, 0, 0, 0, 0, 0]); assert_eq!(program[4].1, [0x7fff7fff, 0x1078001, 0, 0, 0, 0, 0, 0]); assert_eq!(program[5].1, [0, 0, 0, 0, 0, 0, 0, 0]); }
fn check_builtin(
builtin_claim: Option<BuiltinClaim>,
segment_range: Option<SegmentRange>,
name: &str,
n_cells: usize,
) {
let segment_range = match segment_range {
None => return,
Some(segment_range) => {
if segment_range.is_empty() {
return;
}
segment_range
}
};
let BuiltinClaim {
segment_start,
log_size,
} = builtin_claim.unwrap_or_else(|| {
panic!("Missing {name} builtin claim despite non-empty segment range {segment_range:?}")
});
let segment_end = segment_start + (1 << log_size) * n_cells as u32;
let start_ptr = segment_range.start_ptr.value;
let stop_ptr = segment_range.stop_ptr.value;
assert!(
(stop_ptr - start_ptr).is_multiple_of(n_cells as u32),
"Builtin segment range must divisible by {n_cells} cells, but got start_ptr: {start_ptr}, stop_ptr: {stop_ptr}"
);
assert_eq!(
start_ptr, segment_start,
"Builtin segment start doesn't match claim"
);
assert!(
start_ptr <= stop_ptr,
"Range start should be less than or equal to range stop"
);
assert!(
stop_ptr <= segment_end,
"Builtin stop pointer must be within the builtin segment"
);
assert!(
segment_end < 1 << 31,
"segment_end must be less than 2^31, but got {segment_end}"
);
}
pub const INTERACTION_POW_BITS: u32 = 24;
pub fn verify_cairo<MC: MerkleChannel>(
proof: CairoProofForRustVerifier<MC::H>,
) -> Result<(), CairoVerificationError> {
verify_cairo_ex::<MC>(proof, false)
}
pub fn verify_cairo_ex<MC: MerkleChannel>(
CairoProofForRustVerifier {
claim,
interaction_pow,
interaction_claim,
stark_proof,
channel_salt,
preprocessed_trace_variant,
}: CairoProofForRustVerifier<MC::H>,
include_all_preprocessed_columns: bool,
) -> Result<(), CairoVerificationError> {
let _span = span!(Level::INFO, "verify_cairo").entered();
assert!(
(1 << claim.memory_address_to_id.as_ref().unwrap().log_size) * MEMORY_ADDRESS_TO_ID_SPLIT
<= (1 << LOG_MEMORY_ADDRESS_BOUND)
);
verify_claim(&claim);
let channel = &mut MC::C::default();
channel.mix_felts(&[channel_salt.into()]);
let pcs_config = stark_proof.config;
pcs_config.mix_into(channel);
let commitment_scheme_verifier = &mut CommitmentSchemeVerifier::<MC>::new(pcs_config);
let mut log_sizes = claim.log_sizes();
log_sizes[PREPROCESSED_TRACE_IDX] = preprocessed_trace_variant
.to_preprocessed_trace()
.log_sizes();
commitment_scheme_verifier.commit(stark_proof.commitments[0], &log_sizes[0], channel);
claim.mix_into::<MC>(channel);
commitment_scheme_verifier.commit(stark_proof.commitments[1], &log_sizes[1], channel);
if !channel.verify_pow_nonce(INTERACTION_POW_BITS, interaction_pow) {
return Err(CairoVerificationError::ProofOfWork);
}
channel.mix_u64(interaction_pow);
let interaction_elements = CommonLookupElements::draw(channel);
if lookup_sum(&claim, &interaction_elements, &interaction_claim) != SecureField::zero() {
return Err(CairoVerificationError::InvalidLogupSum);
}
interaction_claim.mix_into(channel);
commitment_scheme_verifier.commit(stark_proof.commitments[2], &log_sizes[2], channel);
let component_generator = CairoComponents::new(
&claim,
&interaction_elements,
&interaction_claim,
&preprocessed_trace_variant.to_preprocessed_trace().ids(),
);
let components = component_generator.components();
verify_ex(
&components,
channel,
commitment_scheme_verifier,
stark_proof,
include_all_preprocessed_columns,
)
.map_err(CairoVerificationError::Stark)
}
#[derive(Error, Debug)]
pub enum CairoVerificationError {
#[error("Invalid logup sum")]
InvalidLogupSum,
#[error("Stark verification error: {0}")]
Stark(#[from] VerificationError),
#[error("Proof of work verification failed.")]
ProofOfWork,
}