miden-air 0.23.0

Algebraic intermediate representation of Miden VM processor
Documentation
//! Permutation chiplet state transition constraints.
//!
//! This module enforces the Poseidon2 permutation constraints for the permutation sub-chiplet.
//! The permutation operates on a 16-row cycle with five types of steps:
//!
//! - **Row 0 (init+ext1)**: Merged init linear layer + first external round
//! - **Rows 1-3, 12-14 (external)**: Single external round: add RCs, S-box^7, M_E
//! - **Rows 4-10 (packed internal)**: 3 internal rounds packed per row using witnesses as S-box
//!   outputs
//! - **Row 11 (int+ext)**: Last internal round + first trailing external round
//! - **Row 15 (boundary)**: No step constraint (cycle boundary, final permutation state)
//!
//! ## Poseidon2 Parameters
//!
//! - State width: 12 field elements
//! - External rounds: 8 (4 initial + 4 terminal)
//! - Internal rounds: 22
//! - S-box: x^7

use miden_core::{chiplets::hasher::Hasher, field::PrimeCharacteristicRing};
use miden_crypto::stark::air::AirBuilder;

use crate::{
    MidenAirBuilder,
    constraints::chiplets::columns::{HasherPeriodicCols, PermutationCols},
    trace::chiplets::hasher::STATE_WIDTH,
};

// CONSTRAINT HELPERS
// ================================================================================================

