use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{ExtensionBuilder, LiftedAirBuilder, WindowAccess};
use crate::{
MainTraceRow,
constraints::tagging::{TaggingAirBuilderExt, ids::TAG_RANGE_BUS_BASE},
trace::{
CHIPLET_S0_COL_IDX, CHIPLET_S1_COL_IDX, CHIPLET_S2_COL_IDX, CHIPLETS_OFFSET,
RANGE_CHECK_TRACE_OFFSET, chiplets, decoder, range,
},
};
const STACK_LOOKUP_BASE: usize = decoder::USER_OP_HELPERS_OFFSET;
const OP_BIT_4_COL_IDX: usize = decoder::OP_BITS_RANGE.start + 4;
const OP_BIT_5_COL_IDX: usize = decoder::OP_BITS_RANGE.start + 5;
const OP_BIT_6_COL_IDX: usize = decoder::OP_BITS_RANGE.start + 6;
const CHIPLET_S0_IDX: usize = CHIPLET_S0_COL_IDX - CHIPLETS_OFFSET;
const CHIPLET_S1_IDX: usize = CHIPLET_S1_COL_IDX - CHIPLETS_OFFSET;
const CHIPLET_S2_IDX: usize = CHIPLET_S2_COL_IDX - CHIPLETS_OFFSET;
const MEMORY_D0_IDX: usize = chiplets::MEMORY_D0_COL_IDX - CHIPLETS_OFFSET;
const MEMORY_D1_IDX: usize = chiplets::MEMORY_D1_COL_IDX - CHIPLETS_OFFSET;
const RANGE_M_COL_IDX: usize = range::M_COL_IDX - RANGE_CHECK_TRACE_OFFSET;
const RANGE_V_COL_IDX: usize = range::V_COL_IDX - RANGE_CHECK_TRACE_OFFSET;
const RANGE_BUS_NAME: &str = "range.bus.transition";
pub fn enforce_bus<AB>(builder: &mut AB, local: &MainTraceRow<AB::Var>)
where
AB: LiftedAirBuilder,
{
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
let b_local = aux_local[range::B_RANGE_COL_IDX];
let b_next = aux_next[range::B_RANGE_COL_IDX];
let challenges = builder.permutation_randomness();
let alpha = challenges[0];
let mv0: AB::ExprEF = alpha.into() + local.chiplets[MEMORY_D0_IDX].clone().into();
let mv1: AB::ExprEF = alpha.into() + local.chiplets[MEMORY_D1_IDX].clone().into();
let sv0: AB::ExprEF = alpha.into() + local.decoder[STACK_LOOKUP_BASE].clone().into();
let sv1: AB::ExprEF = alpha.into() + local.decoder[STACK_LOOKUP_BASE + 1].clone().into();
let sv2: AB::ExprEF = alpha.into() + local.decoder[STACK_LOOKUP_BASE + 2].clone().into();
let sv3: AB::ExprEF = alpha.into() + local.decoder[STACK_LOOKUP_BASE + 3].clone().into();
let range_check: AB::ExprEF = alpha.into() + local.range[RANGE_V_COL_IDX].clone().into();
let memory_lookups = mv0.clone() * mv1.clone();
let stack_lookups = sv0.clone() * sv1.clone() * sv2.clone() * sv3.clone();
let lookups = range_check.clone() * stack_lookups.clone() * memory_lookups.clone();
let not_4: AB::Expr = AB::Expr::ONE - local.decoder[OP_BIT_4_COL_IDX].clone().into();
let not_5: AB::Expr = AB::Expr::ONE - local.decoder[OP_BIT_5_COL_IDX].clone().into();
let u32_rc_op: AB::Expr = local.decoder[OP_BIT_6_COL_IDX].clone().into() * not_5 * not_4;
let sflag_rc_mem = range_check.clone() * memory_lookups.clone() * u32_rc_op;
let s_0: AB::Expr = local.chiplets[CHIPLET_S0_IDX].clone().into();
let s_1: AB::Expr = local.chiplets[CHIPLET_S1_IDX].clone().into();
let s_2: AB::Expr = local.chiplets[CHIPLET_S2_IDX].clone().into();
let chiplets_memory_flag: AB::Expr = s_0 * s_1 * (AB::Expr::ONE - s_2);
let mflag_rc_stack = range_check * stack_lookups.clone() * chiplets_memory_flag;
let b_next_term = b_next.into() * lookups.clone();
let b_term = b_local.into() * lookups;
let rc_term = stack_lookups * memory_lookups * local.range[RANGE_M_COL_IDX].clone().into();
let s0_term = sflag_rc_mem.clone() * sv1.clone() * sv2.clone() * sv3.clone();
let s1_term = sflag_rc_mem.clone() * sv0.clone() * sv2.clone() * sv3.clone();
let s2_term = sflag_rc_mem.clone() * sv0.clone() * sv1.clone() * sv3;
let s3_term = sflag_rc_mem * sv0 * sv1 * sv2;
let m0_term: AB::ExprEF = mflag_rc_stack.clone() * mv1;
let m1_term = mflag_rc_stack * mv0;
builder.tagged(TAG_RANGE_BUS_BASE, RANGE_BUS_NAME, |builder| {
builder.when_transition().assert_zero_ext(
b_next_term - b_term - rc_term
+ s0_term
+ s1_term
+ s2_term
+ s3_term
+ m0_term
+ m1_term,
);
});
}