stwo-cairo-common 1.0.0

Common types and utilities shared across Stwo Cairo crates
Documentation
use std::collections::HashMap;
#[cfg(feature = "prover")]
use std::iter::zip;

use itertools::{chain, Itertools};
use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId;

use super::bitwise_xor::BitwiseXor;
use super::blake::{BlakeSigma, N_BLAKE_SIGMA_COLS};
use super::pedersen::{PedersenPoints, PEDERSEN_TABLE_N_COLUMNS};
use super::poseidon::{PoseidonRoundKeys, N_WORDS as POSEIDON_N_WORDS};
#[cfg(feature = "prover")]
use super::simd_prelude::*;

// Size to initialize the preprocessed trace with for `PreprocessedColumn::BitwiseXor`.
const XOR_N_BITS: [u32; 5] = [4, 7, 8, 9, 10];

// Used by every builtin for a read of the memory.
pub const MAX_SEQUENCE_LOG_SIZE: u32 = 25;
pub const MIN_SEQUENCE_LOG_SIZE: u32 = 4;

// The total number of trace cells in the canonical preprocessed trace.
pub const CANONICAL_SIZE: u32 = 543100528;
// The total number of trace cells in the canonical without pedersen preprocessed trace.
pub const CANONICAL_WITHOUT_PEDERSEN_SIZE: u32 = 73338480;

pub trait PreProcessedColumn: Send + Sync {
    #[cfg(feature = "prover")]
    fn packed_at(&self, vec_row: usize) -> PackedM31;
    fn log_size(&self) -> u32;
    fn id(&self) -> PreProcessedColumnId;
    #[cfg(feature = "prover")]
    fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>;
}

/// A collection of preprocessed columns, whose values are publicly acknowledged, and independent of
/// the proof. The Canonical `PreProcessedTrace`, generated by `PreProcessedTrace::new()`, is the
/// only one allowed to be used in proving Cairo programs, as it's commitment is known the verifier.
pub struct PreProcessedTrace {
    pub columns: Vec<Box<dyn PreProcessedColumn>>,
    pub column_indices: HashMap<PreProcessedColumnId, usize>,
}
impl PreProcessedTrace {
    fn from_columns(columns: Vec<Box<dyn PreProcessedColumn>>) -> Self {
        let mut column_indices = HashMap::new();

        for (i, column) in columns.iter().enumerate() {
            column_indices.insert(column.id(), i);
        }

        Self {
            columns,
            column_indices,
        }
    }

    /// Generates a canonical preprocessed trace. Used in proving Generic Cairo code & Starknet
    /// blocks.
    pub fn canonical() -> Self {
        let canonical_without_pedersen = Self::canonical_without_pedersen().columns;
        let pedersen_points = (0..PEDERSEN_TABLE_N_COLUMNS)
            .map(|x| Box::new(PedersenPoints::<18>::new(x)) as Box<dyn PreProcessedColumn>);

        let columns = chain!(canonical_without_pedersen, pedersen_points)
            .sorted_by_key(|column| column.log_size())
            .collect_vec();

        assert!(
            columns.iter().map(|col| 1 << col.log_size()).sum::<u32>() == CANONICAL_SIZE,
            "Canonical preprocessed trace has unexpected size"
        );

        Self::from_columns(columns)
    }

