extern crate alloc;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use lib_q_stark_air::{
Air,
AirBuilder,
BaseAir,
WindowAccess,
};
use lib_q_stark_field::{
Field,
PrimeCharacteristicRing,
};
use lib_q_stark_matrix::dense::RowMajorMatrix;
use super::{
AirError,
TraceGenerator,
next_power_of_two,
validate_trace_dimensions,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionType {
Payment,
ContractCall,
StateUpdate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SignatureMode {
MlDsa,
None,
}
pub const MAX_TRANSACTION_SIZE: usize = 8192;
#[derive(Debug, Clone)]
pub struct TransactionAir {
tx_type: TransactionType,
sig_mode: SignatureMode,
}
impl TransactionAir {
pub fn new(tx_type: TransactionType, sig_mode: SignatureMode) -> Self {
Self { tx_type, sig_mode }
}
pub fn tx_type(&self) -> TransactionType {
self.tx_type
}
pub fn sig_mode(&self) -> SignatureMode {
self.sig_mode
}
}
impl<F: Field> BaseAir<F> for TransactionAir {
fn width(&self) -> usize {
1 + MAX_TRANSACTION_SIZE + 4096 + 256 + 128
}
}
impl<AB: AirBuilder> Air<AB> for TransactionAir
where
AB::F: Field,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.current_slice();
let tx_type = local[0].into();
let one = AB::Expr::from(AB::F::ONE);
let two = one.clone() + one.clone();
builder.assert_zero(tx_type.clone() * (tx_type.clone() - one) * (tx_type.clone() - two));
}
}
#[derive(Debug, Clone)]
pub struct TransactionInput {
pub transaction_data: Vec<u8>,
pub signatures: Vec<Vec<u8>>,
}
impl
TraceGenerator<
lib_q_stark_field::extension::Complex<lib_q_stark_mersenne31::Mersenne31>,
TransactionInput,
> for TransactionAir
{
fn generate_trace(
&self,
inputs: &TransactionInput,
) -> Result<
RowMajorMatrix<lib_q_stark_field::extension::Complex<lib_q_stark_mersenne31::Mersenne31>>,
AirError,
> {
use lib_q_stark_field::extension::Complex;
use lib_q_stark_field::integers::QuotientMap;
use lib_q_stark_mersenne31::Mersenne31;
type Val = Complex<Mersenne31>;
if inputs.transaction_data.len() > MAX_TRANSACTION_SIZE {
return Err(AirError::ExceedsMaxSize {
parameter: "transaction_data".to_string(),
max: MAX_TRANSACTION_SIZE,
actual: inputs.transaction_data.len(),
});
}
let trace_width = {
use lib_q_stark_field::extension::Complex;
use lib_q_stark_mersenne31::Mersenne31;
type Val = Complex<Mersenne31>;
<Self as BaseAir<Val>>::width(self)
};
let trace_height = 1;
let num_rows_padded = next_power_of_two(trace_height);
validate_trace_dimensions(trace_width, num_rows_padded)?;
let mut trace_values = vec![Val::ZERO; num_rows_padded * trace_width];
let base = 0;
let tx_type_byte = match self.tx_type {
TransactionType::Payment => 0u8,
TransactionType::ContractCall => 1u8,
TransactionType::StateUpdate => 2u8,
};
trace_values[base] =
Val::from_prime_subfield(<Mersenne31 as QuotientMap<u8>>::from_int(tx_type_byte));
for (i, byte) in inputs.transaction_data.iter().enumerate() {
if i < MAX_TRANSACTION_SIZE {
trace_values[base + 1 + i] =
Val::from_prime_subfield(<Mersenne31 as QuotientMap<u8>>::from_int(*byte));
}
}
let sig_start = 1 + MAX_TRANSACTION_SIZE;
let mut sig_col = sig_start;
for sig in &inputs.signatures {
for (i, byte) in sig.iter().enumerate() {
if sig_col + i < trace_width && sig_col + i < sig_start + 4096 {
trace_values[base + sig_col + i] =
Val::from_prime_subfield(<Mersenne31 as QuotientMap<u8>>::from_int(*byte));
}
}
sig_col += sig.len();
}
Ok(RowMajorMatrix::new(trace_values, trace_width))
}
fn public_values(
&self,
_inputs: &TransactionInput,
) -> Vec<lib_q_stark_field::extension::Complex<lib_q_stark_mersenne31::Mersenne31>> {
Vec::new()
}
}
#[cfg(test)]
mod tests {
use lib_q_stark_field::extension::Complex;
use lib_q_stark_matrix::Matrix;
use lib_q_stark_mersenne31::Mersenne31;
use super::*;
#[test]
fn test_transaction_air_creation() {
let air = TransactionAir::new(TransactionType::Payment, SignatureMode::MlDsa);
assert_eq!(air.tx_type(), TransactionType::Payment);
assert_eq!(air.sig_mode(), SignatureMode::MlDsa);
}
#[test]
fn test_transaction_trace_generation() {
let air = TransactionAir::new(TransactionType::Payment, SignatureMode::MlDsa);
let input = TransactionInput {
transaction_data: vec![1, 2, 3, 4],
signatures: vec![vec![5, 6, 7, 8]],
};
let trace = air.generate_trace(&input);
assert!(trace.is_ok());
}
#[test]
fn test_transaction_public_values_empty() {
let input = TransactionInput {
transaction_data: vec![1, 2, 3, 4],
signatures: vec![vec![5, 6, 7, 8]],
};
for &tx_type in &[
TransactionType::Payment,
TransactionType::ContractCall,
TransactionType::StateUpdate,
] {
for &sig_mode in &[SignatureMode::MlDsa, SignatureMode::None] {
let air = TransactionAir::new(tx_type, sig_mode);
assert!(
air.public_values(&input).is_empty(),
"Expected empty public values for {tx_type:?} / {sig_mode:?}",
);
}
}
}
#[test]
fn test_transaction_trace_generation_rejects_oversized_transaction_data() {
let air = TransactionAir::new(TransactionType::Payment, SignatureMode::MlDsa);
let input = TransactionInput {
transaction_data: vec![0u8; MAX_TRANSACTION_SIZE + 1],
signatures: vec![],
};
let result = air.generate_trace(&input);
assert!(matches!(result, Err(AirError::ExceedsMaxSize { .. })));
}
#[test]
fn test_transaction_trace_sets_tx_type_column_for_all_variants() {
type Val = Complex<Mersenne31>;
let input = TransactionInput {
transaction_data: vec![],
signatures: vec![],
};
let cases = [
(TransactionType::Payment, 0u32),
(TransactionType::ContractCall, 1u32),
(TransactionType::StateUpdate, 2u32),
];
for (tx_type, expected) in cases {
let air = TransactionAir::new(tx_type, SignatureMode::None);
let trace = air.generate_trace(&input).expect("trace");
assert_eq!(trace.get(0, 0), Some(Val::from_u32(expected)));
}
}
#[test]
fn test_transaction_trace_writes_signature_bytes_and_stops_at_signature_window() {
type Val = Complex<Mersenne31>;
let air = TransactionAir::new(TransactionType::Payment, SignatureMode::MlDsa);
let input = TransactionInput {
transaction_data: vec![],
signatures: vec![vec![5u8; 5000]],
};
let trace = air.generate_trace(&input).expect("trace");
let sig_start = 1 + MAX_TRANSACTION_SIZE;
assert_eq!(trace.get(0, sig_start), Some(Val::from_u32(5)));
assert_eq!(trace.get(0, sig_start + 4095), Some(Val::from_u32(5)));
assert_eq!(trace.get(0, sig_start + 4096), Some(Val::ZERO));
}
}