use core::array;
use miden_core::field::PrimeCharacteristicRing;
use crate::{
constraints::{
lookup::{
chiplet_air::{ChipletBusContext, ChipletLookupBuilder},
messages::{MemoryMsg, RangeMsg, SiblingBit, SiblingMsg},
},
utils::BoolNot,
},
lookup::{Deg, LookupBatch, LookupColumn, LookupGroup},
trace::{
CHIPLETS_OFFSET,
chiplets::{
MEMORY_WORD_ADDR_HI_COL_IDX, MEMORY_WORD_ADDR_LO_COL_IDX,
ace::{ACE_INSTRUCTION_ID1_OFFSET, ACE_INSTRUCTION_ID2_OFFSET},
},
},
};
pub(in crate::constraints::lookup) const MAX_INTERACTIONS_PER_ROW: usize = 5;
#[allow(clippy::too_many_lines)]
pub(in crate::constraints::lookup) fn emit_hash_kernel_table<LB>(
builder: &mut LB,
ctx: &ChipletBusContext<LB>,
) where
LB: ChipletLookupBuilder,
{
let local = ctx.local;
let next = ctx.next;
let ctrl = local.controller();
let ctrl_next = next.controller();
let hs0: LB::Expr = ctrl.s0.into();
let hs1: LB::Expr = ctrl.s1.into();
let hs2: LB::Expr = ctrl.s2.into();
let controller_flag = ctx.chiplet_active.controller.clone();
let f_mu_all: LB::Expr = controller_flag.clone() * hs0.clone() * hs1.clone() * hs2.clone();
let f_mv_all: LB::Expr = controller_flag * hs0 * hs1 * hs2.not();
let rate_0: [LB::Var; 4] = array::from_fn(|i| ctrl.state[i]);
let rate_1: [LB::Var; 4] = array::from_fn(|i| ctrl.state[4 + i]);
let mrupdate_id = ctrl.mrupdate_id;
let node_index = ctrl.node_index;
let node_index_next: LB::Expr = ctrl_next.node_index.into();
let bit: LB::Expr = node_index.into() - node_index_next.double();
let one_minus_bit: LB::Expr = bit.not();
let ace = local.ace();
let block_sel: LB::Expr = ace.s_block.into();
let is_ace_row = ctx.chiplet_active.ace.clone();
let f_ace_read: LB::Expr = is_ace_row.clone() * block_sel.not();
let f_ace_eval: LB::Expr = is_ace_row * block_sel;
let ace_clk = ace.clk;
let ace_ctx = ace.ctx;
let ace_ptr = ace.ptr;
let ace_v0 = ace.v_0;
let ace_v1 = ace.v_1;
let ace_id_1 = ace.id_1;
let ace_id_2 = ace.eval().id_2;
let ace_eval_op = ace.eval_op;
let mem_active = ctx.chiplet_active.memory.clone();
let mem = local.memory();
let mem_d0 = mem.d0;
let mem_d1 = mem.d1;
let mem_w0 = local.chiplets[MEMORY_WORD_ADDR_LO_COL_IDX - CHIPLETS_OFFSET];
let mem_w1 = local.chiplets[MEMORY_WORD_ADDR_HI_COL_IDX - CHIPLETS_OFFSET];
builder.next_column(
|col| {
col.group(
"sibling_ace_memory",
|g| {
for (op_name, is_add, f_all, bit_tag, bit_gate) in [
(
"sibling_mv_b0",
true,
f_mv_all.clone(),
SiblingBit::Zero,
one_minus_bit.clone(),
),
("sibling_mv_b1", true, f_mv_all, SiblingBit::One, bit.clone()),
("sibling_mu_b0", false, f_mu_all.clone(), SiblingBit::Zero, one_minus_bit),
("sibling_mu_b1", false, f_mu_all, SiblingBit::One, bit),
] {
let gate = f_all * bit_gate;
let build = move || {
let mrupdate_id: LB::Expr = mrupdate_id.into();
let node_index: LB::Expr = node_index.into();
let h = match bit_tag {
SiblingBit::Zero => array::from_fn(|i| rate_1[i].into()),
SiblingBit::One => array::from_fn(|i| rate_0[i].into()),
};
SiblingMsg { bit: bit_tag, mrupdate_id, node_index, h }
};
if is_add {
g.add(op_name, gate, build, Deg { v: 5, u: 6 });
} else {
g.remove(op_name, gate, build, Deg { v: 5, u: 6 });
}
}
g.remove(
"ace_mem_read_word",
f_ace_read,
move || {
let clk = ace_clk.into();
let ctx = ace_ctx.into();
let addr = ace_ptr.into();
let word = [
ace_v0.0.into(),
ace_v0.1.into(),
ace_v1.0.into(),
ace_v1.1.into(),
];
MemoryMsg::read_word(ctx, addr, clk, word)
},
Deg { v: 5, u: 6 },
);
g.remove(
"ace_mem_eval_element",
f_ace_eval,
move || {
let clk = ace_clk.into();
let ctx = ace_ctx.into();
let addr = ace_ptr.into();
let id_1: LB::Expr = ace_id_1.into();
let id_2: LB::Expr = ace_id_2.into();
let eval_op: LB::Expr = ace_eval_op.into();
let element = id_1
+ id_2 * LB::Expr::from(ACE_INSTRUCTION_ID1_OFFSET)
+ (eval_op + LB::Expr::ONE)
* LB::Expr::from(ACE_INSTRUCTION_ID2_OFFSET);
MemoryMsg::read_element(ctx, addr, clk, element)
},
Deg { v: 5, u: 6 },
);
g.batch(
"memory_range_checks",
mem_active,
move |b| {
b.remove(
"mem_d0",
RangeMsg { value: mem_d0.into() },
Deg { v: 3, u: 4 },
);
b.remove(
"mem_d1",
RangeMsg { value: mem_d1.into() },
Deg { v: 3, u: 4 },
);
let w0: LB::Expr = mem_w0.into();
let w1: LB::Expr = mem_w1.into();
let w1_mul4 = w1.clone() * LB::Expr::from_u16(4);
b.remove("mem_w0", RangeMsg { value: w0 }, Deg { v: 3, u: 4 });
b.remove("mem_w1", RangeMsg { value: w1 }, Deg { v: 3, u: 4 });
b.remove(
"mem_w1_mul4",
RangeMsg { value: w1_mul4 },
Deg { v: 3, u: 4 },
);
},
Deg { v: 7, u: 8 }, );
},
Deg { v: 7, u: 8 },
);
},
Deg { v: 7, u: 8 },
);
}