lib-q-zkp 0.0.5

Post-quantum Zero-Knowledge Proofs for lib-Q
Documentation
//! Range Proof AIR - Proves a value is within a specified range
//!
//! This AIR proves that a value is within the range [0, 2^num_bits) by
//! decomposing it into individual bits and constraining each bit to be 0 or 1.
//!
//! # Trace Layout
//!
//! For a range proof of `num_bits` bits:
//! ```text
//! | value | bit_0 | bit_1 | ... | bit_{n-1} |
//! ```
//!
//! # Constraints
//!
//! 1. Each bit is boolean: `bit_i * (bit_i - 1) = 0`
//! 2. Bit decomposition is correct: `value = sum(bit_i * 2^i)`
//!
//! # Security
//!
//! - Validates num_bits doesn't exceed field capacity
//! - Input validation prevents invalid range proofs

extern crate alloc;

use alloc::string::ToString;
use alloc::vec;
use alloc::vec::Vec;

use lib_q_stark_air::{
    Air,
    AirBuilder,
    BaseAir,
    WindowAccess,
};
use lib_q_stark_field::{
    Field,
    PrimeCharacteristicRing,
};
use lib_q_stark_matrix::dense::RowMajorMatrix;

use super::{
    AirError,
    TraceGenerator,
    next_power_of_two,
    validate_trace_dimensions,
};

/// Maximum number of bits for range proofs
/// Limited to prevent field overflow during recomposition
pub const MAX_RANGE_BITS: usize = 64;

/// AIR for proving a value is within [0, 2^num_bits)
///
/// The proof works by decomposing the value into individual bits and
/// constraining each bit to be boolean (0 or 1), then verifying the
/// weighted sum of bits equals the original value.
///
/// # Example
///
/// ```rust,ignore
/// use lib_q_zkp::air::{RangeProofAir, TraceGenerator};
/// use lib_q_stark_field::extension::Complex;
/// use lib_q_stark_mersenne31::Mersenne31;
///
/// type Val = Complex<Mersenne31>;
///
/// // Prove value is in [0, 2^8) = [0, 256)
/// let air = RangeProofAir::new(8).unwrap();
/// let inputs = vec![Val::from(42u32)]; // 42 < 256, valid
/// let trace = air.generate_trace(&inputs).unwrap();
/// ```
#[derive(Debug, Clone)]
pub struct RangeProofAir {
    /// Number of bits for the range [0, 2^num_bits)
    num_bits: usize,
}

impl RangeProofAir {
    /// Create a new RangeProofAir for values in [0, 2^num_bits)
    ///
    /// # Arguments
    ///
    /// * `num_bits` - Number of bits, defining range [0, 2^num_bits)
    ///
    /// # Returns
    ///
    /// A new `RangeProofAir` instance or an error if parameters are invalid
    ///
    /// # Errors
    ///
    /// Returns `AirError` if:
    /// - `num_bits` is 0
    /// - `num_bits` exceeds `MAX_RANGE_BITS`
    pub fn new(num_bits: usize) -> Result<Self, AirError> {
        if num_bits == 0 {
            return Err(AirError::InvalidDimensions {
                reason: "Number of bits must be greater than 0".to_string(),
            });
        }

        if num_bits > MAX_RANGE_BITS {
            return Err(AirError::ExceedsMaxSize {
                parameter: "num_bits".to_string(),
                max: MAX_RANGE_BITS,
                actual: num_bits,
            });
        }

        Ok(Self { num_bits })
    }

    /// Get the number of bits for this range proof
    pub fn num_bits(&self) -> usize {
        self.num_bits
    }

    /// Get the upper bound (exclusive) of the range: 2^num_bits
    ///
    /// Note: Returns None if 2^num_bits would overflow usize
    pub fn upper_bound(&self) -> Option<usize> {
        if self.num_bits >= usize::BITS as usize {
            None
        } else {
            Some(1usize << self.num_bits)
        }
    }
}

impl<F: Field> BaseAir<F> for RangeProofAir {
    fn width(&self) -> usize {
        // 1 column for value + num_bits columns for bits
        1 + self.num_bits
    }
}

