use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{ExtensionBuilder, LiftedAirBuilder, WindowAccess};
use crate::{
MainTraceRow,
constraints::{
bus::indices::P1_STACK,
op_flags::OpFlags,
tagging::{TaggingAirBuilderExt, ids::TAG_STACK_OVERFLOW_BUS_BASE},
},
trace::{
Challenges,
decoder::HASHER_STATE_RANGE,
stack::{B0_COL_IDX, B1_COL_IDX, H0_COL_IDX},
},
};
const STACK_OVERFLOW_BUS_ID: usize = TAG_STACK_OVERFLOW_BUS_BASE;
const STACK_OVERFLOW_BUS_NAME: &str = "stack.overflow.bus.transition";
pub fn enforce_bus<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
next: &MainTraceRow<AB::Var>,
op_flags: &OpFlags<AB::Expr>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: LiftedAirBuilder,
{
let (p1_local, p1_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(aux_local[P1_STACK], aux_next[P1_STACK])
};
let one_ef = AB::ExprEF::ONE;
let clk: AB::Expr = local.clk.clone().into();
let s15: AB::Expr = local.stack[15].clone().into();
let b0: AB::Expr = local.stack[B0_COL_IDX].clone().into();
let b1: AB::Expr = local.stack[B1_COL_IDX].clone().into();
let h0: AB::Expr = local.stack[H0_COL_IDX].clone().into();
let s15_next: AB::Expr = next.stack[15].clone().into();
let b1_next: AB::Expr = next.stack[B1_COL_IDX].clone().into();
let hasher_state_5: AB::Expr = local.decoder[HASHER_STATE_RANGE.start + 5].clone().into();
let sixteen = AB::Expr::from_u16(16);
let is_non_empty_overflow: AB::Expr = (b0 - sixteen) * h0;
let right_shift = op_flags.right_shift();
let left_shift = op_flags.left_shift();
let dyncall = op_flags.dyncall();
let response_row = challenges.encode([clk.clone(), s15.clone(), b1.clone()]);
let request_row_left = challenges.encode([b1.clone(), s15_next.clone(), b1_next.clone()]);
let request_row_dyncall =
challenges.encode([b1.clone(), s15_next.clone(), hasher_state_5.clone()]);
let response: AB::ExprEF = response_row * right_shift.clone() + (one_ef.clone() - right_shift);
let left_flag: AB::Expr = left_shift * is_non_empty_overflow.clone();
let dyncall_flag: AB::Expr = dyncall * is_non_empty_overflow;
let request_flag_sum: AB::Expr = left_flag.clone() + dyncall_flag.clone();
let request: AB::ExprEF = request_row_left * left_flag.clone()
+ request_row_dyncall * dyncall_flag.clone()
+ (one_ef.clone() - request_flag_sum);
let lhs = p1_next.into() * request;
let rhs = p1_local.into() * response;
builder.tagged(STACK_OVERFLOW_BUS_ID, STACK_OVERFLOW_BUS_NAME, |builder| {
builder.when_transition().assert_zero_ext(lhs - rhs);
});
}