use miden_core::{chiplets::hasher::Hasher, field::PrimeCharacteristicRing};
use miden_crypto::stark::air::AirBuilder;
use crate::{
MidenAirBuilder,
constraints::chiplets::columns::{HasherPeriodicCols, PermutationCols},
trace::chiplets::hasher::STATE_WIDTH,
};
pub fn enforce_permutation_steps<AB>(
builder: &mut AB,
perm_gate: AB::Expr,
cols: &PermutationCols<AB::Var>,
cols_next: &PermutationCols<AB::Var>,
periodic: &HasherPeriodicCols<AB::PeriodicVar>,
) where
AB: MidenAirBuilder,
{
let h: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| cols.state[i].into());
let h_next: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| cols_next.state[i].into());
let w: [AB::Expr; 3] = core::array::from_fn(|i| cols.witnesses[i].into());
let is_init_ext: AB::Expr = periodic.is_init_ext.into();
let is_ext: AB::Expr = periodic.is_ext.into();
let is_packed_int: AB::Expr = periodic.is_packed_int.into();
let is_int_ext: AB::Expr = periodic.is_int_ext.into();
let ark: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| periodic.ark[i].into());
let mat_diag: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| Hasher::MAT_DIAG[i].into());
let ark_int_21: AB::Expr = Hasher::ARK_INT[21].into();
builder
.when(perm_gate.clone() * (AB::Expr::ONE - is_packed_int.clone() - is_int_ext.clone()))
.assert_zero(w[0].clone());
{
let builder =
&mut builder.when(perm_gate.clone() * (AB::Expr::ONE - is_packed_int.clone()));
builder.assert_zero(w[1].clone());
builder.assert_zero(w[2].clone());
}
{
let expected = apply_init_plus_ext(&h, &ark);
let builder = &mut builder.when(perm_gate.clone() * is_init_ext);
for i in 0..STATE_WIDTH {
builder.assert_eq(h_next[i].clone(), expected[i].clone());
}
}
{
let ext_with_rc: [AB::Expr; STATE_WIDTH] =
core::array::from_fn(|i| h[i].clone() + ark[i].clone());
let ext_with_sbox: [AB::Expr; STATE_WIDTH] =
core::array::from_fn(|i| ext_with_rc[i].clone().exp_const_u64::<7>());
let expected = apply_matmul_external(&ext_with_sbox);
let builder = &mut builder.when(perm_gate.clone() * is_ext);
for i in 0..STATE_WIDTH {
builder.assert_eq(h_next[i].clone(), expected[i].clone());
}
}
{
let ark_int_3: [AB::Expr; 3] = core::array::from_fn(|i| ark[i].clone());
let (expected, witness_checks) = apply_packed_internals(&h, &w, &ark_int_3, &mat_diag);
let builder = &mut builder.when(perm_gate.clone() * is_packed_int);
for wc in &witness_checks {
builder.assert_zero(wc.clone());
}
for i in 0..STATE_WIDTH {
builder.assert_eq(h_next[i].clone(), expected[i].clone());
}
}
{
let (expected, witness_check) =
apply_internal_plus_ext(&h, &w[0], ark_int_21, &ark, &mat_diag);
let builder = &mut builder.when(perm_gate * is_int_ext);
builder.assert_zero(witness_check);
for i in 0..STATE_WIDTH {
builder.assert_eq(h_next[i].clone(), expected[i].clone());
}
}
}
fn apply_matmul_external<E: PrimeCharacteristicRing>(state: &[E; STATE_WIDTH]) -> [E; STATE_WIDTH] {
let b0 = matmul_m4(core::array::from_fn(|i| state[i].clone()));
let b1 = matmul_m4(core::array::from_fn(|i| state[4 + i].clone()));
let b2 = matmul_m4(core::array::from_fn(|i| state[8 + i].clone()));
let stored0 = b0[0].clone() + b1[0].clone() + b2[0].clone();
let stored1 = b0[1].clone() + b1[1].clone() + b2[1].clone();
let stored2 = b0[2].clone() + b1[2].clone() + b2[2].clone();
let stored3 = b0[3].clone() + b1[3].clone() + b2[3].clone();
[
b0[0].clone() + stored0.clone(),
b0[1].clone() + stored1.clone(),
b0[2].clone() + stored2.clone(),
b0[3].clone() + stored3.clone(),
b1[0].clone() + stored0.clone(),
b1[1].clone() + stored1.clone(),
b1[2].clone() + stored2.clone(),
b1[3].clone() + stored3.clone(),
b2[0].clone() + stored0,
b2[1].clone() + stored1,
b2[2].clone() + stored2,
b2[3].clone() + stored3,
]
}
fn matmul_m4<E: PrimeCharacteristicRing>(input: [E; 4]) -> [E; 4] {
let [a, b, c, d] = input;
let t01 = a.clone() + b.clone();
let t23 = c.clone() + d.clone();
let t0123 = t01.clone() + t23.clone();
let t01123 = t0123.clone() + b;
let t01233 = t0123 + d;
let out0 = t01123.clone() + t01;
let out1 = t01123 + c.double();
let out2 = t01233.clone() + t23;
let out3 = t01233 + a.double();
[out0, out1, out2, out3]
}
fn apply_matmul_internal<E: PrimeCharacteristicRing>(
state: &[E; STATE_WIDTH],
mat_diag: &[E; STATE_WIDTH],
) -> [E; STATE_WIDTH] {
let sum = E::sum_array::<STATE_WIDTH>(state);
core::array::from_fn(|i| state[i].clone() * mat_diag[i].clone() + sum.clone())
}
pub fn apply_init_plus_ext<E: PrimeCharacteristicRing>(
h: &[E; STATE_WIDTH],
ark_ext: &[E; STATE_WIDTH],
) -> [E; STATE_WIDTH] {
let pre = apply_matmul_external(h);
let with_rc: [E; STATE_WIDTH] = core::array::from_fn(|i| pre[i].clone() + ark_ext[i].clone());
let with_sbox: [E; STATE_WIDTH] =
core::array::from_fn(|i| with_rc[i].clone().exp_const_u64::<7>());
apply_matmul_external(&with_sbox)
}
pub fn apply_packed_internals<E: PrimeCharacteristicRing>(
h: &[E; STATE_WIDTH],
w: &[E; 3],
ark_int: &[E; 3],
mat_diag: &[E; STATE_WIDTH],
) -> ([E; STATE_WIDTH], [E; 3]) {
let mut state = h.clone();
let mut witness_checks: [E; 3] = core::array::from_fn(|_| E::ZERO);
for k in 0..3 {
let sbox_input = state[0].clone() + ark_int[k].clone();
witness_checks[k] = w[k].clone() - sbox_input.exp_const_u64::<7>();
state[0] = w[k].clone();
state = apply_matmul_internal(&state, mat_diag);
}
(state, witness_checks)
}
pub fn apply_internal_plus_ext<E: PrimeCharacteristicRing>(
h: &[E; STATE_WIDTH],
w0: &E,
ark_int_const: E,
ark_ext: &[E; STATE_WIDTH],
mat_diag: &[E; STATE_WIDTH],
) -> ([E; STATE_WIDTH], E) {
let sbox_input = h[0].clone() + ark_int_const;
let witness_check = w0.clone() - sbox_input.exp_const_u64::<7>();
let mut int_state = h.clone();
int_state[0] = w0.clone();
let intermediate = apply_matmul_internal(&int_state, mat_diag);
let with_rc: [E; STATE_WIDTH] =
core::array::from_fn(|i| intermediate[i].clone() + ark_ext[i].clone());
let with_sbox: [E; STATE_WIDTH] =
core::array::from_fn(|i| with_rc[i].clone().exp_const_u64::<7>());
let next_state = apply_matmul_external(&with_sbox);
(next_state, witness_check)
}