/// Enforces Poseidon2 permutation step constraints on the 16-row packed cycle.
///
/// These constraints are gated by `perm_gate = flags.is_active`, so they only
/// fire on permutation segment rows.
///
/// ## Step Types
///
/// 1. **Init+ext1 (row 0)**: `h' = M_E(S(M_E(h) + ark))` — degree 9
/// 2. **Single ext (rows 1-3, 12-14)**: `h' = M_E(S(h + ark))` — degree 9
/// 3. **Packed 3x internal (rows 4-10)**: witnesses + affine next-state — degree 9 / 3
/// 4. **Int+ext (row 11)**: witness + `h' = M_E(S(y + ark))` — degree 9
/// 5. **Boundary (row 15)**: No constraint
///
/// The witness columns `w[0..2]` correspond to the S-box outputs on permutation rows.
pub fn enforce_permutation_steps<AB>(
    builder: &mut AB,
    perm_gate: AB::Expr,
    cols: &PermutationCols<AB::Var>,
    cols_next: &PermutationCols<AB::Var>,
    periodic: &HasherPeriodicCols<AB::PeriodicVar>,
) where
    AB: MidenAirBuilder,
{
    let h: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| cols.state[i].into());
    let h_next: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| cols_next.state[i].into());
    let w: [AB::Expr; 3] = core::array::from_fn(|i| cols.witnesses[i].into());

    // Step-type selectors
    let is_init_ext: AB::Expr = periodic.is_init_ext.into();
    let is_ext: AB::Expr = periodic.is_ext.into();
    let is_packed_int: AB::Expr = periodic.is_packed_int.into();
    let is_int_ext: AB::Expr = periodic.is_int_ext.into();

    // Shared round constants
    let ark: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| periodic.ark[i].into());

    // Lift Felt constants needed by pure math helpers.
    let mat_diag: [AB::Expr; STATE_WIDTH] = core::array::from_fn(|i| Hasher::MAT_DIAG[i].into());
    let ark_int_21: AB::Expr = Hasher::ARK_INT[21].into();

    // -------------------------------------------------------------------------
    // 0. Unused witness zeroing
    //
    // Unused witness columns are forced to zero. On non-packed rows, this means:
    // - rows 0-3, 12-15: w0 = w1 = w2 = 0
    // - row 11:          w1 = w2 = 0
    // - rows 4-10:       w0, w1, w2 unconstrained here (checked by packed witness equations)
    //
    // These constraints are primarily defensive. They make permutation rows inert when
    // witnesses are reused and reduce accidental coupling with controller-side selector
    // logic.
    //
    // Gate degrees:
    // - perm_gate(1) * (1 - is_packed_int - is_int_ext)(1) = 2 for w0
    // - perm_gate(1) * (1 - is_packed_int)(1) = 2 for w1,w2
    // Constraint degree: gate(2) * witness(1) = 3
    // -------------------------------------------------------------------------
    // w0 unused on rows that are neither packed-int nor int+ext.
    builder
        .when(perm_gate.clone() * (AB::Expr::ONE - is_packed_int.clone() - is_int_ext.clone()))
        .assert_zero(w[0].clone());
    // w1, w2 unused on rows that are not packed-int.
    {
        let builder =
            &mut builder.when(perm_gate.clone() * (AB::Expr::ONE - is_packed_int.clone()));
        builder.assert_zero(w[1].clone());
        builder.assert_zero(w[2].clone());
    }

    // -------------------------------------------------------------------------
    // 1. Init+ext1 (row 0): h' = M_E(S(M_E(h) + ark))
    // Gate degree: perm_gate(1) * is_init_ext(1) = 2
    // Constraint degree: gate(2) * sbox(7) = 9
    // -------------------------------------------------------------------------
    {
        let expected = apply_init_plus_ext(&h, &ark);
        let builder = &mut builder.when(perm_gate.clone() * is_init_ext);
        for i in 0..STATE_WIDTH {
            builder.assert_eq(h_next[i].clone(), expected[i].clone());
        }
    }

    // -------------------------------------------------------------------------
    // 2. Single external round (rows 1-3, 12-14): h' = M_E(S(h + ark))
    // Gate degree: perm_gate(1) * is_ext(1) = 2
    // Constraint degree: gate(2) * sbox(7) = 9
    // -------------------------------------------------------------------------
    {
        let ext_with_rc: [AB::Expr; STATE_WIDTH] =
            core::array::from_fn(|i| h[i].clone() + ark[i].clone());
        let ext_with_sbox: [AB::Expr; STATE_WIDTH] =
            core::array::from_fn(|i| ext_with_rc[i].clone().exp_const_u64::<7>());
        let expected = apply_matmul_external(&ext_with_sbox);

        let builder = &mut builder.when(perm_gate.clone() * is_ext);
        for i in 0..STATE_WIDTH {
            builder.assert_eq(h_next[i].clone(), expected[i].clone());
        }
    }

    // -------------------------------------------------------------------------
    // 3. Packed 3x internal (rows 4-10): witness checks + affine next-state
    // Gate degree: perm_gate(1) * is_packed_int(1) = 2
    // Witness constraint degree: gate(2) * sbox(7) = 9
    // Next-state constraint degree: gate(2) * affine(1) = 3
    // -------------------------------------------------------------------------
    {
        // ark[0..2] hold the 3 internal round constants on packed-int rows
        let ark_int_3: [AB::Expr; 3] = core::array::from_fn(|i| ark[i].clone());
        let (expected, witness_checks) = apply_packed_internals(&h, &w, &ark_int_3, &mat_diag);

        let builder = &mut builder.when(perm_gate.clone() * is_packed_int);
        // 3 witness constraints
        for wc in &witness_checks {
            builder.assert_zero(wc.clone());
        }
        // 12 next-state constraints
        for i in 0..STATE_WIDTH {
            builder.assert_eq(h_next[i].clone(), expected[i].clone());
        }
    }

    // -------------------------------------------------------------------------
    // 4. Int+ext merged (row 11): 1 internal (ARK_INT[21] hardcoded) + 1 external
    // Gate degree: perm_gate(1) * is_int_ext(1) = 2
    // Witness constraint degree: gate(2) * sbox(7) = 9
    // Next-state constraint degree: gate(2) * sbox(7) = 9
    // -------------------------------------------------------------------------
    {
        let (expected, witness_check) =
            apply_internal_plus_ext(&h, &w[0], ark_int_21, &ark, &mat_diag);

        let builder = &mut builder.when(perm_gate * is_int_ext);
        // 1 witness constraint
        builder.assert_zero(witness_check);
        // 12 next-state constraints
        for i in 0..STATE_WIDTH {
            builder.assert_eq(h_next[i].clone(), expected[i].clone());
        }
    }
}

