extern crate alloc;
use alloc::vec::Vec;
use alloc::{
format,
vec,
};
use lib_q_stark_air::{
Air,
AirBuilder,
BaseAir,
WindowAccess,
};
use lib_q_stark_field::integers::QuotientMap;
use lib_q_stark_field::{
Field,
PrimeCharacteristicRing,
};
use lib_q_stark_matrix::dense::RowMajorMatrix;
use super::recursive_types::{
MAX_FINAL_POLY_LOG_LEN,
MAX_FRI_ROUNDS,
SerializedFriRound,
};
use super::{
AirError,
TraceGenerator,
next_power_of_two,
validate_trace_dimensions,
};
pub const MAX_FRI_QUERIES: usize = 1000;
#[derive(Debug, Clone)]
pub struct FriVerifierAir {
num_rounds: usize,
log_final_poly_len: usize,
num_queries: usize,
}
impl FriVerifierAir {
pub fn new(
num_rounds: usize,
log_final_poly_len: usize,
num_queries: usize,
) -> Result<Self, AirError> {
if num_rounds == 0 || num_rounds > MAX_FRI_ROUNDS {
return Err(AirError::InvalidDimensions {
reason: format!(
"Number of FRI rounds must be between 1 and {}",
MAX_FRI_ROUNDS
),
});
}
if log_final_poly_len > MAX_FINAL_POLY_LOG_LEN {
return Err(AirError::InvalidDimensions {
reason: format!(
"Log final poly len {} exceeds maximum {}",
log_final_poly_len, MAX_FINAL_POLY_LOG_LEN
),
});
}
if num_queries == 0 || num_queries > MAX_FRI_QUERIES {
return Err(AirError::InvalidDimensions {
reason: format!(
"Number of queries must be between 1 and {}",
MAX_FRI_QUERIES
),
});
}
Ok(Self {
num_rounds,
log_final_poly_len,
num_queries,
})
}
pub fn num_rounds(&self) -> usize {
self.num_rounds
}
pub fn log_final_poly_len(&self) -> usize {
self.log_final_poly_len
}
pub fn num_queries(&self) -> usize {
self.num_queries
}
fn trace_width(&self) -> usize {
let per_round = 32 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1; let final_poly_len = 1 << self.log_final_poly_len;
let final_section = final_poly_len + 1 + final_poly_len;
let per_query = 1 + 1;
self.num_rounds * per_round + final_section + self.num_queries * per_query
}
}
impl<F: Field> BaseAir<F> for FriVerifierAir {
fn width(&self) -> usize {
self.trace_width()
}
}
impl<AB: AirBuilder> Air<AB> for FriVerifierAir
where
AB::F: Field,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.current_slice();
Self::eval_with_offset(
builder,
local,
0,
self.num_rounds,
self.log_final_poly_len,
self.num_queries,
);
}
}
impl FriVerifierAir {
pub fn eval_with_offset<AB: AirBuilder>(
builder: &mut AB,
local: &[AB::Var],
offset: usize,
num_rounds: usize,
log_final_poly_len: usize,
num_queries: usize,
) where
AB::F: Field,
{
let per_round = 32 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1;
let final_poly_len = 1 << log_final_poly_len;
let per_query = 1 + 1;
use lib_q_stark_field::PrimeCharacteristicRing;
let one = <AB::F as PrimeCharacteristicRing>::ONE;
let one_expr = AB::Expr::from(one);
for round_idx in 0..num_rounds {
let round_start = offset + round_idx * per_round;
let beta_col = round_start + 32;
let beta = local[beta_col];
let folded_eval = local[beta_col + 1];
let sibling_eval = local[beta_col + 2];
let current_eval = local[beta_col + 3];
let domain_point_inv = local[beta_col + 4];
let x0 = local[beta_col + 5];
let parity = local[beta_col + 6];
let roll_in = local[beta_col + 7];
let p = AB::Expr::from(parity);
let e0 = (one_expr.clone() - p.clone()) * AB::Expr::from(current_eval) +
p.clone() * AB::Expr::from(sibling_eval);
let e1 = p.clone() * AB::Expr::from(current_eval) +
(one_expr.clone() - p) * AB::Expr::from(sibling_eval);
let diff = e1 - e0.clone();
let fold_part = e0 +
(AB::Expr::from(beta) - AB::Expr::from(x0)) *
diff *
AB::Expr::from(domain_point_inv);
let expected_folded =
fold_part + AB::Expr::from(beta) * AB::Expr::from(beta) * AB::Expr::from(roll_in);
builder.assert_eq(AB::Expr::from(folded_eval), expected_folded);
}
for round_idx in 1..num_rounds {
let prev_folded_col = offset + (round_idx - 1) * per_round + 32 + 1;
let curr_current_col = offset + round_idx * per_round + 32 + 3;
builder.assert_eq(
local[prev_folded_col].into(),
local[curr_current_col].into(),
);
}
let coeff_start = offset + num_rounds * per_round;
let eval_point_col = coeff_start + final_poly_len;
let horner_start = eval_point_col + 1;
if final_poly_len > 0 {
let eval_point = local[eval_point_col];
builder.assert_eq(
local[horner_start].into(),
local[coeff_start + final_poly_len - 1].into(),
);
for i in 1..final_poly_len {
let prev_horner = local[horner_start + i - 1];
let coeff = local[coeff_start + final_poly_len - 1 - i];
let expected = AB::Expr::from(prev_horner) * AB::Expr::from(eval_point) +
AB::Expr::from(coeff);
builder.assert_eq(local[horner_start + i].into(), expected);
}
if num_rounds > 0 {
let last_folded_col = offset + (num_rounds - 1) * per_round + 32 + 1;
let horner_result_col = horner_start + final_poly_len - 1;
builder.assert_eq(
local[last_folded_col].into(),
local[horner_result_col].into(),
);
}
}
let queries_start = horner_start + final_poly_len;
if num_rounds > 0 && num_queries > 0 {
let first_round_current_col = offset + 32 + 3;
let first_query_eval_col = queries_start + 1;
builder.assert_eq(
local[first_round_current_col].into(),
local[first_query_eval_col].into(),
);
}
for query_idx in 0..num_queries {
let _query_start = queries_start + query_idx * per_query;
}
}
}
#[derive(Debug, Clone)]
pub struct FriVerificationInput<F: Field> {
pub fri_rounds: Vec<SerializedFriRound>,
pub round_betas: Vec<F>,
pub final_poly: Vec<F>,
pub query_indices: Vec<usize>,
pub query_evaluations: Vec<F>,
pub round_current_evals: Vec<F>,
pub round_sibling_evals: Vec<F>,
pub round_domain_point_inverses: Vec<F>,
pub round_domain_point_x0: Vec<F>,
pub round_parity: Vec<F>,
pub final_poly_eval_point: F,
pub round_roll_ins: Vec<F>,
}
impl<F: Field> TraceGenerator<F, FriVerificationInput<F>> for FriVerifierAir {
fn generate_trace(
&self,
inputs: &FriVerificationInput<F>,
) -> Result<RowMajorMatrix<F>, AirError> {
if inputs.fri_rounds.len() != self.num_rounds {
return Err(AirError::InvalidInput {
reason: format!(
"FRI rounds length {} doesn't match expected {}",
inputs.fri_rounds.len(),
self.num_rounds
),
});
}
if inputs.round_betas.len() != self.num_rounds {
return Err(AirError::InvalidInput {
reason: format!(
"round_betas length {} doesn't match expected {}",
inputs.round_betas.len(),
self.num_rounds
),
});
}
let final_poly_len = 1 << self.log_final_poly_len;
if inputs.final_poly.len() != final_poly_len {
return Err(AirError::InvalidInput {
reason: format!(
"Final poly length {} doesn't match expected {}",
inputs.final_poly.len(),
final_poly_len
),
});
}
if inputs.query_indices.len() != self.num_queries {
return Err(AirError::InvalidInput {
reason: format!(
"Query indices length {} doesn't match expected {}",
inputs.query_indices.len(),
self.num_queries
),
});
}
let width = self.trace_width();
let num_rows_padded = next_power_of_two(1);
validate_trace_dimensions(width, num_rows_padded)?;
let mut trace_values = vec![F::ZERO; num_rows_padded * width];
let per_round = 32 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1;
let final_poly_len = 1 << self.log_final_poly_len;
let per_query = 1 + 1;
let one_f = <F as PrimeCharacteristicRing>::ONE;
for (round_idx, round) in inputs.fri_rounds.iter().enumerate() {
let round_start = round_idx * per_round;
for (i, &byte) in round.commitment_hash.iter().enumerate() {
trace_values[round_start + i] =
F::from_prime_subfield(<F::PrimeSubfield as QuotientMap<u8>>::from_int(byte));
}
let beta_col = round_start + 32;
let beta_f = inputs.round_betas[round_idx];
trace_values[beta_col] = beta_f;
let current_f = inputs
.round_current_evals
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let sibling_f = inputs
.round_sibling_evals
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let domain_inv_f = inputs
.round_domain_point_inverses
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let x0_f = inputs
.round_domain_point_x0
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let parity_f = inputs
.round_parity
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let e0_f = (one_f - parity_f) * current_f + parity_f * sibling_f;
let e1_f = parity_f * current_f + (one_f - parity_f) * sibling_f;
let roll_in_f = inputs
.round_roll_ins
.get(round_idx)
.copied()
.unwrap_or(F::ZERO);
let folded_f =
e0_f + (beta_f - x0_f) * (e1_f - e0_f) * domain_inv_f + beta_f * beta_f * roll_in_f;
trace_values[beta_col + 1] = folded_f;
trace_values[beta_col + 2] = sibling_f;
trace_values[beta_col + 3] = current_f;
trace_values[beta_col + 4] = domain_inv_f;
trace_values[beta_col + 5] = x0_f;
trace_values[beta_col + 6] = parity_f;
trace_values[beta_col + 7] = roll_in_f;
}
let final_poly_start = self.num_rounds * per_round;
let coeff_vals: &[F] = &inputs.final_poly[..final_poly_len.min(inputs.final_poly.len())];
for (i, &c) in coeff_vals.iter().enumerate() {
trace_values[final_poly_start + i] = c;
}
let eval_point_col = final_poly_start + final_poly_len;
let eval_point = inputs.final_poly_eval_point;
trace_values[eval_point_col] = eval_point;
let horner_start = eval_point_col + 1;
if final_poly_len > 0 {
trace_values[horner_start] = coeff_vals[final_poly_len - 1];
for i in 1..final_poly_len {
let prev = trace_values[horner_start + i - 1];
let coeff = coeff_vals[final_poly_len - 1 - i];
trace_values[horner_start + i] = prev * eval_point + coeff;
}
}
let queries_start = horner_start + final_poly_len;
for (query_idx, &index) in inputs.query_indices.iter().enumerate() {
let query_start = queries_start + query_idx * per_query;
trace_values[query_start] =
F::from_prime_subfield(<F::PrimeSubfield as QuotientMap<usize>>::from_int(index));
let eval_f = inputs
.query_evaluations
.get(query_idx)
.copied()
.unwrap_or(F::ZERO);
trace_values[query_start + 1] = eval_f;
}
Ok(RowMajorMatrix::new(trace_values, width))
}
fn public_values(&self, inputs: &FriVerificationInput<F>) -> Vec<F> {
let final_poly_len = 1 << self.log_final_poly_len;
inputs
.final_poly
.iter()
.take(final_poly_len)
.copied()
.collect()
}
}
#[cfg(test)]
mod tests {
use lib_q_stark::check_constraints;
use lib_q_stark_air::BaseAir;
use lib_q_stark_field::extension::Complex;
use lib_q_stark_matrix::Matrix;
use lib_q_stark_mersenne31::Mersenne31;
use super::super::recursive_types::SerializedFriRound;
use super::*;
type TestField = Complex<Mersenne31>;
#[test]
fn test_fri_verifier_air_new_valid() {
let air = FriVerifierAir::new(8, 4, 10);
assert!(air.is_ok());
let air = air.unwrap();
assert_eq!(air.num_rounds(), 8);
assert_eq!(air.log_final_poly_len(), 4);
assert_eq!(air.num_queries(), 10);
}
#[test]
fn test_fri_verifier_air_new_invalid() {
let result = FriVerifierAir::new(0, 4, 10);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
let result = FriVerifierAir::new(MAX_FRI_ROUNDS + 1, 4, 10);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
let result = FriVerifierAir::new(8, MAX_FINAL_POLY_LOG_LEN + 1, 10);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
let result = FriVerifierAir::new(8, 4, 0);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
let result = FriVerifierAir::new(8, 4, MAX_FRI_QUERIES + 1);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
}
#[test]
fn test_fri_verifier_air_accepts_max_queries() {
let air = FriVerifierAir::new(1, 0, MAX_FRI_QUERIES).unwrap();
assert_eq!(air.num_queries(), MAX_FRI_QUERIES);
}
#[test]
fn test_fri_verifier_air_width() {
let air = FriVerifierAir::new(4, 3, 5).unwrap();
let width = BaseAir::<TestField>::width(&air);
assert!(width > 0);
}
#[test]
fn test_generate_trace_basic() {
let air = FriVerifierAir::new(2, 2, 1).unwrap();
let zero = TestField::ZERO;
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![
SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![1, 2, 3],
},
SerializedFriRound {
commitment_hash: [1u8; 32],
beta: vec![4, 5, 6],
},
],
round_betas: vec![zero, zero],
final_poly: vec![zero; 4],
query_indices: vec![0],
query_evaluations: vec![zero],
round_current_evals: vec![zero, zero],
round_sibling_evals: vec![zero, zero],
round_domain_point_inverses: vec![zero, zero],
round_domain_point_x0: vec![zero, zero],
round_parity: vec![zero, zero],
final_poly_eval_point: zero,
round_roll_ins: vec![zero, zero],
};
let trace: Result<RowMajorMatrix<TestField>, _> = air.generate_trace(&input);
assert!(trace.is_ok());
}
#[test]
fn test_generate_trace_mismatched_lengths() {
let air = FriVerifierAir::new(2, 2, 1).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![],
}],
round_betas: vec![],
final_poly: vec![TestField::ZERO; 4],
query_indices: vec![],
query_evaluations: vec![],
round_current_evals: vec![],
round_sibling_evals: vec![],
round_domain_point_inverses: vec![],
round_domain_point_x0: vec![],
round_parity: vec![],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![],
};
let result: Result<RowMajorMatrix<TestField>, _> = air.generate_trace(&input);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_generate_trace_rejects_round_beta_length_mismatch() {
let air = FriVerifierAir::new(2, 1, 1).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![
SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
};
2
],
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO; 2],
query_indices: vec![0],
query_evaluations: vec![TestField::ZERO],
round_current_evals: vec![TestField::ZERO; 2],
round_sibling_evals: vec![TestField::ZERO; 2],
round_domain_point_inverses: vec![TestField::ZERO; 2],
round_domain_point_x0: vec![TestField::ZERO; 2],
round_parity: vec![TestField::ZERO; 2],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![TestField::ZERO; 2],
};
let result: Result<RowMajorMatrix<TestField>, _> = air.generate_trace(&input);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_generate_trace_rejects_final_poly_length_mismatch() {
let air = FriVerifierAir::new(1, 2, 1).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
}],
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO; 3],
query_indices: vec![0],
query_evaluations: vec![TestField::ZERO],
round_current_evals: vec![TestField::ZERO],
round_sibling_evals: vec![TestField::ZERO],
round_domain_point_inverses: vec![TestField::ZERO],
round_domain_point_x0: vec![TestField::ZERO],
round_parity: vec![TestField::ZERO],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![TestField::ZERO],
};
let result: Result<RowMajorMatrix<TestField>, _> = air.generate_trace(&input);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_generate_trace_rejects_query_index_length_mismatch() {
let air = FriVerifierAir::new(1, 1, 2).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
}],
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO; 2],
query_indices: vec![0],
query_evaluations: vec![TestField::ZERO],
round_current_evals: vec![TestField::ZERO],
round_sibling_evals: vec![TestField::ZERO],
round_domain_point_inverses: vec![TestField::ONE],
round_domain_point_x0: vec![TestField::ZERO],
round_parity: vec![TestField::ZERO],
final_poly_eval_point: TestField::ONE,
round_roll_ins: vec![TestField::ZERO],
};
let result: Result<RowMajorMatrix<TestField>, _> = air.generate_trace(&input);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_fri_public_values_truncates_to_expected_poly_len() {
let air = FriVerifierAir::new(1, 1, 1).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
}],
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO, TestField::ONE, TestField::ONE],
query_indices: vec![0],
query_evaluations: vec![TestField::ZERO],
round_current_evals: vec![],
round_sibling_evals: vec![],
round_domain_point_inverses: vec![],
round_domain_point_x0: vec![],
round_parity: vec![],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![],
};
let public_values = air.public_values(&input);
assert_eq!(public_values.len(), 2);
assert_eq!(public_values[0], TestField::ZERO);
assert_eq!(public_values[1], TestField::ONE);
}
#[test]
fn test_fri_trace_generation_uses_default_zero_for_missing_round_vectors() {
let air = FriVerifierAir::new(1, 1, 1).unwrap();
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
}],
round_betas: vec![TestField::ZERO],
final_poly: vec![TestField::ZERO; 2],
query_indices: vec![0],
query_evaluations: vec![],
round_current_evals: vec![],
round_sibling_evals: vec![],
round_domain_point_inverses: vec![],
round_domain_point_x0: vec![],
round_parity: vec![],
final_poly_eval_point: TestField::ZERO,
round_roll_ins: vec![],
};
let trace: RowMajorMatrix<TestField> = air.generate_trace(&input).expect("trace");
assert_eq!(trace.height(), 1);
}
#[test]
fn test_fri_trace_satisfies_constraints() {
let air = FriVerifierAir::new(1, 1, 1).unwrap();
let zero = TestField::ZERO;
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![SerializedFriRound {
commitment_hash: [0u8; 32],
beta: vec![0u8; 8],
}],
round_betas: vec![zero],
final_poly: vec![zero; 2],
query_indices: vec![0],
query_evaluations: vec![zero],
round_current_evals: vec![zero],
round_sibling_evals: vec![zero],
round_domain_point_inverses: vec![zero],
round_domain_point_x0: vec![zero],
round_parity: vec![zero],
final_poly_eval_point: zero,
round_roll_ins: vec![zero],
};
let trace: RowMajorMatrix<TestField> = air.generate_trace(&input).expect("trace");
let public_values: Vec<TestField> = air.public_values(&input);
check_constraints(&air, &trace, &public_values);
}
#[test]
fn test_fri_trace_populates_commitment_folding_horner_and_queries() {
let air = FriVerifierAir::new(2, 2, 2).unwrap();
let one = TestField::ONE;
let two = one + one;
let three = two + one;
let input = FriVerificationInput::<TestField> {
fri_rounds: vec![
SerializedFriRound {
commitment_hash: [5u8; 32],
beta: vec![1u8; 8],
},
SerializedFriRound {
commitment_hash: [7u8; 32],
beta: vec![2u8; 8],
},
],
round_betas: vec![three, two],
final_poly: vec![one, two, three, TestField::ZERO],
query_indices: vec![3, 9],
query_evaluations: vec![three],
round_current_evals: vec![two, one],
round_sibling_evals: vec![one, three],
round_domain_point_inverses: vec![one, one],
round_domain_point_x0: vec![one, one],
round_parity: vec![TestField::ZERO, one],
final_poly_eval_point: two,
round_roll_ins: vec![one, TestField::ZERO],
};
let trace: RowMajorMatrix<TestField> = air.generate_trace(&input).expect("trace");
assert_eq!(trace.height(), 1);
let per_round = 32 + 1 + 1 + 1 + 1 + 1 + 1 + 1 + 1;
let round0_start = 0;
let round1_start = per_round;
let round0_beta_col = round0_start + 32;
let round1_beta_col = round1_start + 32;
assert_eq!(
trace.get(0, round0_start),
Some(TestField::new_real(Mersenne31::new(5)))
);
assert_eq!(
trace.get(0, round1_start),
Some(TestField::new_real(Mersenne31::new(7)))
);
let folded0 = trace.get(0, round0_beta_col + 1);
let expected_folded0 = two + (three - one) * (one - two) * one + three * three * one;
assert_eq!(folded0, Some(expected_folded0));
assert_eq!(trace.get(0, round0_beta_col + 2), Some(one));
assert_eq!(trace.get(0, round0_beta_col + 3), Some(two));
assert_eq!(trace.get(0, round0_beta_col + 7), Some(one));
let folded1 = trace.get(0, round1_beta_col + 1);
let expected_folded1 = three + (two - one) * (one - three) * one;
assert_eq!(folded1, Some(expected_folded1));
let final_poly_start = 2 * per_round;
let eval_point_col = final_poly_start + 4;
let horner_start = eval_point_col + 1;
assert_eq!(trace.get(0, eval_point_col), Some(two));
assert_eq!(trace.get(0, horner_start), Some(TestField::ZERO));
assert_eq!(trace.get(0, horner_start + 1), Some(three));
assert_eq!(trace.get(0, horner_start + 2), Some(three * two + two));
assert_eq!(
trace.get(0, horner_start + 3),
Some((three * two + two) * two + one)
);
let queries_start = horner_start + 4;
assert_eq!(
trace.get(0, queries_start),
Some(TestField::new_real(Mersenne31::new(3)))
);
assert_eq!(trace.get(0, queries_start + 1), Some(three));
assert_eq!(
trace.get(0, queries_start + 2),
Some(TestField::new_real(Mersenne31::new(9)))
);
assert_eq!(trace.get(0, queries_start + 3), Some(TestField::ZERO));
}
}