use miden_core::{chiplets::hasher::Hasher, field::PrimeCharacteristicRing};
use miden_crypto::stark::air::LiftedAirBuilder;
use super::periodic::{
P_ARK_EXT_START, P_ARK_INT, P_CYCLE_ROW_0, P_IS_EXTERNAL, P_IS_INTERNAL, STATE_WIDTH,
};
use crate::{
Felt,
constraints::tagging::{TagGroup, TaggingAirBuilderExt, tagged_assert_zeros},
};
const PERM_INIT_NAMESPACE: &str = "chiplets.hasher.permutation.init";
const PERM_EXT_NAMESPACE: &str = "chiplets.hasher.permutation.external";
const PERM_INT_NAMESPACE: &str = "chiplets.hasher.permutation.internal";
const ABP_CAP_NAMESPACE: &str = "chiplets.hasher.abp.capacity";
const PERM_INIT_NAMES: [&str; STATE_WIDTH] = [PERM_INIT_NAMESPACE; STATE_WIDTH];
const PERM_EXT_NAMES: [&str; STATE_WIDTH] = [PERM_EXT_NAMESPACE; STATE_WIDTH];
const PERM_INT_NAMES: [&str; STATE_WIDTH] = [PERM_INT_NAMESPACE; STATE_WIDTH];
const ABP_CAP_NAMES: [&str; 4] = [ABP_CAP_NAMESPACE; 4];
const PERM_INIT_TAGS: TagGroup = TagGroup {
base: super::HASHER_PERM_INIT_BASE_ID,
names: &PERM_INIT_NAMES,
};
const PERM_EXT_TAGS: TagGroup = TagGroup {
base: super::HASHER_PERM_EXT_BASE_ID,
names: &PERM_EXT_NAMES,
};
const PERM_INT_TAGS: TagGroup = TagGroup {
base: super::HASHER_PERM_INT_BASE_ID,
names: &PERM_INT_NAMES,
};
const ABP_CAP_TAGS: TagGroup = TagGroup {
base: super::HASHER_ABP_BASE_ID,
names: &ABP_CAP_NAMES,
};
pub fn enforce_permutation_steps<AB>(
builder: &mut AB,
hasher_flag: AB::Expr,
h: &[AB::Expr; STATE_WIDTH],
h_next: &[AB::Expr; STATE_WIDTH],
periodic: &[AB::PeriodicVar],
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let cycle_row_0: AB::Expr = periodic[P_CYCLE_ROW_0].into();
let is_external: AB::Expr = periodic[P_IS_EXTERNAL].into();
let is_internal: AB::Expr = periodic[P_IS_INTERNAL].into();
let is_init_linear = cycle_row_0.clone();
let mut ark_ext = [AB::Expr::ZERO; STATE_WIDTH];
for lane in 0..STATE_WIDTH {
ark_ext[lane] = periodic[P_ARK_EXT_START + lane].into();
}
let ark_int: AB::Expr = periodic[P_ARK_INT].into();
let expected_init = apply_matmul_external::<AB>(h);
let ext_with_rc: [AB::Expr; STATE_WIDTH] =
core::array::from_fn(|i| h[i].clone() + ark_ext[i].clone());
let ext_with_sbox: [AB::Expr; STATE_WIDTH] =
core::array::from_fn(|i| ext_with_rc[i].clone().exp_const_u64::<7>());
let expected_ext = apply_matmul_external::<AB>(&ext_with_sbox);
let mut tmp_int = h.clone();
tmp_int[0] = (tmp_int[0].clone() + ark_int).exp_const_u64::<7>();
let expected_int = apply_matmul_internal::<AB>(&tmp_int);
let gate_init = hasher_flag.clone() * is_init_linear;
let mut idx = 0;
tagged_assert_zeros(
builder,
&PERM_INIT_TAGS,
&mut idx,
PERM_INIT_NAMESPACE,
core::array::from_fn::<_, STATE_WIDTH, _>(|i| {
gate_init.clone() * (h_next[i].clone() - expected_init[i].clone())
}),
);
let gate_ext = hasher_flag.clone() * is_external;
let mut idx = 0;
tagged_assert_zeros(
builder,
&PERM_EXT_TAGS,
&mut idx,
PERM_EXT_NAMESPACE,
core::array::from_fn::<_, STATE_WIDTH, _>(|i| {
gate_ext.clone() * (h_next[i].clone() - expected_ext[i].clone())
}),
);
let gate_int = hasher_flag * is_internal;
let mut idx = 0;
tagged_assert_zeros(
builder,
&PERM_INT_TAGS,
&mut idx,
PERM_INT_NAMESPACE,
core::array::from_fn::<_, STATE_WIDTH, _>(|i| {
gate_int.clone() * (h_next[i].clone() - expected_int[i].clone())
}),
);
}
pub fn enforce_abp_capacity_preservation<AB>(
builder: &mut AB,
hasher_flag: AB::Expr,
f_abp: AB::Expr,
h_cap: &[AB::Expr; 4],
h_cap_next: &[AB::Expr; 4],
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let gate = hasher_flag * f_abp;
let mut idx = 0;
tagged_assert_zeros(
builder,
&ABP_CAP_TAGS,
&mut idx,
ABP_CAP_NAMESPACE,
core::array::from_fn::<_, 4, _>(|i| {
gate.clone() * (h_cap_next[i].clone() - h_cap[i].clone())
}),
);
}
fn apply_matmul_external<AB: LiftedAirBuilder<F = Felt>>(
state: &[AB::Expr; STATE_WIDTH],
) -> [AB::Expr; STATE_WIDTH] {
let b0 =
matmul_m4::<AB>(&[state[0].clone(), state[1].clone(), state[2].clone(), state[3].clone()]);
let b1 =
matmul_m4::<AB>(&[state[4].clone(), state[5].clone(), state[6].clone(), state[7].clone()]);
let b2 = matmul_m4::<AB>(&[
state[8].clone(),
state[9].clone(),
state[10].clone(),
state[11].clone(),
]);
let stored0 = b0[0].clone() + b1[0].clone() + b2[0].clone();
let stored1 = b0[1].clone() + b1[1].clone() + b2[1].clone();
let stored2 = b0[2].clone() + b1[2].clone() + b2[2].clone();
let stored3 = b0[3].clone() + b1[3].clone() + b2[3].clone();
[
b0[0].clone() + stored0.clone(),
b0[1].clone() + stored1.clone(),
b0[2].clone() + stored2.clone(),
b0[3].clone() + stored3.clone(),
b1[0].clone() + stored0.clone(),
b1[1].clone() + stored1.clone(),
b1[2].clone() + stored2.clone(),
b1[3].clone() + stored3.clone(),
b2[0].clone() + stored0,
b2[1].clone() + stored1,
b2[2].clone() + stored2,
b2[3].clone() + stored3,
]
}
fn matmul_m4<AB: LiftedAirBuilder<F = Felt>>(input: &[AB::Expr; 4]) -> [AB::Expr; 4] {
let [a, b, c, d] = input.clone();
let t0 = a.clone() + b.clone();
let t1 = c.clone() + d.clone();
let t2 = b.clone() + b.clone() + t1.clone(); let t3 = d.clone() + d.clone() + t0.clone(); let t4 = t1.clone().double() + t1.clone().double() + t3.clone(); let t5 = t0.clone().double() + t0.clone().double() + t2.clone();
let out0 = t3.clone() + t5.clone();
let out1 = t5;
let out2 = t2 + t4.clone();
let out3 = t4;
[out0, out1, out2, out3]
}
fn apply_matmul_internal<AB: LiftedAirBuilder<F = Felt>>(
state: &[AB::Expr; STATE_WIDTH],
) -> [AB::Expr; STATE_WIDTH] {
let sum: AB::Expr = state.iter().cloned().reduce(|a, b| a + b).expect("STATE_WIDTH > 0");
core::array::from_fn(|i| state[i].clone() * AB::Expr::from(Hasher::MAT_DIAG[i]) + sum.clone())
}