// =============================================================================
// LINEAR ALGEBRA HELPERS
// =============================================================================

/// Applies the external linear layer M_E to the state.
///
/// The external layer consists of:
/// 1. Apply M4 to each 4-element block
/// 2. Add cross-block sums to each element
fn apply_matmul_external<E: PrimeCharacteristicRing>(state: &[E; STATE_WIDTH]) -> [E; STATE_WIDTH] {
    // Apply M4 to each 4-element block
    let b0 = matmul_m4(core::array::from_fn(|i| state[i].clone()));
    let b1 = matmul_m4(core::array::from_fn(|i| state[4 + i].clone()));
    let b2 = matmul_m4(core::array::from_fn(|i| state[8 + i].clone()));

    // Compute cross-block sums
    let stored0 = b0[0].clone() + b1[0].clone() + b2[0].clone();
    let stored1 = b0[1].clone() + b1[1].clone() + b2[1].clone();
    let stored2 = b0[2].clone() + b1[2].clone() + b2[2].clone();
    let stored3 = b0[3].clone() + b1[3].clone() + b2[3].clone();

    // Add sums to each element
    [
        b0[0].clone() + stored0.clone(),
        b0[1].clone() + stored1.clone(),
        b0[2].clone() + stored2.clone(),
        b0[3].clone() + stored3.clone(),
        b1[0].clone() + stored0.clone(),
        b1[1].clone() + stored1.clone(),
        b1[2].clone() + stored2.clone(),
        b1[3].clone() + stored3.clone(),
        b2[0].clone() + stored0,
        b2[1].clone() + stored1,
        b2[2].clone() + stored2,
        b2[3].clone() + stored3,
    ]
}

/// Applies the 4x4 matrix M4 used in Poseidon2's external linear layer.
fn matmul_m4<E: PrimeCharacteristicRing>(input: [E; 4]) -> [E; 4] {
    let [a, b, c, d] = input;

    let t01 = a.clone() + b.clone();
    let t23 = c.clone() + d.clone();
    let t0123 = t01.clone() + t23.clone();
    let t01123 = t0123.clone() + b;
    let t01233 = t0123 + d;

    let out0 = t01123.clone() + t01;
    let out1 = t01123 + c.double();
    let out2 = t01233.clone() + t23;
    let out3 = t01233 + a.double();

    [out0, out1, out2, out3]
}

/// Applies the internal linear layer M_I to the state.
///
/// M_I = I + diag(MAT_DIAG) where all rows share the same sum.
/// The `mat_diag` parameter provides `Hasher::MAT_DIAG` pre-lifted to the expression type.
fn apply_matmul_internal<E: PrimeCharacteristicRing>(
    state: &[E; STATE_WIDTH],
    mat_diag: &[E; STATE_WIDTH],
) -> [E; STATE_WIDTH] {
    // Sum of all state elements
    let sum = E::sum_array::<STATE_WIDTH>(state);
    // result[i] = state[i] * MAT_DIAG[i] + sum
    core::array::from_fn(|i| state[i].clone() * mat_diag[i].clone() + sum.clone())
}

// =============================================================================
// PACKED ROUND HELPERS
// =============================================================================

