use std::collections::HashMap;
#[cfg(feature = "prover")]
use std::iter::zip;
use itertools::{chain, Itertools};
use serde::{Deserialize, Serialize};
use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId;
use super::bitwise_xor::BitwiseXor;
use super::blake::{BlakeSigma, N_BLAKE_SIGMA_COLS};
use super::pedersen::{PedersenPoints, PEDERSEN_TABLE_N_COLUMNS};
use super::poseidon::{PoseidonRoundKeys, N_WORDS as POSEIDON_N_WORDS};
#[cfg(feature = "prover")]
use super::simd_prelude::*;
const XOR_N_BITS: [u32; 5] = [4, 7, 8, 9, 10];
pub const MAX_SEQUENCE_LOG_SIZE: u32 = 25;
pub const MIN_SEQUENCE_LOG_SIZE: u32 = 4;
pub const SMALL_MAX_SEQUENCE_LOG_SIZE: u32 = 20;
pub const CANONICAL_SIZE: u32 = 543100528;
pub const CANONICAL_WITHOUT_PEDERSEN_SIZE: u32 = 73338480;
pub const CANONICAL_SMALL: u32 = 10161776;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum PreProcessedTraceVariant {
Canonical,
CanonicalWithoutPedersen,
CanonicalSmall,
}
impl PreProcessedTraceVariant {
pub fn to_preprocessed_trace(&self) -> PreProcessedTrace {
match self {
PreProcessedTraceVariant::Canonical => PreProcessedTrace::canonical(),
PreProcessedTraceVariant::CanonicalWithoutPedersen => {
PreProcessedTrace::canonical_without_pedersen()
}
PreProcessedTraceVariant::CanonicalSmall => PreProcessedTrace::canonical_small(),
}
}
}
pub trait PreProcessedColumn: Send + Sync {
#[cfg(feature = "prover")]
fn packed_at(&self, vec_row: usize) -> PackedM31;
fn log_size(&self) -> u32;
fn id(&self) -> PreProcessedColumnId;
#[cfg(feature = "prover")]
fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>;
}
pub struct PreProcessedTrace {
pub columns: Vec<Box<dyn PreProcessedColumn>>,
pub column_indices: HashMap<PreProcessedColumnId, usize>,
pub variant: PreProcessedTraceVariant,
}
impl PreProcessedTrace {
fn get_column_indices(
columns: &[Box<dyn PreProcessedColumn>],
) -> HashMap<PreProcessedColumnId, usize> {
let mut column_indices = HashMap::new();
for (i, column) in columns.iter().enumerate() {
column_indices.insert(column.id(), i);
}
column_indices
}
pub fn canonical() -> Self {
let columns = Self::canonical_columns();
let column_indices = Self::get_column_indices(&columns);
Self {
columns,
column_indices,
variant: PreProcessedTraceVariant::Canonical,
}
}
fn canonical_columns() -> Vec<Box<dyn PreProcessedColumn>> {
let canonical_without_pedersen = Self::canonical_without_pedersen_columns();
let pedersen_points = (0..PEDERSEN_TABLE_N_COLUMNS)
.map(|x| Box::new(PedersenPoints::<18>::new(x)) as Box<dyn PreProcessedColumn>);
let columns = chain!(canonical_without_pedersen, pedersen_points)
.sorted_by_key(|column| column.log_size())
.collect_vec();
assert_eq!(
columns.iter().map(|col| 1 << col.log_size()).sum::<u32>(),
CANONICAL_SIZE,
"Canonical preprocessed trace has unexpected size"
);
columns
}
pub fn canonical_without_pedersen() -> Self {
let columns = Self::canonical_without_pedersen_columns();
let column_indices = Self::get_column_indices(&columns);
Self {
columns,
column_indices,
variant: PreProcessedTraceVariant::CanonicalWithoutPedersen,
}
}
fn canonical_without_pedersen_columns() -> Vec<Box<dyn PreProcessedColumn>> {
let seq = (MIN_SEQUENCE_LOG_SIZE..=MAX_SEQUENCE_LOG_SIZE)
.map(|x| Box::new(Seq::new(x)) as Box<dyn PreProcessedColumn>);
let bitwise_xor = XOR_N_BITS
.map(|n_bits| {
(0..3).map(move |col_index| {
Box::new(BitwiseXor::new(n_bits, col_index)) as Box<dyn PreProcessedColumn>
})
})
.into_iter()
.flatten();
let range_check = gen_range_check_columns();
let poseidon_keys = (0..POSEIDON_N_WORDS)
.map(|x| Box::new(PoseidonRoundKeys::new(x)) as Box<dyn PreProcessedColumn>);
let blake_sigma = (0..N_BLAKE_SIGMA_COLS)
.map(|x| Box::new(BlakeSigma::new(x)) as Box<dyn PreProcessedColumn>);
let columns = chain!(seq, bitwise_xor, range_check, poseidon_keys, blake_sigma)
.sorted_by_key(|column| column.log_size())
.collect_vec();
assert_eq!(
columns.iter().map(|col| 1 << col.log_size()).sum::<u32>(),
CANONICAL_WITHOUT_PEDERSEN_SIZE,
"Canonical without pedersen preprocessed trace has unexpected size"
);
columns
}
pub fn canonical_small() -> Self {
let columns = Self::canonical_small_columns();
let column_indices = Self::get_column_indices(&columns);
Self {
columns,
column_indices,
variant: PreProcessedTraceVariant::CanonicalSmall,
}
}
fn canonical_small_columns() -> Vec<Box<dyn PreProcessedColumn>> {
let seq = (MIN_SEQUENCE_LOG_SIZE..=SMALL_MAX_SEQUENCE_LOG_SIZE)
.map(|x| Box::new(Seq::new(x)) as Box<dyn PreProcessedColumn>);
let bitwise_xor = XOR_N_BITS
.map(|n_bits| {
(0..3).map(move |col_index| {
Box::new(BitwiseXor::new(n_bits, col_index)) as Box<dyn PreProcessedColumn>
})
})
.into_iter()
.flatten();
let range_check = gen_range_check_columns();
let poseidon_keys = (0..POSEIDON_N_WORDS)
.map(|x| Box::new(PoseidonRoundKeys::new(x)) as Box<dyn PreProcessedColumn>);
let blake_sigma = (0..N_BLAKE_SIGMA_COLS)
.map(|x| Box::new(BlakeSigma::new(x)) as Box<dyn PreProcessedColumn>);
let pedersen_points = (0..PEDERSEN_TABLE_N_COLUMNS)
.map(|x| Box::new(PedersenPoints::<9>::new(x)) as Box<dyn PreProcessedColumn>);
let columns = chain!(
seq,
bitwise_xor,
range_check,
poseidon_keys,
blake_sigma,
pedersen_points
)
.sorted_by_key(|column| column.log_size())
.collect_vec();
assert_eq!(
columns.iter().map(|col| 1 << col.log_size()).sum::<u32>(),
CANONICAL_SMALL,
"Canonical small preprocessed trace has unexpected size"
);
columns
}
pub fn log_sizes(&self) -> Vec<u32> {
self.columns.iter().map(|c| c.log_size()).collect()
}
pub fn get_column(&self, id: &PreProcessedColumnId) -> &dyn PreProcessedColumn {
self.columns[*self
.column_indices
.get(id)
.unwrap_or_else(|| panic!("Missing preprocessed column {id:?}"))]
.as_ref()
}
pub fn has_column(&self, id: &PreProcessedColumnId) -> bool {
self.column_indices.contains_key(id)
}
pub fn ids(&self) -> Vec<PreProcessedColumnId> {
self.columns.iter().map(|c| c.id()).collect()
}
}
fn gen_range_check_columns() -> Vec<Box<dyn PreProcessedColumn>> {
let range_check_4_3_col_0 = RangeCheck::new([4, 3], 0);
let range_check_4_3_col_1 = RangeCheck::new([4, 3], 1);
let range_check_4_4_col_0 = RangeCheck::new([4, 4], 0);
let range_check_4_4_col_1 = RangeCheck::new([4, 4], 1);
let range_check_9_9_col_0 = RangeCheck::new([9, 9], 0);
let range_check_9_9_col_1 = RangeCheck::new([9, 9], 1);
let range_check_7_2_5_col_0 = RangeCheck::new([7, 2, 5], 0);
let range_check_7_2_5_col_1 = RangeCheck::new([7, 2, 5], 1);
let range_check_7_2_5_col_2 = RangeCheck::new([7, 2, 5], 2);
let range_check_3_6_6_3_col_0 = RangeCheck::new([3, 6, 6, 3], 0);
let range_check_3_6_6_3_col_1 = RangeCheck::new([3, 6, 6, 3], 1);
let range_check_3_6_6_3_col_2 = RangeCheck::new([3, 6, 6, 3], 2);
let range_check_3_6_6_3_col_3 = RangeCheck::new([3, 6, 6, 3], 3);
let range_check_4_4_4_4_col_0 = RangeCheck::new([4, 4, 4, 4], 0);
let range_check_4_4_4_4_col_1 = RangeCheck::new([4, 4, 4, 4], 1);
let range_check_4_4_4_4_col_2 = RangeCheck::new([4, 4, 4, 4], 2);
let range_check_4_4_4_4_col_3 = RangeCheck::new([4, 4, 4, 4], 3);
let range_check_3_3_3_3_3_col_0 = RangeCheck::new([3, 3, 3, 3, 3], 0);
let range_check_3_3_3_3_3_col_1 = RangeCheck::new([3, 3, 3, 3, 3], 1);
let range_check_3_3_3_3_3_col_2 = RangeCheck::new([3, 3, 3, 3, 3], 2);
let range_check_3_3_3_3_3_col_3 = RangeCheck::new([3, 3, 3, 3, 3], 3);
let range_check_3_3_3_3_3_col_4 = RangeCheck::new([3, 3, 3, 3, 3], 4);
vec![
Box::new(range_check_4_3_col_0),
Box::new(range_check_4_3_col_1),
Box::new(range_check_4_4_col_0),
Box::new(range_check_4_4_col_1),
Box::new(range_check_9_9_col_0),
Box::new(range_check_9_9_col_1),
Box::new(range_check_7_2_5_col_0),
Box::new(range_check_7_2_5_col_1),
Box::new(range_check_7_2_5_col_2),
Box::new(range_check_3_6_6_3_col_0),
Box::new(range_check_3_6_6_3_col_1),
Box::new(range_check_3_6_6_3_col_2),
Box::new(range_check_3_6_6_3_col_3),
Box::new(range_check_4_4_4_4_col_0),
Box::new(range_check_4_4_4_4_col_1),
Box::new(range_check_4_4_4_4_col_2),
Box::new(range_check_4_4_4_4_col_3),
Box::new(range_check_3_3_3_3_3_col_0),
Box::new(range_check_3_3_3_3_3_col_1),
Box::new(range_check_3_3_3_3_3_col_2),
Box::new(range_check_3_3_3_3_3_col_3),
Box::new(range_check_3_3_3_3_3_col_4),
]
}
#[derive(Debug, Clone)]
pub struct Seq {
pub log_size: u32,
}
impl Seq {
pub const fn new(log_size: u32) -> Self {
Self { log_size }
}
}
impl PreProcessedColumn for Seq {
fn log_size(&self) -> u32 {
self.log_size
}
#[cfg(feature = "prover")]
fn packed_at(&self, vec_row: usize) -> PackedM31 {
PackedM31::broadcast(M31::from(vec_row * N_LANES))
+ unsafe { PackedM31::from_simd_unchecked(SIMD_ENUMERATION_0) }
}
#[cfg(feature = "prover")]
fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
let col = Col::<SimdBackend, BaseField>::from_iter(
(0..(1 << self.log_size)).map(BaseField::from),
);
CircleEvaluation::new(CanonicCoset::new(self.log_size).circle_domain(), col)
}
fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("seq_{}", self.log_size),
}
}
}
#[cfg(feature = "prover")]
pub fn partition_into_bit_segments<const N: usize>(
mut value: Simd<u32, N_LANES>,
n_bits_per_segment: [u32; N],
) -> [Simd<u32, N_LANES>; N] {
let mut segments = [Simd::splat(0); N];
for (segment, segment_n_bits) in zip(&mut segments, n_bits_per_segment).rev() {
let mask = Simd::splat((1 << segment_n_bits) - 1);
*segment = value & mask;
value >>= segment_n_bits;
}
segments
}
#[cfg(feature = "prover")]
pub fn generate_partitioned_enumeration<const N: usize>(
n_bits_per_segmants: [u32; N],
) -> [Vec<PackedM31>; N] {
let sum_bits = n_bits_per_segmants.iter().sum::<u32>();
assert!(sum_bits < MODULUS_BITS);
let mut res = std::array::from_fn(|_| vec![]);
for vec_row in 0..1 << (sum_bits - LOG_N_LANES) {
let value = SIMD_ENUMERATION_0 + Simd::splat(vec_row * N_LANES as u32);
let segments = partition_into_bit_segments(value, n_bits_per_segmants);
for i in 0..N {
res[i].push(unsafe { PackedM31::from_simd_unchecked(segments[i]) });
}
}
res
}
pub struct RangeCheck<const N: usize> {
ranges: [u32; N],
column_idx: usize,
}
impl<const N: usize> RangeCheck<N> {
pub fn new(ranges: [u32; N], column_idx: usize) -> Self {
assert!(ranges.iter().all(|&r| r > 0));
assert!(column_idx < N);
Self { ranges, column_idx }
}
}
impl<const N: usize> PreProcessedColumn for RangeCheck<N> {
fn log_size(&self) -> u32 {
self.ranges.iter().sum()
}
#[cfg(feature = "prover")]
fn packed_at(&self, vec_row: usize) -> PackedM31 {
let shift: u32 = self.ranges[(self.column_idx + 1)..].iter().sum();
let mask = Simd::splat((1 << self.ranges[self.column_idx]) - 1);
let simd_result =
((SIMD_ENUMERATION_0 + Simd::splat((vec_row * N_LANES) as u32)) >> shift) & mask;
unsafe { PackedM31::from_simd_unchecked(simd_result) }
}
#[cfg(feature = "prover")]
fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
let partitions = generate_partitioned_enumeration(self.ranges);
let column = partitions.into_iter().nth(self.column_idx).unwrap();
CircleEvaluation::new(
CanonicCoset::new(self.log_size()).circle_domain(),
BaseColumn::from_simd(column),
)
}
fn id(&self) -> PreProcessedColumnId {
let ranges = self.ranges.iter().join("_");
PreProcessedColumnId {
id: format!("range_check_{}_column_{}", ranges, self.column_idx).to_string(),
}
}
}
pub fn testing_preprocessed_tree(max_log_size: u32) -> PreProcessedTrace {
let canonical = PreProcessedTrace::canonical_columns();
let columns: Vec<Box<dyn PreProcessedColumn>> = canonical
.into_iter()
.filter(|c| c.log_size() <= max_log_size)
.collect();
let column_indices = PreProcessedTrace::get_column_indices(&columns);
PreProcessedTrace {
columns,
column_indices,
variant: PreProcessedTraceVariant::Canonical,
}
}
#[cfg(test)]
pub mod tests {
use super::*;
const LOG_SIZE: u32 = 8;
use stwo::prover::backend::Column;
#[test]
fn test_columns_are_in_ascending_order() {
let preprocessed_trace = PreProcessedTrace::canonical();
let columns = preprocessed_trace.columns;
assert!(columns
.windows(2)
.all(|w| w[0].log_size() <= w[1].log_size()));
}
#[cfg(feature = "prover")]
#[test]
fn test_gen_seq() {
let seq = Seq::new(LOG_SIZE).gen_column_simd();
for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}
#[cfg(feature = "prover")]
#[test]
fn test_packed_at_seq() {
let seq = Seq::new(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));
let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();
assert_eq!(packed_seq, expected_seq);
}
#[cfg(feature = "prover")]
#[test]
fn test_range_check_gen_column_simd() {
let ranges = [3, 1];
let expected_0 = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7].map(M31);
let expected_1 = [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1].map(M31);
let col_0 = RangeCheck::new(ranges, 0);
let col_1 = RangeCheck::new(ranges, 1);
let col_0_simd = col_0.gen_column_simd().to_cpu().to_vec();
let col_1_simd = col_1.gen_column_simd().to_cpu().to_vec();
assert_eq!(col_0_simd, expected_0);
assert_eq!(col_1_simd, expected_1);
}
#[test]
fn test_range_check_id() {
let ranges = [1, 2, 3, 4];
let range_check = RangeCheck::new(ranges, 2);
let id = range_check.id();
assert_eq!(id.id, "range_check_1_2_3_4_column_2");
}
}