use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{AirBuilder, LiftedAirBuilder};
use crate::{
Felt, MainTraceRow,
constraints::tagging::{TaggingAirBuilderExt, ids::TAG_CHIPLETS_BASE},
};
const CHIPLET_SELECTORS_BASE_ID: usize = TAG_CHIPLETS_BASE;
const CHIPLET_SELECTORS_NAMES: [&str; 10] = [
"chiplets.selectors.s0.binary",
"chiplets.selectors.s1.binary",
"chiplets.selectors.s2.binary",
"chiplets.selectors.s3.binary",
"chiplets.selectors.s4.binary",
"chiplets.selectors.s0.stability",
"chiplets.selectors.s1.stability",
"chiplets.selectors.s2.stability",
"chiplets.selectors.s3.stability",
"chiplets.selectors.s4.stability",
];
pub fn enforce_chiplet_selectors<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
) where
AB: LiftedAirBuilder<F = Felt>,
{
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s2: AB::Expr = local.chiplets[2].clone().into();
let s3: AB::Expr = local.chiplets[3].clone().into();
let s4: AB::Expr = local.chiplets[4].clone().into();
let s0_next: AB::Expr = next.chiplets[0].clone().into();
let s1_next: AB::Expr = next.chiplets[1].clone().into();
let s2_next: AB::Expr = next.chiplets[2].clone().into();
let s3_next: AB::Expr = next.chiplets[3].clone().into();
let s4_next: AB::Expr = next.chiplets[4].clone().into();
let one: AB::Expr = AB::Expr::ONE;
builder.tagged(CHIPLET_SELECTORS_BASE_ID, CHIPLET_SELECTORS_NAMES[0], |builder| {
builder.assert_zero(s0.clone() * (s0.clone() - one.clone()));
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 1, CHIPLET_SELECTORS_NAMES[1], |builder| {
builder.when(s0.clone()).assert_zero(s1.clone() * (s1.clone() - one.clone()));
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 2, CHIPLET_SELECTORS_NAMES[2], |builder| {
builder
.when(s0.clone())
.when(s1.clone())
.assert_zero(s2.clone() * (s2.clone() - one.clone()));
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 3, CHIPLET_SELECTORS_NAMES[3], |builder| {
builder
.when(s0.clone())
.when(s1.clone())
.when(s2.clone())
.assert_zero(s3.clone() * (s3.clone() - one.clone()));
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 4, CHIPLET_SELECTORS_NAMES[4], |builder| {
builder
.when(s0.clone())
.when(s1.clone())
.when(s2.clone())
.when(s3.clone())
.assert_zero(s4.clone() * (s4.clone() - one.clone()));
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 5, CHIPLET_SELECTORS_NAMES[5], |builder| {
builder
.when_transition()
.when(s0.clone())
.assert_zero(s0_next.clone() - s0.clone());
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 6, CHIPLET_SELECTORS_NAMES[6], |builder| {
builder
.when_transition()
.when(s0.clone())
.when(s1.clone())
.assert_zero(s1_next.clone() - s1.clone());
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 7, CHIPLET_SELECTORS_NAMES[7], |builder| {
builder
.when_transition()
.when(s0.clone())
.when(s1.clone())
.when(s2.clone())
.assert_zero(s2_next.clone() - s2.clone());
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 8, CHIPLET_SELECTORS_NAMES[8], |builder| {
builder
.when_transition()
.when(s0.clone())
.when(s1.clone())
.when(s2.clone())
.when(s3.clone())
.assert_zero(s3_next.clone() - s3.clone());
});
builder.tagged(CHIPLET_SELECTORS_BASE_ID + 9, CHIPLET_SELECTORS_NAMES[9], |builder| {
builder
.when_transition()
.when(s0)
.when(s1)
.when(s2)
.when(s3)
.when(s4.clone())
.assert_zero(s4_next - s4);
});
}
#[inline]
pub fn bitwise_chiplet_flag<E: PrimeCharacteristicRing>(s0: E, s1: E) -> E {
s0 * (E::ONE - s1)
}
#[inline]
pub fn memory_chiplet_flag<E: PrimeCharacteristicRing>(s0: E, s1: E, s2: E) -> E {
s0 * s1 * (E::ONE - s2)
}
#[inline]
pub fn ace_chiplet_flag<E: PrimeCharacteristicRing>(s0: E, s1: E, s2: E, s3: E) -> E {
s0 * s1 * s2 * (E::ONE - s3)
}
#[inline]
pub fn kernel_rom_chiplet_flag<E: PrimeCharacteristicRing>(s0: E, s1: E, s2: E, s3: E, s4: E) -> E {
s0 * s1 * s2 * s3 * (E::ONE - s4)
}