extern crate alloc;
use alloc::string::ToString;
use alloc::vec::Vec;
use alloc::{
format,
vec,
};
use lib_q_stark_air::{
Air,
AirBuilder,
BaseAir,
WindowAccess,
};
use lib_q_stark_field::Field;
use lib_q_stark_matrix::Matrix;
use lib_q_stark_matrix::dense::RowMajorMatrix;
use super::recursive_types::SerializedStarkProof;
use super::stark_verifier::RecursiveStarkVerificationInput;
use super::{
AirError,
StarkVerifierAir,
TraceGenerator,
validate_trace_dimensions,
};
pub type BatchRecursiveStarkVerificationInput<F: Field, Ch: Field> =
Vec<RecursiveStarkVerificationInput<F, Ch>>;
pub fn batch_recursive_verifier_public_values<F, Ch>(
serialized: &[SerializedStarkProof<F, Ch>],
) -> Vec<F>
where
F: Field
+ lib_q_stark_field::BasedVectorSpace<lib_q_stark_mersenne31::Mersenne31>
+ From<lib_q_poseidon::PoseidonField>
+ Into<lib_q_poseidon::PoseidonField>,
Ch: Field,
{
let flattened: Vec<F> = serialized
.iter()
.flat_map(|s| s.expected_public_values.iter().cloned())
.collect();
if flattened.is_empty() {
return vec![F::ZERO];
}
use lib_q_poseidon::{
Poseidon,
Poseidon128,
PoseidonField,
};
let hash_input: Vec<PoseidonField> = flattened.into_iter().map(|f| f.into()).collect();
let hash_output = Poseidon128.hash(&hash_input);
vec![hash_output[0].into()]
}
#[derive(Debug, Clone)]
pub struct BatchStarkVerifierAir<F: Field, Ch: Field = F> {
airs: Vec<StarkVerifierAir<F, Ch>>,
}
impl<F: Field, Ch: Field> BatchStarkVerifierAir<F, Ch> {
pub fn new(
serialized_proofs: Vec<SerializedStarkProof<F, Ch>>,
merkle_tree_depth: usize,
log_final_poly_len: usize,
num_fri_queries: usize,
) -> Result<Self, AirError> {
if serialized_proofs.is_empty() {
return Err(AirError::InvalidInput {
reason: "BatchStarkVerifierAir requires at least one proof".to_string(),
});
}
let airs: Vec<StarkVerifierAir<F, Ch>> = serialized_proofs
.into_iter()
.map(|p| {
StarkVerifierAir::new(p, merkle_tree_depth, log_final_poly_len, num_fri_queries)
})
.collect::<Result<Vec<_>, _>>()?;
let width0 = airs[0].width();
for (i, air) in airs.iter().enumerate().skip(1) {
if air.width() != width0 {
return Err(AirError::InvalidInput {
reason: format!(
"Proof {} has width {} but proof 0 has {}",
i,
air.width(),
width0
),
});
}
}
Ok(Self { airs })
}
pub fn num_proofs(&self) -> usize {
self.airs.len()
}
}
impl<F: Field, Ch: Field> BaseAir<F> for BatchStarkVerifierAir<F, Ch> {
fn width(&self) -> usize {
self.airs[0].width()
}
}
impl<AB: AirBuilder> Air<AB> for BatchStarkVerifierAir<AB::F, AB::F>
where
AB::F: Field + Sized + lib_q_stark_field::BasedVectorSpace<lib_q_stark_mersenne31::Mersenne31>,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.current_slice();
for air in &self.airs {
air.eval_on_slice(builder, local);
}
}
}
impl<F, Ch> TraceGenerator<F, BatchRecursiveStarkVerificationInput<F, Ch>>
for BatchStarkVerifierAir<F, Ch>
where
F: Field
+ lib_q_stark_field::BasedVectorSpace<lib_q_stark_mersenne31::Mersenne31>
+ From<lib_q_poseidon::PoseidonField>
+ Into<lib_q_poseidon::PoseidonField>,
Ch: Field,
{
fn generate_trace(
&self,
inputs: &BatchRecursiveStarkVerificationInput<F, Ch>,
) -> Result<RowMajorMatrix<F>, AirError> {
if inputs.len() != self.airs.len() {
return Err(AirError::InvalidInput {
reason: format!(
"Batch input len {} does not match num_proofs {}",
inputs.len(),
self.airs.len()
),
});
}
let width = self.width();
let mut all_rows = Vec::with_capacity(self.airs.len() * width);
for (air, input) in self.airs.iter().zip(inputs.iter()) {
let one_row = air.generate_trace(input)?;
if one_row.width() != width {
return Err(AirError::InvalidInput {
reason: format!(
"Generated row width {} does not match batch width {}",
one_row.width(),
width
),
});
}
for r in 0..one_row.height() {
for c in 0..width {
if let Some(v) = one_row.get(r, c) {
all_rows.push(v);
}
}
}
}
let num_rows = all_rows.len() / width;
validate_trace_dimensions(width, num_rows)?;
Ok(RowMajorMatrix::new(all_rows, width))
}
fn public_values(&self, inputs: &BatchRecursiveStarkVerificationInput<F, Ch>) -> Vec<F> {
let serialized: Vec<SerializedStarkProof<F, Ch>> = inputs
.iter()
.map(|inp| inp.serialized_proof.clone())
.collect();
batch_recursive_verifier_public_values(&serialized)
}
}
#[cfg(test)]
mod tests {
use alloc::vec;
use lib_q_stark::check_constraints;
use lib_q_stark_air::BaseAir;
use lib_q_stark_field::PrimeCharacteristicRing;
use lib_q_stark_field::extension::Complex;
use lib_q_stark_matrix::Matrix;
use lib_q_stark_mersenne31::Mersenne31;
use super::super::{
CommitmentVerificationInput,
ConstraintVerificationInput,
FriVerificationInput,
MerkleHash,
MerkleProofInput,
OpeningVerificationInput,
SerializedFriRound,
};
use super::*;
use crate::air::{
MerkleInclusionAir,
merkle_root_to_bytes,
};
type TestField = Complex<Mersenne31>;
fn sample_serialized_proof() -> SerializedStarkProof<TestField, TestField> {
SerializedStarkProof::<TestField, TestField> {
degree_bits: 8,
num_quotient_chunks: 2,
trace_width: 4,
is_zk: false,
trace_commitment_hash: [0u8; 32],
quotient_commitment_hash: [1u8; 32],
random_commitment_hash: None,
trace_local: vec![TestField::ZERO; 4],
trace_next: vec![TestField::ZERO; 4],
quotient_chunks: vec![vec![TestField::ZERO; 1]; 2],
random_values: None,
fri_rounds: vec![SerializedFriRound {
commitment_hash: [2u8; 32],
beta: vec![0u8; 8],
}],
final_poly: vec![TestField::ZERO; 2],
pow_witness: vec![],
zeta: TestField::ZERO,
zeta_next: TestField::ZERO,
alpha: TestField::ZERO,
expected_public_values: vec![TestField::ZERO],
}
}
fn sample_recursive_input(
serialized: &SerializedStarkProof<TestField, TestField>,
) -> RecursiveStarkVerificationInput<TestField, TestField> {
let tree_depth = 4;
let zero_hash = MerkleHash::hash_data(b"z");
let merkle_proof = MerkleProofInput {
leaf: b"leaf".to_vec(),
leaf_hash_direct: None,
path_bits: vec![false; tree_depth],
siblings: vec![zero_hash; tree_depth],
};
let merkle_air = MerkleInclusionAir::new(tree_depth).expect("MerkleInclusionAir");
let root_field = merkle_air.public_values(&merkle_proof)[0];
let root_bytes = merkle_root_to_bytes(&root_field);
RecursiveStarkVerificationInput {
serialized_proof: serialized.clone(),
commitment_inputs: CommitmentVerificationInput {
expected_roots: vec![root_bytes, root_bytes],
merkle_proofs: vec![merkle_proof.clone(), merkle_proof.clone()],
},
fri_inputs: FriVerificationInput {
fri_rounds: serialized.fri_rounds.clone(),
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO; 2],
query_indices: vec![0, 0],
query_evaluations: vec![TestField::ZERO, TestField::ZERO],
round_current_evals: vec![TestField::ZERO],
round_sibling_evals: vec![TestField::ZERO],
round_domain_point_inverses: vec![TestField::ZERO],
round_domain_point_x0: vec![TestField::ZERO],
round_parity: vec![TestField::ZERO],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![TestField::ZERO],
},
constraint_inputs: ConstraintVerificationInput {
quotient_chunks: vec![TestField::ZERO; serialized.num_quotient_chunks],
trace_local: vec![TestField::ZERO; serialized.trace_width],
trace_next: vec![TestField::ZERO; serialized.trace_width],
zeta: TestField::ZERO,
alpha: TestField::ZERO,
public_values: serialized.expected_public_values.clone(),
},
opening_inputs: OpeningVerificationInput {
opened_values: vec![
TestField::ZERO;
serialized.trace_width * 2 + serialized.num_quotient_chunks
],
domain_points: vec![
TestField::ZERO;
serialized.trace_width * 2 + serialized.num_quotient_chunks
],
merkle_proofs: vec![
merkle_proof;
serialized.trace_width * 2 + serialized.num_quotient_chunks
],
expected_roots: vec![
TestField::ZERO;
serialized.trace_width * 2 + serialized.num_quotient_chunks
],
},
}
}
#[test]
fn test_batch_air_new_rejects_empty_proofs() {
let result = BatchStarkVerifierAir::<TestField, TestField>::new(vec![], 4, 1, 2);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_batch_air_generate_trace_rejects_input_length_mismatch() {
let serialized = sample_serialized_proof();
let batch_air =
BatchStarkVerifierAir::<TestField, TestField>::new(vec![serialized], 4, 1, 2).unwrap();
let result = batch_air.generate_trace(&vec![]);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_batch_air_generate_trace_and_public_values() {
let proof_a = sample_serialized_proof();
let mut proof_b = sample_serialized_proof();
proof_b.expected_public_values = vec![TestField::ONE];
let input_a = sample_recursive_input(&proof_a);
let input_b = sample_recursive_input(&proof_b);
let batch_air =
BatchStarkVerifierAir::<TestField, TestField>::new(vec![proof_a, proof_b], 4, 1, 2)
.unwrap();
let trace = batch_air
.generate_trace(&vec![input_a.clone(), input_b.clone()])
.expect("batch trace should generate");
assert_eq!(batch_air.num_proofs(), 2);
assert_eq!(trace.height(), 2);
assert_eq!(trace.width(), BaseAir::<TestField>::width(&batch_air));
let pvs = batch_air.public_values(&vec![input_a, input_b]);
assert_eq!(pvs.len(), 1);
}
#[test]
fn test_batch_public_values_empty_flattened_returns_zero() {
let mut serialized = sample_serialized_proof();
serialized.expected_public_values = vec![];
let input = sample_recursive_input(&serialized);
let batch_air =
BatchStarkVerifierAir::<TestField, TestField>::new(vec![serialized], 4, 1, 2).unwrap();
let public_values = batch_air.public_values(&vec![input]);
assert_eq!(public_values, vec![TestField::ZERO]);
}
#[test]
#[should_panic]
fn test_batch_constraints_detect_placeholder_trace() {
let proof_a = sample_serialized_proof();
let mut proof_b = sample_serialized_proof();
proof_b.expected_public_values = vec![TestField::ONE];
let input_a = sample_recursive_input(&proof_a);
let input_b = sample_recursive_input(&proof_b);
let batch_air =
BatchStarkVerifierAir::<TestField, TestField>::new(vec![proof_a, proof_b], 4, 1, 2)
.unwrap();
let trace = batch_air
.generate_trace(&vec![input_a.clone(), input_b.clone()])
.expect("batch trace should generate");
let public_values = batch_air.public_values(&vec![input_a, input_b]);
check_constraints(&batch_air, &trace, &public_values);
}
}