Skip to main content

lib_q_zkp/air/
mod.rs

1//! AIR (Algebraic Intermediate Representation) module
2//!
3//! This module provides standalone AIR implementations for common proof types
4//! used in zero-knowledge proofs. Each AIR defines constraints that can be
5//! verified using STARK proving systems.
6//!
7//! # Available AIRs
8//!
9//! - [`crate::air::arithmetic::ArithmeticAir`] - Basic arithmetic operations (multiplication constraints)
10//! - [`crate::air::range_proof::RangeProofAir`] - Proves a value is within a specified range
11//! - [`crate::air::hash_preimage::HashPreimageAir`] - Proves knowledge of a Poseidon-128 preimage (industry-standard for STARK constraint encoding)
12//! - [`crate::air::merkle_inclusion::MerkleInclusionAir`] - Proves membership in a Merkle tree
13//!
14//! # Security
15//!
16//! All AIR implementations follow these security principles:
17//! - Input validation to prevent DoS attacks
18//! - Automatic zeroization of secret data via `SecretWitness`
19//! - Constant-time operations where applicable
20//!
21//! # Example
22//!
23//! ```rust,ignore
24//! use lib_q_zkp::air::{ArithmeticAir, TraceGenerator};
25//! use lib_q_stark_field::extension::Complex;
26//! use lib_q_stark_mersenne31::Mersenne31;
27//!
28//! type Val = Complex<Mersenne31>;
29//!
30//! // Create an AIR for 3 multiplication operations
31//! let air = ArithmeticAir::new(3).unwrap();
32//!
33//! // Generate a trace
34//! let inputs = vec![(Val::from(2u32), Val::from(3u32))];
35//! let trace = air.generate_trace(&inputs)?;
36//! ```
37
38extern crate alloc;
39
40use alloc::string::{
41    String,
42    ToString,
43};
44use alloc::vec::Vec;
45use core::fmt;
46
47use lib_q_poseidon::{
48    PoseidonField,
49    PoseidonParams,
50    sbox,
51};
52use lib_q_stark_field::{
53    BasedVectorSpace,
54    Field,
55    PrimeCharacteristicRing,
56};
57use lib_q_stark_matrix::dense::RowMajorMatrix;
58use lib_q_stark_mersenne31::Mersenne31;
59
60pub mod anonymous_auth;
61pub mod arithmetic;
62pub mod batch_stark_verifier;
63pub mod commitment_verifier;
64pub mod constraint_verifier;
65pub mod credential;
66pub mod fri_verifier;
67pub mod hash_preimage;
68pub mod hash_preimage_nist;
69pub mod identity_proof;
70pub mod merkle_inclusion;
71pub mod opening_verifier;
72pub mod poseidon_gadget;
73pub mod poseidon_hash;
74pub mod range_proof;
75pub mod recursive_types;
76pub mod session_key;
77pub mod stark_verifier;
78pub mod state_transition;
79pub mod transaction;
80pub mod verifier_utils;
81
82pub use anonymous_auth::{
83    AnonymousAuthAir,
84    AnonymousAuthInput,
85};
86pub use arithmetic::ArithmeticAir;
87pub use batch_stark_verifier::{
88    BatchRecursiveStarkVerificationInput,
89    BatchStarkVerifierAir,
90    batch_recursive_verifier_public_values,
91};
92#[cfg(all(feature = "recursive-proofs-experimental", feature = "std"))]
93pub use commitment_verifier::debug_commitment_trace_sanity_check;
94pub use commitment_verifier::{
95    CommitmentVerificationInput,
96    CommitmentVerifierAir,
97};
98pub use constraint_verifier::{
99    ConstraintVerificationInput,
100    ConstraintVerifierAir,
101};
102pub use credential::{
103    CredentialAir,
104    CredentialInput,
105    CredentialSchema,
106};
107pub use fri_verifier::{
108    FriVerificationInput,
109    FriVerifierAir,
110};
111pub use hash_preimage::HashPreimageAir;
112pub use hash_preimage_nist::{
113    HASH_OUTPUT_BYTES,
114    HashPreimageNistAir,
115    HashPreimageNistInput,
116    expected_hash_to_public_values,
117};
118pub use identity_proof::{
119    IdentityProofAir,
120    IdentityProofInput,
121    MlDsaLevel,
122};
123pub use merkle_inclusion::{
124    MerkleHash,
125    MerkleInclusionAir,
126    MerkleProofInput,
127};
128pub use opening_verifier::{
129    OpeningVerificationInput,
130    OpeningVerifierAir,
131};
132pub use poseidon_gadget::PoseidonGadget;
133pub use poseidon_hash::PoseidonHashAir;
134pub use range_proof::RangeProofAir;
135pub use recursive_types::{
136    RecursiveStarkInput,
137    SerializedFriRound,
138    SerializedStarkProof,
139};
140pub use session_key::{
141    KdfAlgorithm,
142    KdfParams,
143    SessionKeyDerivationAir,
144    SessionKeyInput,
145};
146/// Trait for PCS commitments that are Poseidon Merkle roots. Used by the recursive verifier.
147/// The only implementation is when `recursive-proofs-experimental` is enabled (Hash in stark_verifier).
148pub trait PoseidonCommitmentRoot {
149    fn to_poseidon_root_bytes(&self) -> [u8; recursive_types::COMMITMENT_HASH_SIZE];
150}
151
152#[cfg(feature = "recursive-proofs-experimental")]
153pub use stark_verifier::{
154    MerklePathExtractable,
155    build_recursive_verification_input_from_proof,
156    build_recursive_verification_input_from_proof_with_poseidon,
157};
158pub use stark_verifier::{
159    RecursiveStarkVerificationInput,
160    StarkVerifierAir,
161    build_recursive_verification_input,
162};
163
164/// Maximum number of operations allowed in a single AIR instance
165/// to prevent memory exhaustion attacks.
166pub const MAX_OPERATIONS: usize = 1 << 20; // ~1 million operations
167
168/// Maximum trace width to prevent excessive memory allocation.
169/// Recursive StarkVerifierAir can exceed 65536; raised to 131072 for aggregation.
170pub const MAX_TRACE_WIDTH: usize = 1 << 17; // 131072 columns
171
172/// Maximum trace height (number of rows) to prevent memory exhaustion.
173pub const MAX_TRACE_HEIGHT: usize = 1 << 24; // ~16 million rows
174
175/// Error type for AIR operations
176#[derive(Debug, Clone, PartialEq, Eq)]
177pub enum AirError {
178    /// AIR configuration has invalid dimensions
179    InvalidDimensions {
180        /// Description of the dimension error
181        reason: String,
182    },
183
184    /// AIR exceeds maximum allowed size
185    ExceedsMaxSize {
186        /// Name of the parameter that exceeded limits
187        parameter: String,
188        /// Maximum allowed value
189        max: usize,
190        /// Actual value provided
191        actual: usize,
192    },
193
194    /// Invalid input data for trace generation
195    InvalidInput {
196        /// Description of what was invalid
197        reason: String,
198    },
199
200    /// Trace dimensions don't match AIR requirements
201    TraceMismatch {
202        /// Expected width
203        expected_width: usize,
204        /// Actual width
205        actual_width: usize,
206    },
207
208    /// Witness values don't satisfy constraints
209    InvalidWitness {
210        /// Description of which constraint failed
211        constraint: String,
212    },
213
214    /// Internal error during AIR evaluation
215    InternalError {
216        /// Description of the error
217        reason: String,
218    },
219
220    /// Feature required but not enabled
221    NotSupported {
222        /// Description of what is not supported
223        reason: String,
224    },
225
226    /// FRI commit-phase openings missing for the query index
227    MissingFriCommitPhaseOpenings,
228
229    /// FRI commit-phase step count does not match number of rounds
230    FriRoundCountMismatch,
231}
232
233impl fmt::Display for AirError {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        match self {
236            AirError::InvalidDimensions { reason } => {
237                write!(f, "Invalid AIR dimensions: {}", reason)
238            }
239            AirError::ExceedsMaxSize {
240                parameter,
241                max,
242                actual,
243            } => {
244                write!(
245                    f,
246                    "AIR parameter '{}' exceeds maximum: max={}, actual={}",
247                    parameter, max, actual
248                )
249            }
250            AirError::InvalidInput { reason } => {
251                write!(f, "Invalid input for trace generation: {}", reason)
252            }
253            AirError::TraceMismatch {
254                expected_width,
255                actual_width,
256            } => {
257                write!(
258                    f,
259                    "Trace width mismatch: expected {}, got {}",
260                    expected_width, actual_width
261                )
262            }
263            AirError::InvalidWitness { constraint } => {
264                write!(
265                    f,
266                    "Invalid witness: constraint '{}' not satisfied",
267                    constraint
268                )
269            }
270            AirError::InternalError { reason } => {
271                write!(f, "Internal AIR error: {}", reason)
272            }
273            AirError::NotSupported { reason } => {
274                write!(f, "Not supported: {}", reason)
275            }
276            AirError::MissingFriCommitPhaseOpenings => {
277                write!(f, "FRI commit-phase openings missing for query index")
278            }
279            AirError::FriRoundCountMismatch => {
280                write!(
281                    f,
282                    "FRI commit-phase step count does not match number of rounds"
283                )
284            }
285        }
286    }
287}
288
289impl From<AirError> for lib_q_core::Error {
290    fn from(err: AirError) -> Self {
291        lib_q_core::Error::InternalError {
292            operation: "AIR operation".into(),
293            details: err.to_string(),
294        }
295    }
296}
297
298/// Trait for AIRs that can generate execution traces from inputs
299///
300/// This trait extends the basic AIR functionality with the ability to
301/// generate valid execution traces from given inputs. The trace can then
302/// be used with STARK proving to generate proofs.
303///
304/// # Type Parameters
305///
306/// - `F`: The field type for trace values
307/// - `I`: The input type for trace generation
308pub trait TraceGenerator<F: Field, I> {
309    /// Generate an execution trace from the given inputs
310    ///
311    /// # Arguments
312    ///
313    /// * `inputs` - The inputs to generate the trace from
314    ///
315    /// # Returns
316    ///
317    /// A `RowMajorMatrix<F>` containing the trace, or an error if trace
318    /// generation fails.
319    ///
320    /// # Errors
321    ///
322    /// Returns `AirError` if:
323    /// - Input dimensions are invalid
324    /// - Input values don't produce a valid trace
325    /// - Memory allocation fails
326    fn generate_trace(&self, inputs: &I) -> Result<RowMajorMatrix<F>, AirError>;
327
328    /// Get the public values from the given inputs
329    ///
330    /// Public values are the values that are shared between prover and verifier.
331    /// These are typically outputs or commitments that are part of the statement
332    /// being proven.
333    ///
334    /// # Arguments
335    ///
336    /// * `inputs` - The inputs to extract public values from
337    ///
338    /// # Returns
339    ///
340    /// A vector of public field elements
341    fn public_values(&self, inputs: &I) -> Vec<F> {
342        let _ = inputs;
343        Vec::new()
344    }
345}
346
347/// Helper function to validate trace dimensions
348///
349/// # Arguments
350///
351/// * `width` - Trace width (number of columns)
352/// * `height` - Trace height (number of rows)
353///
354/// # Returns
355///
356/// `Ok(())` if dimensions are valid, `Err(AirError)` otherwise
357pub fn validate_trace_dimensions(width: usize, height: usize) -> Result<(), AirError> {
358    if width == 0 {
359        return Err(AirError::InvalidDimensions {
360            reason: "Trace width must be greater than 0".into(),
361        });
362    }
363
364    if width > MAX_TRACE_WIDTH {
365        return Err(AirError::ExceedsMaxSize {
366            parameter: "width".into(),
367            max: MAX_TRACE_WIDTH,
368            actual: width,
369        });
370    }
371
372    if height == 0 {
373        return Err(AirError::InvalidDimensions {
374            reason: "Trace height must be greater than 0".into(),
375        });
376    }
377
378    if height > MAX_TRACE_HEIGHT {
379        return Err(AirError::ExceedsMaxSize {
380            parameter: "height".into(),
381            max: MAX_TRACE_HEIGHT,
382            actual: height,
383        });
384    }
385
386    if !height.is_power_of_two() {
387        return Err(AirError::InvalidDimensions {
388            reason: "Trace height must be a power of 2".into(),
389        });
390    }
391
392    Ok(())
393}
394
395/// Round up to the next power of 2
396///
397/// # Arguments
398///
399/// * `n` - The number to round up
400///
401/// # Returns
402///
403/// The smallest power of 2 that is >= n
404pub fn next_power_of_two(n: usize) -> usize {
405    if n == 0 {
406        return 1;
407    }
408    n.next_power_of_two()
409}
410
411/// Convert PoseidonField to any Field F that supports u32 conversion
412///
413/// Converts PoseidonField (`Complex<Mersenne31>`) to the target field type,
414/// preserving both real and imaginary parts via basis decomposition.
415///
416/// # Arguments
417///
418/// * `pf` - The PoseidonField (`Complex<Mersenne31>`) to convert
419///
420/// # Returns
421///
422/// The field element in type F
423pub fn poseidon_to_field<F: Field + BasedVectorSpace<Mersenne31>>(pf: &PoseidonField) -> F {
424    let coeffs: &[Mersenne31] = pf.as_basis_coefficients_slice();
425    F::from_basis_coefficients_fn(|i| {
426        if i < coeffs.len() {
427            coeffs[i]
428        } else {
429            <Mersenne31 as PrimeCharacteristicRing>::ZERO
430        }
431    })
432}
433
434/// Convert slice of PoseidonField to `Vec<F>`
435///
436/// # Arguments
437///
438/// * `slice` - Slice of PoseidonField elements
439///
440/// # Returns
441///
442/// Vector of field elements in type F
443pub fn poseidon_slice_to_field<F: Field + BasedVectorSpace<Mersenne31>>(
444    slice: &[PoseidonField],
445) -> Vec<F> {
446    slice.iter().map(poseidon_to_field).collect()
447}
448
449/// Convert PoseidonField hash output to bytes
450///
451/// Uses RawDataSerializable to convert `Complex<Mersenne31>` elements to bytes.
452/// Each Complex element produces 8 bytes (4 for real, 4 for imag).
453///
454/// # Arguments
455///
456/// * `hash` - Slice of PoseidonField elements (hash output)
457///
458/// # Returns
459///
460/// Vector of bytes representing the hash
461pub fn poseidon_field_to_bytes(hash: &[PoseidonField]) -> Vec<u8> {
462    use lib_q_stark_field::RawDataSerializable;
463    // Complex<Mersenne31> has NUM_BYTES = 8 (4 real + 4 imag)
464    hash.iter().flat_map(|f| (*f).into_bytes()).collect()
465}
466
467/// Serialize a Merkle root (single PoseidonField) to a fixed 32-byte array.
468///
469/// Uses RawDataSerializable: one Complex&lt;Mersenne31&gt; produces 8 bytes (4 real + 4 imag, LE).
470/// The result is zero-padded to 32 bytes for a fixed-size root representation.
471///
472/// # Arguments
473///
474/// * `root` - The Merkle root as a Poseidon field element
475///
476/// # Returns
477///
478/// 32-byte array suitable for `verify_membership` and related APIs
479#[must_use]
480pub fn merkle_root_to_bytes(root: &PoseidonField) -> [u8; 32] {
481    use lib_q_stark_field::RawDataSerializable;
482    let mut out = [0u8; 32];
483    let bytes: Vec<u8> = (*root).into_bytes().into_iter().collect();
484    let n = core::cmp::min(bytes.len(), 32);
485    out[..n].copy_from_slice(&bytes[..n]);
486    out
487}
488
489/// Deserialize a Merkle root from bytes back to a PoseidonField.
490///
491/// Expects at least 8 bytes: first 4 bytes (u32 LE) = real part, next 4 = imag part
492/// of Complex&lt;Mersenne31&gt;. Used by verifiers to reconstruct the expected public value.
493///
494/// # Arguments
495///
496/// * `bytes` - At least 8 bytes (extra bytes are ignored)
497///
498/// # Returns
499///
500/// The root as PoseidonField, or InvalidInput if bytes.len() &lt; 8
501pub fn merkle_root_from_bytes(bytes: &[u8]) -> Result<PoseidonField, AirError> {
502    use lib_q_stark_field::extension::Complex;
503    use lib_q_stark_field::integers::QuotientMap;
504    use lib_q_stark_mersenne31::Mersenne31;
505
506    if bytes.len() < 8 {
507        return Err(AirError::InvalidInput {
508            reason: alloc::format!(
509                "Merkle root bytes must have at least 8 bytes, got {}",
510                bytes.len()
511            ),
512        });
513    }
514    let mut real_bytes = [0u8; 4];
515    let mut imag_bytes = [0u8; 4];
516    real_bytes.copy_from_slice(&bytes[0..4]);
517    imag_bytes.copy_from_slice(&bytes[4..8]);
518    let real = Mersenne31::from_int(u32::from_le_bytes(real_bytes));
519    let imag = Mersenne31::from_int(u32::from_le_bytes(imag_bytes));
520    Ok(Complex::new_complex(real, imag))
521}
522
523/// Compute one Poseidon permutation row: state in, intermediates, state out.
524///
525/// Uses `params.state_width` (e.g. 5 for Poseidon-128). Caller must pass at least
526/// `params.state_width` elements in `state`. Returns (final_state, intermediates).
527pub fn compute_poseidon_row(
528    state: &[PoseidonField],
529    params: &PoseidonParams,
530) -> (Vec<PoseidonField>, Vec<PoseidonField>) {
531    use lib_q_stark_field::extension::Complex;
532    use lib_q_stark_mersenne31::Mersenne31;
533
534    let n = params.state_width;
535    assert!(state.len() >= n, "state must have at least {} elements", n);
536    let zero = Complex::<Mersenne31>::new_complex(Mersenne31::ZERO, Mersenne31::ZERO);
537    let mut intermediates = Vec::new();
538    let mut round_idx = 0usize;
539    let mut s: Vec<PoseidonField> = state[0..n].to_vec();
540    let full_half = params.full_rounds / 2;
541
542    for _ in 0..full_half {
543        let after_arc: Vec<PoseidonField> = (0..n)
544            .map(|i| s[i] + params.round_constants[round_idx + i])
545            .collect();
546        round_idx += n;
547        intermediates.extend(after_arc.iter().cloned());
548        let after_sbox: Vec<PoseidonField> = (0..n).map(|i| sbox(after_arc[i])).collect();
549        intermediates.extend(after_sbox.iter().cloned());
550        let mut next_s = alloc::vec![zero; n];
551        for (i, next_s_i) in next_s.iter_mut().enumerate().take(n) {
552            for (j, &after_sbox_j) in after_sbox.iter().enumerate().take(n) {
553                *next_s_i += params.mds_matrix[i][j] * after_sbox_j;
554            }
555        }
556        intermediates.extend(next_s.iter().cloned());
557        s = next_s;
558    }
559    for _ in 0..params.partial_rounds {
560        let after_arc: Vec<PoseidonField> = (0..n)
561            .map(|i| s[i] + params.round_constants[round_idx + i])
562            .collect();
563        round_idx += n;
564        intermediates.extend(after_arc.iter().cloned());
565        let mut after_sbox = alloc::vec![zero; n];
566        after_sbox[0] = sbox(after_arc[0]);
567        after_sbox[1..n].copy_from_slice(&after_arc[1..n]);
568        intermediates.extend(after_sbox.iter().cloned());
569        let mut next_s = alloc::vec![zero; n];
570        for (i, next_s_i) in next_s.iter_mut().enumerate().take(n) {
571            for (j, &after_sbox_j) in after_sbox.iter().enumerate().take(n) {
572                *next_s_i += params.mds_matrix[i][j] * after_sbox_j;
573            }
574        }
575        intermediates.extend(next_s.iter().cloned());
576        s = next_s;
577    }
578    for _ in 0..full_half {
579        let after_arc: Vec<PoseidonField> = (0..n)
580            .map(|i| s[i] + params.round_constants[round_idx + i])
581            .collect();
582        round_idx += n;
583        intermediates.extend(after_arc.iter().cloned());
584        let after_sbox: Vec<PoseidonField> = (0..n).map(|i| sbox(after_arc[i])).collect();
585        intermediates.extend(after_sbox.iter().cloned());
586        let mut next_s = alloc::vec![zero; n];
587        for (i, next_s_i) in next_s.iter_mut().enumerate().take(n) {
588            for (j, &after_sbox_j) in after_sbox.iter().enumerate().take(n) {
589                *next_s_i += params.mds_matrix[i][j] * after_sbox_j;
590            }
591        }
592        intermediates.extend(next_s.iter().cloned());
593        s = next_s;
594    }
595    (s, intermediates)
596}
597
598/// Convert bytes to PoseidonField elements
599///
600/// This is a helper function to consistently convert byte slices to PoseidonField
601/// (`Complex<Mersenne31>`) elements. Each byte is converted to a field element.
602///
603/// # Arguments
604///
605/// * `bytes` - Slice of bytes to convert
606///
607/// # Returns
608///
609/// Vector of PoseidonField elements
610pub fn bytes_to_poseidon_field(bytes: &[u8]) -> Vec<PoseidonField> {
611    use lib_q_stark_field::extension::Complex;
612    use lib_q_stark_mersenne31::Mersenne31;
613    bytes
614        .iter()
615        .map(|b| Complex::<Mersenne31>::from(Mersenne31::new(*b as u32)))
616        .collect()
617}
618
619/// Decode the first 8 bytes of an Identity Token (IT) to the expected public value.
620/// The IT is the first 16 bytes of the encoding of the Poseidon hash output; the first 8 bytes
621/// encode one `Complex<Mersenne31>` (4 bytes real + 4 bytes imag, little-endian).
622pub fn it_bytes_to_public_value<F: Field + BasedVectorSpace<Mersenne31>>(it: &[u8; 16]) -> F {
623    use lib_q_stark_field::extension::Complex;
624    use lib_q_stark_mersenne31::Mersenne31;
625    let mut real_bytes = [0u8; 4];
626    let mut imag_bytes = [0u8; 4];
627    real_bytes.copy_from_slice(&it[0..4]);
628    imag_bytes.copy_from_slice(&it[4..8]);
629    let real = Mersenne31::new(u32::from_le_bytes(real_bytes));
630    let imag = Mersenne31::new(u32::from_le_bytes(imag_bytes));
631    let c = Complex::new_complex(real, imag);
632    poseidon_to_field::<F>(&c)
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    #[test]
640    fn test_validate_trace_dimensions_valid() {
641        assert!(validate_trace_dimensions(8, 16).is_ok());
642        assert!(validate_trace_dimensions(1, 1).is_ok());
643        assert!(validate_trace_dimensions(100, 1024).is_ok());
644    }
645
646    #[test]
647    fn test_validate_trace_dimensions_zero_width() {
648        let result = validate_trace_dimensions(0, 16);
649        assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
650    }
651
652    #[test]
653    fn test_validate_trace_dimensions_zero_height() {
654        let result = validate_trace_dimensions(8, 0);
655        assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
656    }
657
658    #[test]
659    fn test_validate_trace_dimensions_not_power_of_two() {
660        let result = validate_trace_dimensions(8, 15);
661        assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
662    }
663
664    #[test]
665    fn test_next_power_of_two() {
666        assert_eq!(next_power_of_two(0), 1);
667        assert_eq!(next_power_of_two(1), 1);
668        assert_eq!(next_power_of_two(2), 2);
669        assert_eq!(next_power_of_two(3), 4);
670        assert_eq!(next_power_of_two(5), 8);
671        assert_eq!(next_power_of_two(16), 16);
672    }
673
674    #[test]
675    fn test_air_error_display() {
676        let err = AirError::InvalidDimensions {
677            reason: "test".into(),
678        };
679        assert!(err.to_string().contains("Invalid AIR dimensions"));
680
681        let err = AirError::ExceedsMaxSize {
682            parameter: "width".into(),
683            max: 100,
684            actual: 200,
685        };
686        assert!(err.to_string().contains("exceeds maximum"));
687    }
688}