    /// Generates a canonical preprocessed trace without the `Pedersen` points. Used in proving
    /// programs that do not use `Pedersen` hash, e.g. the recursive verifier.
    pub fn canonical_without_pedersen() -> Self {
        let seq = (MIN_SEQUENCE_LOG_SIZE..=MAX_SEQUENCE_LOG_SIZE)
            .map(|x| Box::new(Seq::new(x)) as Box<dyn PreProcessedColumn>);
        let bitwise_xor = XOR_N_BITS
            .map(|n_bits| {
                (0..3).map(move |col_index| {
                    Box::new(BitwiseXor::new(n_bits, col_index)) as Box<dyn PreProcessedColumn>
                })
            })
            .into_iter()
            .flatten();
        let range_check = gen_range_check_columns();
        let poseidon_keys = (0..POSEIDON_N_WORDS)
            .map(|x| Box::new(PoseidonRoundKeys::new(x)) as Box<dyn PreProcessedColumn>);
        let blake_sigma = (0..N_BLAKE_SIGMA_COLS)
            .map(|x| Box::new(BlakeSigma::new(x)) as Box<dyn PreProcessedColumn>);

        let columns = chain!(seq, bitwise_xor, range_check, poseidon_keys, blake_sigma)
            .sorted_by_key(|column| column.log_size())
            .collect_vec();

        assert!(
            columns.iter().map(|col| 1 << col.log_size()).sum::<u32>()
                == CANONICAL_WITHOUT_PEDERSEN_SIZE,
            "Canonical without pedersen preprocessed trace has unexpected size"
        );

        Self::from_columns(columns)
    }

    pub fn log_sizes(&self) -> Vec<u32> {
        self.columns.iter().map(|c| c.log_size()).collect()
    }

    pub fn get_column(&self, id: &PreProcessedColumnId) -> &dyn PreProcessedColumn {
        self.columns[*self
            .column_indices
            .get(id)
            .unwrap_or_else(|| panic!("Missing preprocessed column {id:?}"))]
        .as_ref()
    }

    pub fn has_column(&self, id: &PreProcessedColumnId) -> bool {
        self.column_indices.contains_key(id)
    }

    pub fn ids(&self) -> Vec<PreProcessedColumnId> {
        self.columns.iter().map(|c| c.id()).collect()
    }
}

fn gen_range_check_columns() -> Vec<Box<dyn PreProcessedColumn>> {
    // RangeCheck_4_3.
    let range_check_4_3_col_0 = RangeCheck::new([4, 3], 0);
    let range_check_4_3_col_1 = RangeCheck::new([4, 3], 1);
    // RangeCheck_4_4.
    let range_check_4_4_col_0 = RangeCheck::new([4, 4], 0);
    let range_check_4_4_col_1 = RangeCheck::new([4, 4], 1);
    // RangeCheck_9_9.
    let range_check_9_9_col_0 = RangeCheck::new([9, 9], 0);
    let range_check_9_9_col_1 = RangeCheck::new([9, 9], 1);
    // RangeCheck_7_2_5.
    let range_check_7_2_5_col_0 = RangeCheck::new([7, 2, 5], 0);
    let range_check_7_2_5_col_1 = RangeCheck::new([7, 2, 5], 1);
    let range_check_7_2_5_col_2 = RangeCheck::new([7, 2, 5], 2);
    // RangeCheck_3_6_6_3.
    let range_check_3_6_6_3_col_0 = RangeCheck::new([3, 6, 6, 3], 0);
    let range_check_3_6_6_3_col_1 = RangeCheck::new([3, 6, 6, 3], 1);
    let range_check_3_6_6_3_col_2 = RangeCheck::new([3, 6, 6, 3], 2);
    let range_check_3_6_6_3_col_3 = RangeCheck::new([3, 6, 6, 3], 3);
    // RangeCheck_4_4_4_4.
    let range_check_4_4_4_4_col_0 = RangeCheck::new([4, 4, 4, 4], 0);
    let range_check_4_4_4_4_col_1 = RangeCheck::new([4, 4, 4, 4], 1);
    let range_check_4_4_4_4_col_2 = RangeCheck::new([4, 4, 4, 4], 2);
    let range_check_4_4_4_4_col_3 = RangeCheck::new([4, 4, 4, 4], 3);
    // RangeCheck_3_3_3_3_3.
    let range_check_3_3_3_3_3_col_0 = RangeCheck::new([3, 3, 3, 3, 3], 0);
    let range_check_3_3_3_3_3_col_1 = RangeCheck::new([3, 3, 3, 3, 3], 1);
    let range_check_3_3_3_3_3_col_2 = RangeCheck::new([3, 3, 3, 3, 3], 2);
    let range_check_3_3_3_3_3_col_3 = RangeCheck::new([3, 3, 3, 3, 3], 3);
    let range_check_3_3_3_3_3_col_4 = RangeCheck::new([3, 3, 3, 3, 3], 4);

    vec![
        Box::new(range_check_4_3_col_0),
        Box::new(range_check_4_3_col_1),
        Box::new(range_check_4_4_col_0),
        Box::new(range_check_4_4_col_1),
        Box::new(range_check_9_9_col_0),
        Box::new(range_check_9_9_col_1),
        Box::new(range_check_7_2_5_col_0),
        Box::new(range_check_7_2_5_col_1),
        Box::new(range_check_7_2_5_col_2),
        Box::new(range_check_3_6_6_3_col_0),
        Box::new(range_check_3_6_6_3_col_1),
        Box::new(range_check_3_6_6_3_col_2),
        Box::new(range_check_3_6_6_3_col_3),
        Box::new(range_check_4_4_4_4_col_0),
        Box::new(range_check_4_4_4_4_col_1),
        Box::new(range_check_4_4_4_4_col_2),
        Box::new(range_check_4_4_4_4_col_3),
        Box::new(range_check_3_3_3_3_3_col_0),
        Box::new(range_check_3_3_3_3_3_col_1),
        Box::new(range_check_3_3_3_3_3_col_2),
        Box::new(range_check_3_3_3_3_3_col_3),
        Box::new(range_check_3_3_3_3_3_col_4),
    ]
}

