use miden_core::field::PrimeCharacteristicRing;
use miden_crypto::stark::air::{LiftedAirBuilder, WindowAccess};
use crate::{
Felt, MainTraceRow,
constraints::{
bus::indices::V_WIRING,
chiplets::selectors::ace_chiplet_flag,
tagging::{
TagGroup, TaggingAirBuilderExt, ids::TAG_WIRING_BUS_BASE, tagged_assert_zero_ext,
},
},
trace::{
Challenges,
chiplets::ace::{
CLK_IDX, CTX_IDX, ID_0_IDX, ID_1_IDX, ID_2_IDX, M_0_IDX, M_1_IDX, SELECTOR_BLOCK_IDX,
V_0_0_IDX, V_0_1_IDX, V_1_0_IDX, V_1_1_IDX, V_2_0_IDX, V_2_1_IDX,
},
},
};
const ACE_OFFSET: usize = 4;
const WIRING_BUS_BASE_ID: usize = TAG_WIRING_BUS_BASE;
const WIRING_BUS_NAME: &str = "chiplets.bus.wiring.transition";
const WIRING_BUS_NAMES: [&str; 1] = [WIRING_BUS_NAME; 1];
const WIRING_BUS_TAGS: TagGroup = TagGroup {
base: WIRING_BUS_BASE_ID,
names: &WIRING_BUS_NAMES,
};
pub fn enforce_wiring_bus_constraint<AB>(
builder: &mut AB,
local: &MainTraceRow<AB::Var>,
_next: &MainTraceRow<AB::Var>,
challenges: &Challenges<AB::ExprEF>,
) where
AB: TaggingAirBuilderExt<F = Felt>,
{
let (v_local, v_next) = {
let aux = builder.permutation();
let aux_local = aux.current_slice();
let aux_next = aux.next_slice();
(aux_local[V_WIRING], aux_next[V_WIRING])
};
let s0: AB::Expr = local.chiplets[0].clone().into();
let s1: AB::Expr = local.chiplets[1].clone().into();
let s2: AB::Expr = local.chiplets[2].clone().into();
let s3: AB::Expr = local.chiplets[3].clone().into();
let ace_flag = ace_chiplet_flag(s0, s1, s2, s3);
let sblock: AB::Expr = load_ace_col::<AB>(local, SELECTOR_BLOCK_IDX);
let is_eval = sblock.clone();
let is_read = AB::Expr::ONE - sblock;
let clk: AB::Expr = load_ace_col::<AB>(local, CLK_IDX);
let ctx: AB::Expr = load_ace_col::<AB>(local, CTX_IDX);
let wire_0 = load_ace_wire::<AB>(local, ID_0_IDX, V_0_0_IDX, V_0_1_IDX);
let wire_1 = load_ace_wire::<AB>(local, ID_1_IDX, V_1_0_IDX, V_1_1_IDX);
let wire_2 = load_ace_wire::<AB>(local, ID_2_IDX, V_2_0_IDX, V_2_1_IDX);
let m0: AB::Expr = load_ace_col::<AB>(local, M_0_IDX);
let m1: AB::Expr = load_ace_col::<AB>(local, M_1_IDX);
let wire_0: AB::ExprEF = encode_wire::<AB>(challenges, &clk, &ctx, &wire_0);
let wire_1: AB::ExprEF = encode_wire::<AB>(challenges, &clk, &ctx, &wire_1);
let wire_2: AB::ExprEF = encode_wire::<AB>(challenges, &clk, &ctx, &wire_2);
let v_local_ef: AB::ExprEF = v_local.into();
let v_next_ef: AB::ExprEF = v_next.into();
let delta = v_next_ef.clone() - v_local_ef.clone();
let read_terms =
wire_1.clone() * wire_2.clone() * m0.clone() + wire_0.clone() * wire_2.clone() * m1;
let eval_terms = wire_1.clone() * wire_2.clone() * m0
- wire_0.clone() * wire_2.clone()
- wire_0.clone() * wire_1.clone();
let read_gate = ace_flag.clone() * is_read;
let eval_gate = ace_flag * is_eval;
let common_den = wire_0.clone() * wire_1.clone() * wire_2.clone();
let rhs = read_terms * read_gate + eval_terms * eval_gate;
let wiring_constraint = delta * common_den - rhs;
let mut idx = 0;
tagged_assert_zero_ext(builder, &WIRING_BUS_TAGS, &mut idx, wiring_constraint);
}
struct AceWire<Expr> {
id: Expr,
v0: Expr,
v1: Expr,
}
fn load_ace_wire<AB>(
row: &MainTraceRow<AB::Var>,
id_idx: usize,
v0_idx: usize,
v1_idx: usize,
) -> AceWire<AB::Expr>
where
AB: LiftedAirBuilder<F = Felt>,
{
AceWire {
id: load_ace_col::<AB>(row, id_idx),
v0: load_ace_col::<AB>(row, v0_idx),
v1: load_ace_col::<AB>(row, v1_idx),
}
}
fn encode_wire<AB>(
challenges: &Challenges<AB::ExprEF>,
clk: &AB::Expr,
ctx: &AB::Expr,
wire: &AceWire<AB::Expr>,
) -> AB::ExprEF
where
AB: LiftedAirBuilder<F = Felt>,
{
challenges.encode([clk.clone(), ctx.clone(), wire.id.clone(), wire.v0.clone(), wire.v1.clone()])
}
fn load_ace_col<AB>(row: &MainTraceRow<AB::Var>, ace_col_idx: usize) -> AB::Expr
where
AB: LiftedAirBuilder<F = Felt>,
{
let local_idx = ACE_OFFSET + ace_col_idx;
row.chiplets[local_idx].clone().into()
}