extern crate alloc;
use alloc::vec::Vec;
use lib_q_poseidon::{
Poseidon128,
PoseidonField,
PoseidonParams,
};
use lib_q_stark_air::{
AirBuilder,
WindowAccess,
};
use lib_q_stark_field::{
BasedVectorSpace,
Field,
};
use lib_q_stark_mersenne31::Mersenne31;
use super::{
AirError,
poseidon_to_field,
};
struct PoseidonRoundParams<'a> {
round_constants: &'a [PoseidonField],
mds: &'a [Vec<PoseidonField>],
n: usize,
}
#[derive(Debug, Clone)]
pub struct PoseidonGadget {
params: PoseidonParams,
}
impl PoseidonGadget {
pub const COLUMNS_PER_HASH: usize = (8 + 56) * (5 * 3);
pub fn new() -> Self {
Self {
params: Poseidon128::params(),
}
}
pub fn params(&self) -> &PoseidonParams {
&self.params
}
pub fn constrain<AB: AirBuilder>(
&self,
builder: &mut AB,
left: AB::Expr,
right: AB::Expr,
output: AB::Expr,
intermediate_start_col: usize,
) -> Result<(), AirError>
where
AB::F: Field + BasedVectorSpace<Mersenne31>,
{
use lib_q_stark_field::PrimeCharacteristicRing;
let n = self.params.state_width;
let zero_expr = AB::Expr::from(<AB::F as PrimeCharacteristicRing>::ZERO);
let initial_state: Vec<AB::Expr> = (0..n)
.map(|i| {
if i == 0 {
left.clone()
} else if i == 1 {
right.clone()
} else {
zero_expr.clone()
}
})
.collect();
self.constrain_full_state(builder, &initial_state, output, intermediate_start_col)
}
pub fn constrain_full_state<AB: AirBuilder>(
&self,
builder: &mut AB,
initial_state: &[AB::Expr],
output: AB::Expr,
intermediate_start_col: usize,
) -> Result<(), AirError>
where
AB::F: Field + BasedVectorSpace<Mersenne31>,
{
let main = builder.main();
let local = main.current_slice();
let full_rounds = self.params.full_rounds;
let partial_rounds = self.params.partial_rounds;
let full_rounds_half = full_rounds / 2;
let n = self.params.state_width;
if initial_state.len() != n {
return Err(AirError::InvalidDimensions {
reason: alloc::format!(
"initial_state must have {} elements, got {}",
n,
initial_state.len()
),
});
}
let mds = &self.params.mds_matrix;
let round_constants = &self.params.round_constants;
let mut state: Vec<AB::Expr> = initial_state.to_vec();
let mut round_const_idx = 0;
let mut intermediate_col = intermediate_start_col;
let round_params = PoseidonRoundParams {
round_constants,
mds,
n,
};
for _ in 0..full_rounds_half {
state = self.constrain_full_round(
builder,
&state,
&round_params,
&mut round_const_idx,
local,
&mut intermediate_col,
)?;
}
for _ in 0..partial_rounds {
state = self.constrain_partial_round(
builder,
&state,
&round_params,
&mut round_const_idx,
local,
&mut intermediate_col,
)?;
}
for _ in 0..full_rounds_half {
state = self.constrain_full_round(
builder,
&state,
&round_params,
&mut round_const_idx,
local,
&mut intermediate_col,
)?;
}
builder.assert_eq(state[0].clone(), output);
Ok(())
}
fn constrain_full_round<AB: AirBuilder>(
&self,
builder: &mut AB,
state: &[AB::Expr],
params: &PoseidonRoundParams<'_>,
round_const_idx: &mut usize,
local: &[AB::Var],
intermediate_col: &mut usize,
) -> Result<Vec<AB::Expr>, AirError>
where
AB::F: Field + BasedVectorSpace<Mersenne31>,
{
use lib_q_stark_field::PrimeCharacteristicRing;
let zero = AB::Expr::from(<AB::F as PrimeCharacteristicRing>::ZERO);
let n = params.n;
for i in 0..n {
let rc_field =
poseidon_to_field::<AB::F>(¶ms.round_constants[*round_const_idx + i]);
let expected = state[i].clone() + AB::Expr::from(rc_field);
builder.assert_eq(local[*intermediate_col + i].into(), expected);
}
*round_const_idx += n;
*intermediate_col += n;
for i in 0..n {
let arc_val = local[*intermediate_col - n + i];
let arc_sq = arc_val * arc_val;
let arc_quad = arc_sq.clone() * arc_sq.clone();
let expected_sbox = arc_quad * arc_val;
builder.assert_eq(local[*intermediate_col + i].into(), expected_sbox);
}
*intermediate_col += n;
let mut next_state: Vec<AB::Expr> = (0..n).map(|_| zero.clone()).collect();
for i in 0..n {
for j in 0..n {
let sbox_val = local[*intermediate_col - n + j];
let mds_field = poseidon_to_field::<AB::F>(¶ms.mds[i][j]);
next_state[i] = next_state[i].clone() + AB::Expr::from(mds_field) * sbox_val;
}
builder.assert_eq(local[*intermediate_col + i].into(), next_state[i].clone());
}
*intermediate_col += n;
Ok(next_state)
}
fn constrain_partial_round<AB: AirBuilder>(
&self,
builder: &mut AB,
state: &[AB::Expr],
params: &PoseidonRoundParams<'_>,
round_const_idx: &mut usize,
local: &[AB::Var],
intermediate_col: &mut usize,
) -> Result<Vec<AB::Expr>, AirError>
where
AB::F: Field + BasedVectorSpace<Mersenne31>,
{
use lib_q_stark_field::PrimeCharacteristicRing;
let zero = AB::Expr::from(<AB::F as PrimeCharacteristicRing>::ZERO);
let n = params.n;
for i in 0..n {
let rc_field =
poseidon_to_field::<AB::F>(¶ms.round_constants[*round_const_idx + i]);
let expected = state[i].clone() + AB::Expr::from(rc_field);
builder.assert_eq(local[*intermediate_col + i].into(), expected);
}
*round_const_idx += n;
*intermediate_col += n;
let arc_val_0 = local[*intermediate_col - n];
let arc_sq_0 = arc_val_0 * arc_val_0;
let arc_quad_0 = arc_sq_0.clone() * arc_sq_0.clone();
let expected_sbox_0 = arc_quad_0 * arc_val_0;
builder.assert_eq(local[*intermediate_col].into(), expected_sbox_0);
for i in 1..n {
builder.assert_eq(
local[*intermediate_col + i].into(),
local[*intermediate_col - n + i].into(),
);
}
*intermediate_col += n;
let mut next_state: Vec<AB::Expr> = (0..n).map(|_| zero.clone()).collect();
for i in 0..n {
for j in 0..n {
let sbox_val = local[*intermediate_col - n + j];
let mds_field = poseidon_to_field::<AB::F>(¶ms.mds[i][j]);
next_state[i] = next_state[i].clone() + AB::Expr::from(mds_field) * sbox_val;
}
builder.assert_eq(local[*intermediate_col + i].into(), next_state[i].clone());
}
*intermediate_col += n;
Ok(next_state)
}
}
impl Default for PoseidonGadget {
fn default() -> Self {
Self::new()
}
}