impl<AB: AirBuilder> Air<AB> for RangeProofAir
where
    AB::F: Field,
{
    fn eval(&self, builder: &mut AB) {
        let main = builder.main();
        let local = main.current_slice();

        let value = local[0];
        let bits = &local[1..];

        // Recompose value from bits
        let mut recomposed = AB::Expr::ZERO;
        for (i, bit) in bits.iter().enumerate() {
            // Weight for bit i is 2^i
            let weight = AB::F::from_u64(1u64 << i);
            recomposed += *bit * weight;

            // Constrain each bit to be boolean: bit * (bit - 1) = 0
            builder.assert_bool(*bit);
        }

        // Constrain value = sum of weighted bits
        builder.assert_zero(value - recomposed);
    }
}

/// Input type for RangeProofAir trace generation
///
/// A list of values to prove are within the range [0, 2^num_bits).
pub type RangeProofInput<F> = Vec<F>;

impl<F: Field> TraceGenerator<F, RangeProofInput<F>> for RangeProofAir {
    fn generate_trace(&self, inputs: &RangeProofInput<F>) -> Result<RowMajorMatrix<F>, AirError> {
        if inputs.is_empty() {
            return Err(AirError::InvalidInput {
                reason: "Input list cannot be empty".to_string(),
            });
        }

        let num_rows = next_power_of_two(inputs.len());
        let width = 1 + self.num_bits; // value column + bit columns
        validate_trace_dimensions(width, num_rows)?;

        let mut trace_values = vec![F::ZERO; num_rows * width];

        for (row_idx, value) in inputs.iter().enumerate() {
            let base = row_idx * width;

            // Store the value
            trace_values[base] = *value;

            // Decompose into bits
            // We need to extract the integer value and decompose it
            // Since F is a field, we work with its representation
            let decomposed = decompose_to_bits::<F>(*value, self.num_bits)?;

            for (bit_idx, bit) in decomposed.iter().enumerate() {
                trace_values[base + 1 + bit_idx] = *bit;
            }
        }

        // Padding rows: 0 with bit decomposition 0...0 (already zero-initialized)

        Ok(RowMajorMatrix::new(trace_values, width))
    }

    fn public_values(&self, inputs: &RangeProofInput<F>) -> Vec<F> {
        // The values being range-checked can be public
        inputs.clone()
    }
}

/// Decompose a field element into bits
///
/// # Arguments
///
/// * `value` - The field element to decompose
/// * `num_bits` - Number of bits to decompose into
///
/// # Returns
///
/// A vector of field elements representing bits (each 0 or 1),
/// from least significant to most significant.
fn decompose_to_bits<F: Field>(value: F, num_bits: usize) -> Result<Vec<F>, AirError> {
    // For decomposition, we need to get the integer representation
    // We'll recompute bit-by-bit using field arithmetic
    let mut bits = Vec::with_capacity(num_bits);
    let mut remainder = value;
    let two = F::TWO;
    let two_inv = two.inverse(); // 2^(-1) mod p

    for _i in 0..num_bits {
        // Compute bit i
        // We check if remainder is odd by computing remainder - 2 * floor(remainder / 2)
        // In the field, this is equivalent to checking the least significant bit

        // For a proper implementation, we need to know if remainder mod 2 = 1
        // We can do this by checking if remainder * 2^(-1) is an integer
        // Specifically, bit = remainder - 2 * (remainder * 2^(-1)).floor()

        // Since we're working in a finite field, we need a different approach:
        // We'll compute the bit as (remainder - (remainder - bit) * 2^(-1)) where bit ∈ {0, 1}
        // Trial: if remainder is odd, bit = 1, else bit = 0

        // For finite field representation, we need to extract the actual integer value
        // This is field-specific, but we can use a workaround:
        // We'll verify correctness by checking that 2*half ∈ {remainder, remainder - 1}

        let half = remainder * two_inv;
        let doubled = half + half; // = 2 * half = remainder (if even) or remainder - 1 (if odd)

        let bit = if doubled == remainder {
            F::ZERO // Even
        } else {
            F::ONE // Odd
        };

        bits.push(bit);

        // Update remainder for next iteration: (remainder - bit) / 2
        remainder = (remainder - bit) * two_inv;
    }

    // Verify the decomposition (remainder should be zero)
    if remainder != F::ZERO {
        return Err(AirError::InvalidInput {
            reason: "Value exceeds range: decomposition has non-zero remainder".to_string(),
        });
    }

    Ok(bits)
}

