extern crate alloc;
use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;
use lib_q_stark_air::{
Air,
AirBuilder,
BaseAir,
WindowAccess,
};
use lib_q_stark_field::{
Field,
PrimeCharacteristicRing,
};
use lib_q_stark_matrix::dense::RowMajorMatrix;
use super::{
AirError,
TraceGenerator,
next_power_of_two,
validate_trace_dimensions,
};
pub const MAX_RANGE_BITS: usize = 64;
#[derive(Debug, Clone)]
pub struct RangeProofAir {
num_bits: usize,
}
impl RangeProofAir {
pub fn new(num_bits: usize) -> Result<Self, AirError> {
if num_bits == 0 {
return Err(AirError::InvalidDimensions {
reason: "Number of bits must be greater than 0".to_string(),
});
}
if num_bits > MAX_RANGE_BITS {
return Err(AirError::ExceedsMaxSize {
parameter: "num_bits".to_string(),
max: MAX_RANGE_BITS,
actual: num_bits,
});
}
Ok(Self { num_bits })
}
pub fn num_bits(&self) -> usize {
self.num_bits
}
pub fn upper_bound(&self) -> Option<usize> {
if self.num_bits >= usize::BITS as usize {
None
} else {
Some(1usize << self.num_bits)
}
}
}
impl<F: Field> BaseAir<F> for RangeProofAir {
fn width(&self) -> usize {
1 + self.num_bits
}
}
impl<AB: AirBuilder> Air<AB> for RangeProofAir
where
AB::F: Field,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.current_slice();
let value = local[0];
let bits = &local[1..];
let mut recomposed = AB::Expr::ZERO;
for (i, bit) in bits.iter().enumerate() {
let weight = AB::F::from_u64(1u64 << i);
recomposed += *bit * weight;
builder.assert_bool(*bit);
}
builder.assert_zero(value - recomposed);
}
}
pub type RangeProofInput<F> = Vec<F>;
impl<F: Field> TraceGenerator<F, RangeProofInput<F>> for RangeProofAir {
fn generate_trace(&self, inputs: &RangeProofInput<F>) -> Result<RowMajorMatrix<F>, AirError> {
if inputs.is_empty() {
return Err(AirError::InvalidInput {
reason: "Input list cannot be empty".to_string(),
});
}
let num_rows = next_power_of_two(inputs.len());
let width = 1 + self.num_bits; validate_trace_dimensions(width, num_rows)?;
let mut trace_values = vec![F::ZERO; num_rows * width];
for (row_idx, value) in inputs.iter().enumerate() {
let base = row_idx * width;
trace_values[base] = *value;
let decomposed = decompose_to_bits::<F>(*value, self.num_bits)?;
for (bit_idx, bit) in decomposed.iter().enumerate() {
trace_values[base + 1 + bit_idx] = *bit;
}
}
Ok(RowMajorMatrix::new(trace_values, width))
}
fn public_values(&self, inputs: &RangeProofInput<F>) -> Vec<F> {
inputs.clone()
}
}
fn decompose_to_bits<F: Field>(value: F, num_bits: usize) -> Result<Vec<F>, AirError> {
let mut bits = Vec::with_capacity(num_bits);
let mut remainder = value;
let two = F::TWO;
let two_inv = two.inverse();
for _i in 0..num_bits {
let half = remainder * two_inv;
let doubled = half + half;
let bit = if doubled == remainder {
F::ZERO } else {
F::ONE };
bits.push(bit);
remainder = (remainder - bit) * two_inv;
}
if remainder != F::ZERO {
return Err(AirError::InvalidInput {
reason: "Value exceeds range: decomposition has non-zero remainder".to_string(),
});
}
Ok(bits)
}
#[cfg(test)]
mod tests {
use alloc::vec;
use lib_q_stark_air::BaseAir;
use lib_q_stark_field::extension::Complex;
use lib_q_stark_matrix::Matrix;
use lib_q_stark_mersenne31::Mersenne31;
use super::*;
type TestField = Complex<Mersenne31>;
#[test]
fn test_range_proof_air_new_valid() {
let air = RangeProofAir::new(8);
assert!(air.is_ok());
assert_eq!(air.unwrap().num_bits(), 8);
}
#[test]
fn test_range_proof_air_new_zero_bits() {
let result = RangeProofAir::new(0);
assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
}
#[test]
fn test_range_proof_air_new_too_many_bits() {
let result = RangeProofAir::new(MAX_RANGE_BITS + 1);
assert!(matches!(result, Err(AirError::ExceedsMaxSize { .. })));
}
#[test]
fn test_range_proof_air_width() {
let air = RangeProofAir::new(8).unwrap();
assert_eq!(BaseAir::<TestField>::width(&air), 9); }
#[test]
fn test_upper_bound() {
let air = RangeProofAir::new(8).unwrap();
assert_eq!(air.upper_bound(), Some(256));
let air = RangeProofAir::new(16).unwrap();
assert_eq!(air.upper_bound(), Some(65536));
}
#[test]
fn test_decompose_to_bits_zero() {
let value = TestField::ZERO;
let bits = decompose_to_bits::<TestField>(value, 8).unwrap();
assert_eq!(bits.len(), 8);
assert!(bits.iter().all(|b| *b == TestField::ZERO));
}
#[test]
fn test_decompose_to_bits_small_value() {
let value = TestField::ZERO;
let bits = decompose_to_bits::<TestField>(value, 8).unwrap();
for bit in bits.iter() {
assert_eq!(*bit, TestField::ZERO);
}
}
#[test]
fn test_generate_trace_with_zero() {
let air = RangeProofAir::new(8).unwrap();
let inputs: RangeProofInput<TestField> = vec![TestField::ZERO];
let trace = air.generate_trace(&inputs);
assert!(trace.is_ok());
let trace = trace.unwrap();
assert_eq!(trace.width(), 9); assert!(trace.height().is_power_of_two());
}
#[test]
fn test_generate_trace_empty_input() {
let air = RangeProofAir::new(8).unwrap();
let inputs: RangeProofInput<TestField> = vec![];
let result = air.generate_trace(&inputs);
assert!(matches!(result, Err(AirError::InvalidInput { .. })));
}
#[test]
fn test_generate_trace_verifies_decomposition_zero() {
use lib_q_stark_matrix::Matrix;
let air = RangeProofAir::new(8).unwrap();
let value = TestField::ZERO;
let inputs = vec![value];
let trace = air.generate_trace(&inputs).unwrap();
let row = trace.row_slice(0).expect("row should exist");
assert_eq!(row[0], value);
let mut sum = TestField::ZERO;
for i in 0..8 {
let bit = row[1 + i];
let weight = TestField::from(Mersenne31::new(1u32 << i));
sum += bit * weight;
}
assert_eq!(sum, value);
}
}