use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::AirBuilder;
use crate::{
MainCols, MidenAirBuilder,
constraints::{constants::*, op_flags::OpFlags, utils::BoolNot},
};
pub fn enforce_main<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
builder.when_first_row().assert_eq(local.stack.b0, F_16);
builder.when_last_row().assert_eq(local.stack.b0, F_16);
builder.when_first_row().assert_zero(local.stack.b1);
builder.when_last_row().assert_zero(local.stack.b1);
enforce_stack_depth_constraints(builder, local, next, op_flags);
{
let depth = local.stack.b0;
builder.when(op_flags.overflow().not()).assert_eq(depth, F_16);
}
enforce_overflow_index_constraints(builder, local, next, op_flags);
}
fn enforce_stack_depth_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let depth = local.stack.b0;
let depth_next = next.stack.b0;
let call_or_dyncall_or_syscall = op_flags.call() + op_flags.dyncall() + op_flags.syscall();
let is_call_or_dyncall_end = local.decoder.hasher_state[6];
let is_syscall_end = local.decoder.hasher_state[7];
let call_or_dyncall_or_syscall_end = op_flags.end() * (is_call_or_dyncall_end + is_syscall_end);
let normal_mask =
AB::Expr::ONE - call_or_dyncall_or_syscall.clone() - call_or_dyncall_or_syscall_end;
let depth_delta_part = (depth_next.into() - depth.into()) * normal_mask;
let left_shift_part = op_flags.left_shift() * op_flags.overflow();
let right_shift_part = op_flags.right_shift();
let call_part = call_or_dyncall_or_syscall * (depth_next.into() - F_16);
builder
.when_transition()
.assert_zero(depth_delta_part + left_shift_part - right_shift_part + call_part);
}
fn enforce_overflow_index_constraints<AB>(
builder: &mut AB,
local: &MainCols<AB::Var>,
next: &MainCols<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: MidenAirBuilder,
{
let overflow_addr_next = next.stack.b1;
let clk = local.system.clk;
let last_stack_item_next = next.stack.get(15);
builder
.when_transition()
.when(op_flags.right_shift())
.assert_eq(overflow_addr_next, clk);
builder
.when_transition()
.when(op_flags.overflow().not())
.when(op_flags.left_shift())
.assert_zero(last_stack_item_next);
}