/// A column with the numbers [0..(2^log_size)-1].
#[derive(Debug, Clone)]
pub struct Seq {
    pub log_size: u32,
}
impl Seq {
    pub const fn new(log_size: u32) -> Self {
        Self { log_size }
    }
}
impl PreProcessedColumn for Seq {
    fn log_size(&self) -> u32 {
        self.log_size
    }
    #[cfg(feature = "prover")]
    fn packed_at(&self, vec_row: usize) -> PackedM31 {
        PackedM31::broadcast(M31::from(vec_row * N_LANES))
            + unsafe { PackedM31::from_simd_unchecked(SIMD_ENUMERATION_0) }
    }
    #[cfg(feature = "prover")]
    fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
        let col = Col::<SimdBackend, BaseField>::from_iter(
            (0..(1 << self.log_size)).map(BaseField::from),
        );
        CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
    }
    fn id(&self) -> PreProcessedColumnId {
        PreProcessedColumnId {
            id: format!("seq_{}", self.log_size),
        }
    }
}

/// Partitions a number into 'N' bit segments.
///
/// For example: partition_into_bit_segments(0b110101010, [3, 4, 2]) -> [0b110, 0b1010, 0b10]
///
///
/// # Arguments
#[cfg(feature = "prover")]
pub fn partition_into_bit_segments<const N: usize>(
    mut value: Simd<u32, N_LANES>,
    n_bits_per_segment: [u32; N],
) -> [Simd<u32, N_LANES>; N] {
    let mut segments = [Simd::splat(0); N];
    for (segment, segment_n_bits) in zip(&mut segments, n_bits_per_segment).rev() {
        let mask = Simd::splat((1 << segment_n_bits) - 1);
        *segment = value & mask;
        value >>= segment_n_bits;
    }
    segments
}

/// Generates the map from 0..2^(sum_bits) to the corresponding value's partition segments.
#[cfg(feature = "prover")]
pub fn generate_partitioned_enumeration<const N: usize>(
    n_bits_per_segmants: [u32; N],
) -> [Vec<PackedM31>; N] {
    let sum_bits = n_bits_per_segmants.iter().sum::<u32>();
    assert!(sum_bits < MODULUS_BITS);

    let mut res = std::array::from_fn(|_| vec![]);
    for vec_row in 0..1 << (sum_bits - LOG_N_LANES) {
        let value = SIMD_ENUMERATION_0 + Simd::splat(vec_row * N_LANES as u32);
        let segments = partition_into_bit_segments(value, n_bits_per_segmants);
        for i in 0..N {
            res[i].push(unsafe { PackedM31::from_simd_unchecked(segments[i]) });
        }
    }
    res
}