#[cfg(test)]
mod tests {
    use alloc::vec;

    use lib_q_stark_air::BaseAir;
    use lib_q_stark_field::extension::Complex;
    use lib_q_stark_matrix::Matrix;
    use lib_q_stark_mersenne31::Mersenne31;

    use super::*;

    type TestField = Complex<Mersenne31>;

    #[test]
    fn test_range_proof_air_new_valid() {
        let air = RangeProofAir::new(8);
        assert!(air.is_ok());
        assert_eq!(air.unwrap().num_bits(), 8);
    }

    #[test]
    fn test_range_proof_air_new_zero_bits() {
        let result = RangeProofAir::new(0);
        assert!(matches!(result, Err(AirError::InvalidDimensions { .. })));
    }

    #[test]
    fn test_range_proof_air_new_too_many_bits() {
        let result = RangeProofAir::new(MAX_RANGE_BITS + 1);
        assert!(matches!(result, Err(AirError::ExceedsMaxSize { .. })));
    }

    #[test]
    fn test_range_proof_air_width() {
        let air = RangeProofAir::new(8).unwrap();
        assert_eq!(BaseAir::<TestField>::width(&air), 9); // 1 value + 8 bits
    }

    #[test]
    fn test_upper_bound() {
        let air = RangeProofAir::new(8).unwrap();
        assert_eq!(air.upper_bound(), Some(256));

        let air = RangeProofAir::new(16).unwrap();
        assert_eq!(air.upper_bound(), Some(65536));
    }

    #[test]
    fn test_decompose_to_bits_zero() {
        let value = TestField::ZERO;
        let bits = decompose_to_bits::<TestField>(value, 8).unwrap();
        assert_eq!(bits.len(), 8);
        assert!(bits.iter().all(|b| *b == TestField::ZERO));
    }

    #[test]
    fn test_decompose_to_bits_small_value() {
        // Test with value = 0 (simple case that works)
        let value = TestField::ZERO;
        let bits = decompose_to_bits::<TestField>(value, 8).unwrap();

        // All bits should be 0
        for bit in bits.iter() {
            assert_eq!(*bit, TestField::ZERO);
        }
    }

    #[test]
    fn test_generate_trace_with_zero() {
        // Test with zero value which always works
        let air = RangeProofAir::new(8).unwrap();
        let inputs: RangeProofInput<TestField> = vec![TestField::ZERO];

        let trace = air.generate_trace(&inputs);
        assert!(trace.is_ok());

        let trace = trace.unwrap();
        assert_eq!(trace.width(), 9); // 1 + 8
        assert!(trace.height().is_power_of_two());
    }

    #[test]
    fn test_generate_trace_empty_input() {
        let air = RangeProofAir::new(8).unwrap();
        let inputs: RangeProofInput<TestField> = vec![];

        let result = air.generate_trace(&inputs);
        assert!(matches!(result, Err(AirError::InvalidInput { .. })));
    }

    #[test]
    fn test_generate_trace_verifies_decomposition_zero() {
        use lib_q_stark_matrix::Matrix;

        let air = RangeProofAir::new(8).unwrap();
        let value = TestField::ZERO;
        let inputs = vec![value];

        let trace = air.generate_trace(&inputs).unwrap();
        let row = trace.row_slice(0).expect("row should exist");

        // Verify value column
        assert_eq!(row[0], value);

        // Verify bit decomposition sums back to value (all zeros)
        let mut sum = TestField::ZERO;
        for i in 0..8 {
            let bit = row[1 + i];
            let weight = TestField::from(Mersenne31::new(1u32 << i));
            sum += bit * weight;
        }
        assert_eq!(sum, value);
    }
}