/// Computes the expected next state for the merged init linear + first external round.
///
/// h' = M_E(S(M_E(h) + ark_ext))
///
/// The init step applies M_E to the input, then the first external round adds round
/// constants, applies the full S-box, and applies M_E again. This is a single S-box
/// layer over affine expressions, so the constraint degree is 7.
pub fn apply_init_plus_ext<E: PrimeCharacteristicRing>(
    h: &[E; STATE_WIDTH],
    ark_ext: &[E; STATE_WIDTH],
) -> [E; STATE_WIDTH] {
    // Apply M_E to get the pre-round state
    let pre = apply_matmul_external(h);

    // Add round constants, apply S-box, apply M_E
    let with_rc: [E; STATE_WIDTH] = core::array::from_fn(|i| pre[i].clone() + ark_ext[i].clone());
    let with_sbox: [E; STATE_WIDTH] =
        core::array::from_fn(|i| with_rc[i].clone().exp_const_u64::<7>());
    apply_matmul_external(&with_sbox)
}

/// Computes the expected next state and witness checks for 3 packed internal rounds.
///
/// Each internal round applies: add RC to lane 0, S-box lane 0, then M_I.
/// The S-box output for each round is provided as an explicit witness (w0, w1, w2),
/// which keeps the intermediate states affine and the constraint degree at 7.
///
/// Returns:
/// - `next_state`: expected state after all 3 rounds (affine in trace columns, degree 1)
/// - `witness_checks`: 3 expressions that must be zero (each degree 7): `wk - (y(k)_0 +
///   ark_int[k])^7`
pub fn apply_packed_internals<E: PrimeCharacteristicRing>(
    h: &[E; STATE_WIDTH],
    w: &[E; 3],
    ark_int: &[E; 3],
    mat_diag: &[E; STATE_WIDTH],
) -> ([E; STATE_WIDTH], [E; 3]) {
    let mut state = h.clone();
    let mut witness_checks: [E; 3] = core::array::from_fn(|_| E::ZERO);

    for k in 0..3 {
        // Witness check: wk = (state[0] + ark_int[k])^7
        let sbox_input = state[0].clone() + ark_int[k].clone();
        witness_checks[k] = w[k].clone() - sbox_input.exp_const_u64::<7>();

        // Substitute witness for lane 0 and apply M_I
        state[0] = w[k].clone();
        state = apply_matmul_internal(&state, mat_diag);
    }

    (state, witness_checks)
}

/// Computes the expected next state and witness check for one internal round followed
/// by one external round.
///
/// Used for the int22+ext5 merged row (row 11). The internal round constant ARK_INT[21]
/// is passed as a concrete Felt rather than read from a periodic column. This is valid
/// because row 11 is the only row gated by `is_int_ext` -- no other row needs a different
/// value under the same gate. A periodic column would waste 15 zero entries to deliver
/// one value.
///
/// Returns:
/// - `next_state`: expected state after int + ext (degree 7 in trace columns)
/// - `witness_check`: `w0 - (h[0] + ark_int_const)^7` (degree 7)
pub fn apply_internal_plus_ext<E: PrimeCharacteristicRing>(
    h: &[E; STATE_WIDTH],
    w0: &E,
    ark_int_const: E,
    ark_ext: &[E; STATE_WIDTH],
    mat_diag: &[E; STATE_WIDTH],
) -> ([E; STATE_WIDTH], E) {
    // Internal round: witness check and state update
    let sbox_input = h[0].clone() + ark_int_const;
    let witness_check = w0.clone() - sbox_input.exp_const_u64::<7>();

    let mut int_state = h.clone();
    int_state[0] = w0.clone();
    let intermediate = apply_matmul_internal(&int_state, mat_diag);

    // External round: add RC, S-box all lanes, M_E
    let with_rc: [E; STATE_WIDTH] =
        core::array::from_fn(|i| intermediate[i].clone() + ark_ext[i].clone());
    let with_sbox: [E; STATE_WIDTH] =
        core::array::from_fn(|i| with_rc[i].clone().exp_const_u64::<7>());
    let next_state = apply_matmul_external(&with_sbox);

    (next_state, witness_check)
}