pub struct RangeCheck<const N: usize> {
    ranges: [u32; N],
    column_idx: usize,
}
impl<const N: usize> RangeCheck<N> {
    pub fn new(ranges: [u32; N], column_idx: usize) -> Self {
        assert!(ranges.iter().all(|&r| r > 0));
        assert!(column_idx < N);
        Self { ranges, column_idx }
    }
}
impl<const N: usize> PreProcessedColumn for RangeCheck<N> {
    fn log_size(&self) -> u32 {
        self.ranges.iter().sum()
    }

    #[cfg(feature = "prover")]
    fn packed_at(&self, vec_row: usize) -> PackedM31 {
        let shift: u32 = self.ranges[(self.column_idx + 1)..].iter().sum();
        let mask = Simd::splat((1 << self.ranges[self.column_idx]) - 1);
        let simd_result =
            ((SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> shift) & mask;
        unsafe { PackedM31::from_simd_unchecked(simd_result) }
    }

    #[cfg(feature = "prover")]
    fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
        let partitions = generate_partitioned_enumeration(self.ranges);
        let column = partitions.into_iter().nth(self.column_idx).unwrap();
        CircleEvaluation::new(
            CanonicCoset::new(self.log_size()).circle_domain(),
            BaseColumn::from_simd(column),
        )
    }

    fn id(&self) -> PreProcessedColumnId {
        let ranges = self.ranges.iter().join("_");
        PreProcessedColumnId {
            id: format!("range_check_{}_column_{}", ranges, self.column_idx).to_string(),
        }
    }
}

/// Generates a dummy preprocessed trace with columns up to `max_log_size`.
/// As such, tests that use columns larger than `max_log_size` will fail.
pub fn testing_preprocessed_tree(max_log_size: u32) -> PreProcessedTrace {
    let canonical = PreProcessedTrace::canonical();
    let columns = canonical
        .columns
        .into_iter()
        .filter(|c| c.log_size() <= max_log_size)
        .collect();
    PreProcessedTrace::from_columns(columns)
}

#[cfg(test)]
pub mod tests {
    use super::*;
    const LOG_SIZE: u32 = 8;
    use stwo::prover::backend::Column;

    #[test]
    fn test_columns_are_in_ascending_order() {
        let preprocessed_trace = PreProcessedTrace::canonical();

        let columns = preprocessed_trace.columns;

        assert!(columns
            .windows(2)
            .all(|w| w[0].log_size() <= w[1].log_size()));
    }

    #[cfg(feature = "prover")]
    #[test]
    fn test_gen_seq() {
        let seq = Seq::new(LOG_SIZE).gen_column_simd();
        for i in 0..(1 << LOG_SIZE) {
            assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
        }
    }

    #[cfg(feature = "prover")]
    #[test]
    fn test_packed_at_seq() {
        let seq = Seq::new(LOG_SIZE);
        let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
        let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
            seq.packed_at(i).to_array()
        })
        .concat();
        assert_eq!(packed_seq, expected_seq);
    }

    #[cfg(feature = "prover")]
    #[test]
    fn test_range_check_gen_column_simd() {
        let ranges = [3, 1];
        let expected_0 = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7].map(M31);
        let expected_1 = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1].map(M31);

        let col_0 = RangeCheck::new(ranges, 0);
        let col_1 = RangeCheck::new(ranges, 1);
        let col_0_simd = col_0.gen_column_simd().to_cpu().to_vec();
        let col_1_simd = col_1.gen_column_simd().to_cpu().to_vec();

        assert_eq!(col_0_simd, expected_0);
        assert_eq!(col_1_simd, expected_1);
    }

    #[test]
    fn test_range_check_id() {
        let ranges = [1, 2, 3, 4];
        let range_check = RangeCheck::new(ranges, 2);

        let id = range_check.id();

        assert_eq!(id.id, "range_check_1_2_3_4_column_2");
    }
}