use stwo::core::fields::m31::M31;
use stwo_constraint_framework::preprocessed_columns::PreProcessedColumnId;
use super::poseidon_round_keys::round_keys;
use super::preprocessed_trace::PreProcessedColumn;
#[cfg(feature = "prover")]
use super::preprocessed_utils::pad;
#[cfg(feature = "prover")]
use super::simd_prelude::*;
use crate::prover_types::cpu::FELT252WIDTH27_N_WORDS;
const LOG_N_ROWS: u32 = (N_ROUNDS as u32).next_power_of_two().ilog2();
#[cfg(feature = "prover")]
const N_PACKED_ROWS: usize = (2_u32.pow(LOG_N_ROWS)) as usize / N_LANES;
pub const N_ROUNDS: usize = 35;
pub const N_FELT252WIDTH27: usize = 3;
pub const N_WORDS: usize = FELT252WIDTH27_N_WORDS * N_FELT252WIDTH27;
pub fn round_keys_m31(round: usize, col: usize) -> M31 {
assert!(col < N_WORDS);
assert!(round < N_ROUNDS);
let felt252_index = col / FELT252WIDTH27_N_WORDS;
let m31_index = col % FELT252WIDTH27_N_WORDS;
round_keys(round)[felt252_index].get_m31(m31_index)
}
#[derive(Debug)]
pub struct PoseidonRoundKeys {
#[cfg(feature = "prover")]
pub packed_keys: [PackedM31; N_PACKED_ROWS],
pub col: usize,
}
impl PoseidonRoundKeys {
pub fn new(col: usize) -> Self {
#[cfg(feature = "prover")]
let packed_keys = BaseColumn::from_iter(pad(round_keys_m31, N_ROUNDS, col)).data;
Self {
#[cfg(feature = "prover")]
packed_keys: packed_keys.try_into().unwrap(),
col,
}
}
}
impl PreProcessedColumn for PoseidonRoundKeys {
fn log_size(&self) -> u32 {
LOG_N_ROWS
}
#[cfg(feature = "prover")]
fn packed_at(&self, vec_row: usize) -> PackedM31 {
self.packed_keys[vec_row]
}
#[cfg(feature = "prover")]
fn gen_column_simd(&self) -> CircleEvaluation<SimdBackend, BaseField, BitReversedOrder> {
CircleEvaluation::new(
CanonicCoset::new(LOG_N_ROWS).circle_domain(),
BaseColumn::from_simd(self.packed_keys.to_vec()),
)
}
fn id(&self) -> PreProcessedColumnId {
PreProcessedColumnId {
id: format!("poseidon_round_keys_{}", self.col),
}
}
}
#[cfg(feature = "prover")]
#[cfg(test)]
mod tests {
use std::array::from_fn;
use stwo::prover::backend::simd::m31::N_LANES;
use super::*;
use crate::prover_types::cpu::Felt252Width27;
#[test]
fn test_packed_at_round_keys() {
for vec_row in 0..N_PACKED_ROWS {
for i in 0..N_FELT252WIDTH27 {
let packed: [[M31; N_LANES]; FELT252WIDTH27_N_WORDS] = from_fn(|c| {
PoseidonRoundKeys::new((i * FELT252WIDTH27_N_WORDS) + c)
.packed_at(vec_row)
.to_array()
});
for row_in_packed in 0..N_LANES {
let felt_limbs: [M31; FELT252WIDTH27_N_WORDS] = packed
.iter()
.map(|arr| arr[row_in_packed])
.collect::<Vec<_>>()
.try_into()
.unwrap();
let row = (vec_row * N_LANES) + row_in_packed;
if row < N_ROUNDS {
assert_eq!(Felt252Width27::from_limbs(&felt_limbs), round_keys(row)[i]);
} else {
assert_eq!(Felt252Width27::from_limbs(&felt_limbs), round_keys(0)[i]);
}
}
}
}
}
}