use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{AirBuilder, LiftedAirBuilder};
use crate::{
MainTraceRow,
constraints::{
op_flags::OpFlags,
tagging::{
TaggingAirBuilderExt,
ids::{TAG_STACK_OVERFLOW_BASE, TAG_STACK_OVERFLOW_COUNT},
},
},
trace::{
decoder::{IS_CALL_FLAG_COL_IDX, IS_SYSCALL_FLAG_COL_IDX},
stack::{B0_COL_IDX, B1_COL_IDX},
},
};
const STACK_OVERFLOW_BASE_ID: usize = TAG_STACK_OVERFLOW_BASE;
const STACK_OVERFLOW_NAMES: [&str; TAG_STACK_OVERFLOW_COUNT] = [
"stack.overflow.depth.first_row",
"stack.overflow.depth.last_row",
"stack.overflow.addr.first_row",
"stack.overflow.addr.last_row",
"stack.overflow.depth.transition",
"stack.overflow.flag.transition",
"stack.overflow.addr.transition",
"stack.overflow.zero_insert.transition",
];
pub fn enforce_main<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: LiftedAirBuilder,
{
let sixteen: AB::Expr = AB::Expr::from_u16(16);
let zero: AB::Expr = AB::Expr::ZERO;
builder.tagged(STACK_OVERFLOW_BASE_ID, STACK_OVERFLOW_NAMES[0], |builder| {
builder
.when_first_row()
.assert_zero(local.stack[B0_COL_IDX].clone().into() - sixteen.clone());
});
builder.tagged(STACK_OVERFLOW_BASE_ID + 1, STACK_OVERFLOW_NAMES[1], |builder| {
builder
.when_last_row()
.assert_zero(local.stack[B0_COL_IDX].clone().into() - sixteen);
});
builder.tagged(STACK_OVERFLOW_BASE_ID + 2, STACK_OVERFLOW_NAMES[2], |builder| {
builder
.when_first_row()
.assert_zero(local.stack[B1_COL_IDX].clone().into() - zero.clone());
});
builder.tagged(STACK_OVERFLOW_BASE_ID + 3, STACK_OVERFLOW_NAMES[3], |builder| {
builder
.when_last_row()
.assert_zero(local.stack[B1_COL_IDX].clone().into() - zero);
});
enforce_stack_depth_constraints(builder, local, next, op_flags);
enforce_overflow_flag_constraints(builder, local, op_flags);
enforce_overflow_index_constraints(builder, local, next, op_flags);
}
fn enforce_stack_depth_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: LiftedAirBuilder,
{
let depth: AB::Expr = local.stack[B0_COL_IDX].clone().into();
let depth_next: AB::Expr = next.stack[B0_COL_IDX].clone().into();
let call_or_dyncall_or_syscall = op_flags.call() + op_flags.dyncall() + op_flags.syscall();
let is_call_or_dyncall_end: AB::Expr = local.decoder[IS_CALL_FLAG_COL_IDX].clone().into();
let is_syscall_end: AB::Expr = local.decoder[IS_SYSCALL_FLAG_COL_IDX].clone().into();
let call_or_dyncall_or_syscall_end = op_flags.end() * (is_call_or_dyncall_end + is_syscall_end);
let normal_mask =
AB::Expr::ONE - call_or_dyncall_or_syscall.clone() - call_or_dyncall_or_syscall_end;
let depth_delta_part = (depth_next.clone() - depth.clone()) * normal_mask;
let left_shift_part = op_flags.left_shift() * op_flags.overflow();
let right_shift_part = op_flags.right_shift();
let call_part = call_or_dyncall_or_syscall * (depth_next - AB::Expr::from_u16(16));
builder.tagged(STACK_OVERFLOW_BASE_ID + 4, STACK_OVERFLOW_NAMES[4], |builder| {
builder
.when_transition()
.assert_zero(depth_delta_part + left_shift_part - right_shift_part + call_part);
});
}
fn enforce_overflow_flag_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: LiftedAirBuilder,
{
let depth: AB::Expr = local.stack[B0_COL_IDX].clone().into();
let constraint = (AB::Expr::ONE - op_flags.overflow()) * (depth - AB::Expr::from_u16(16));
builder.tagged(STACK_OVERFLOW_BASE_ID + 5, STACK_OVERFLOW_NAMES[5], |builder| {
builder.assert_zero(constraint);
});
}
fn enforce_overflow_index_constraints<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
) where
AB: LiftedAirBuilder,
{
let overflow_addr_next: AB::Expr = next.stack[B1_COL_IDX].clone().into();
let clk: AB::Expr = local.clk.clone().into();
let last_stack_item_next: AB::Expr = next.stack[15].clone().into();
let right_shift_constraint = (overflow_addr_next - clk) * op_flags.right_shift();
builder.tagged(STACK_OVERFLOW_BASE_ID + 6, STACK_OVERFLOW_NAMES[6], |builder| {
builder.when_transition().assert_zero(right_shift_constraint);
});
let left_shift_constraint =
(AB::Expr::ONE - op_flags.overflow()) * op_flags.left_shift() * last_stack_item_next;
builder.tagged(STACK_OVERFLOW_BASE_ID + 7, STACK_OVERFLOW_NAMES[7], |builder| {
builder.when_transition().assert_zero(left_shift_constraint);
});
}