pub mod flags;
use flags::ControllerFlags;
use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::AirBuilder;
use crate::{
MainCols, MidenAirBuilder,
constraints::{
chiplets::{columns::ControllerCols, selectors::ChipletFlags},
utils::BoolNot,
},
};
pub fn enforce_controller_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
chiplet: &ChipletFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let cols: &ControllerCols<AB::Var> = local.controller();
let cols_next: &ControllerCols<AB::Var> = next.controller();
let rows = ControllerFlags::<AB::Expr>::new(cols, cols_next);
builder
.when_first_row()
.assert_one(chiplet.is_active.clone() * rows.is_input.clone());
builder
.when(chiplet.is_active.clone())
.assert_bools([cols.s0, cols.s1, cols.s2]);
builder.when(chiplet.is_active.clone()).assert_bool(cols.is_boundary);
builder
.when(chiplet.is_transition.clone())
.when(rows.is_output.clone())
.assert_zero(rows.is_output_next.clone());
builder
.when(chiplet.is_transition.clone())
.when(rows.is_padding.clone())
.assert_one(rows.is_padding_next.clone());
builder
.when(chiplet.is_active.clone())
.when(rows.is_padding.clone())
.assert_zeros([cols.is_boundary, cols.direction_bit]);
builder.when(chiplet.is_last.clone()).assert_zero(rows.is_input.clone());
builder
.when(chiplet.is_last.clone())
.when(rows.is_output.clone())
.assert_one(cols.is_boundary);
builder
.when(chiplet.is_transition.clone())
.when(rows.is_input.clone())
.assert_one(rows.is_output_next.clone());
builder
.when(chiplet.is_active.clone())
.when(rows.is_sponge_input.clone())
.assert_zeros([cols.node_index, cols.direction_bit]);
{
let is_boundary_next: AB::Expr = cols_next.is_boundary.into();
let gate = chiplet.is_transition.clone()
* rows.is_sponge_input_next.clone()
* is_boundary_next.not();
let cap = cols.capacity();
let cap_next = cols_next.capacity();
let builder = &mut builder.when(gate);
for i in 0..4 {
builder.assert_eq(cap_next[i], cap[i]);
}
}
{
let gate = chiplet.is_active.clone() * rows.is_merkle_input.clone();
let builder = &mut builder.when(gate);
let node_index_next: AB::Expr = cols_next.node_index.into();
let idx_expected = node_index_next.double() + cols.direction_bit;
builder.assert_eq(cols.node_index, idx_expected);
builder.assert_bool(cols.direction_bit);
builder.assert_zeros(cols.capacity());
}
let not_boundary: AB::Expr = cols.is_boundary.into().not();
builder
.when(chiplet.is_active.clone())
.when(rows.is_output.clone())
.when(not_boundary.clone())
.when(rows.is_merkle_input_next.clone())
.assert_eq(cols_next.node_index, cols.node_index);
{
let gate = chiplet.is_active.clone()
* rows.is_output.clone()
* not_boundary
* rows.is_merkle_input_next.clone();
let builder = &mut builder.when(gate);
builder.assert_eq(cols.direction_bit, cols_next.direction_bit);
let b: AB::Expr = cols.direction_bit.into();
let rate0_curr = cols.rate0();
let rate0_next = cols_next.rate0();
let rate1_next = cols_next.rate1();
for j in 0..4 {
builder.assert_eq(
rate0_curr[j],
rate0_next[j] + b.clone() * (rate1_next[j] - rate0_next[j]),
);
}
}
let mrupdate_id: AB::Expr = cols.mrupdate_id.into();
let mv_start_next = rows.is_mv_input_next * cols_next.is_boundary;
builder
.when(chiplet.is_transition.clone())
.assert_eq(cols_next.mrupdate_id, mrupdate_id + mv_start_next);
builder
.when(chiplet.is_active.clone())
.when(rows.is_hout.clone())
.assert_zeros([cols.node_index, cols.direction_bit]);
builder
.when(chiplet.is_active.clone())
.when(rows.is_sout)
.when(cols.is_boundary)
.assert_zero(cols.direction_bit);
}