use core::{array, borrow::Borrow};
use miden_core::field::PrimeCharacteristicRing;
use crate::{
constraints::{
chiplets::columns::PeriodicCols,
lookup::{
chiplet_air::{ChipletBusContext, ChipletLookupBuilder},
messages::{
AceInitMsg, BitwiseMsg, BusId, HasherMsg, HasherPayload, KernelRomMsg,
MemoryResponseMsg,
},
},
utils::BoolNot,
},
lookup::{Deg, LookupBatch, LookupColumn, LookupGroup},
};
pub(in crate::constraints::lookup) const MAX_INTERACTIONS_PER_ROW: usize = 2;
#[allow(clippy::too_many_lines)]
pub(in crate::constraints::lookup) fn emit_chiplet_responses<LB>(
builder: &mut LB,
ctx: &ChipletBusContext<LB>,
) where
LB: ChipletLookupBuilder,
{
let local = ctx.local;
let next = ctx.next;
let k_transition: LB::Expr = {
let periodic: &PeriodicCols<LB::PeriodicVar> = builder.periodic_values().borrow();
periodic.bitwise.k_transition.into()
};
let ctrl = local.controller();
let ctrl_next = next.controller();
let bw = local.bitwise();
let mem = local.memory();
let ace = local.ace();
let krom = local.kernel_rom();
let hs0: LB::Expr = ctrl.s0.into();
let hs1: LB::Expr = ctrl.s1.into();
let hs2: LB::Expr = ctrl.s2.into();
let is_boundary: LB::Expr = ctrl.is_boundary.into();
let not_hs0 = hs0.not();
let not_hs1 = hs1.not();
let not_hs2 = hs2.not();
let state: [LB::Var; 12] = ctrl.state;
let rate_0: [LB::Var; 4] = array::from_fn(|i| ctrl.state[i]);
let rate_1: [LB::Var; 4] = array::from_fn(|i| ctrl.state[4 + i]);
let controller_flag = ctx.chiplet_active.controller.clone();
let f_sponge_start: LB::Expr = controller_flag.clone()
* hs0.clone()
* not_hs1.clone()
* not_hs2.clone()
* is_boundary.clone();
let f_sponge_respan: LB::Expr = controller_flag.clone()
* hs0.clone()
* not_hs1.clone()
* not_hs2.clone()
* is_boundary.not();
let f_mp: LB::Expr =
controller_flag.clone() * hs0.clone() * not_hs1.clone() * hs2.clone() * is_boundary.clone();
let f_mv: LB::Expr =
controller_flag.clone() * hs0.clone() * hs1.clone() * not_hs2.clone() * is_boundary.clone();
let f_mu: LB::Expr = controller_flag.clone() * hs0 * hs1 * hs2.clone() * is_boundary.clone();
let f_hout: LB::Expr = controller_flag.clone() * not_hs0.clone() * not_hs1.clone() * not_hs2;
let f_sout: LB::Expr = controller_flag * not_hs0 * not_hs1 * hs2 * is_boundary;
let is_bitwise_responding: LB::Expr = ctx.chiplet_active.bitwise.clone() * k_transition.not();
let is_ace_init: LB::Expr = ctx.chiplet_active.ace.clone() * ace.s_start.into();
let clk_plus_one: LB::Expr = local.system.clk.into() + LB::Expr::ONE;
let full_state = || -> [LB::Expr; 12] { state.map(Into::into) };
let full_rate = || -> [LB::Expr; 8] {
array::from_fn(|i| if i < 4 { rate_0[i].into() } else { rate_1[i - 4].into() })
};
builder.next_column(
|col| {
col.group(
"chiplet_responses",
|g| {
g.add(
"sponge_start",
f_sponge_start,
|| HasherMsg {
kind: BusId::HasherLinearHashInit,
addr: clk_plus_one.clone(),
node_index: LB::Expr::ZERO,
payload: HasherPayload::State(full_state()),
},
Deg { v: 5, u: 6 },
);
g.add(
"sponge_respan",
f_sponge_respan,
|| HasherMsg {
kind: BusId::HasherAbsorption,
addr: clk_plus_one.clone(),
node_index: LB::Expr::ZERO,
payload: HasherPayload::Rate(full_rate()),
},
Deg { v: 5, u: 6 },
);
for (name, flag, kind) in [
("mp_verify_input", f_mp, BusId::HasherMerkleVerifyInit),
("mr_update_old_input", f_mv, BusId::HasherMerkleOldInit),
("mr_update_new_input", f_mu, BusId::HasherMerkleNewInit),
] {
g.add(
name,
flag,
|| {
let addr = clk_plus_one.clone();
let node_index: LB::Expr = ctrl.node_index.into();
let bit: LB::Expr =
node_index.clone() - ctrl_next.node_index.into().double();
let one_minus_bit = bit.not();
let word: [LB::Expr; 4] = array::from_fn(|i| {
one_minus_bit.clone() * rate_0[i].into()
+ bit.clone() * rate_1[i].into()
});
HasherMsg {
kind,
addr,
node_index,
payload: HasherPayload::Word(word),
}
},
Deg { v: 5, u: 7 },
);
}
g.add(
"hout",
f_hout,
|| {
let addr = clk_plus_one.clone();
let node_index: LB::Expr = ctrl.node_index.into();
let word: [LB::Expr; 4] = rate_0.map(LB::Expr::from);
HasherMsg {
kind: BusId::HasherReturnHash,
addr,
node_index,
payload: HasherPayload::Word(word),
}
},
Deg { v: 4, u: 5 },
);
g.add(
"sout",
f_sout,
|| HasherMsg {
kind: BusId::HasherReturnState,
addr: clk_plus_one.clone(),
node_index: LB::Expr::ZERO,
payload: HasherPayload::State(full_state()),
},
Deg { v: 5, u: 6 },
);
g.add(
"bitwise",
is_bitwise_responding,
|| {
let bw_op: LB::Expr = bw.op_flag.into();
BitwiseMsg {
op: bw_op,
a: bw.a.into(),
b: bw.b.into(),
result: bw.output.into(),
}
},
Deg { v: 3, u: 4 },
);
g.add(
"memory",
ctx.chiplet_active.memory.clone(),
|| {
let mem_is_read: LB::Expr = mem.is_read.into();
let is_word: LB::Expr = mem.is_word.into();
let mem_idx0: LB::Expr = mem.idx0.into();
let mem_idx1: LB::Expr = mem.idx1.into();
let addr = mem.word_addr.into()
+ mem_idx1.clone() * LB::Expr::from_u16(2)
+ mem_idx0.clone();
let word: [LB::Expr; 4] = mem.values.map(LB::Expr::from);
let element = word[0].clone() * mem_idx0.not() * mem_idx1.not()
+ word[1].clone() * mem_idx0.clone() * mem_idx1.not()
+ word[2].clone() * mem_idx0.not() * mem_idx1.clone()
+ word[3].clone() * mem_idx0 * mem_idx1;
MemoryResponseMsg {
is_read: mem_is_read,
ctx: mem.ctx.into(),
addr,
clk: mem.clk.into(),
is_word,
element,
word,
}
},
Deg { v: 3, u: 7 },
);
g.add(
"ace_init",
is_ace_init,
|| {
let num_eval = ace.read().num_eval.into() + LB::Expr::ONE;
let num_read = ace.id_0.into() + LB::Expr::ONE - num_eval.clone();
AceInitMsg {
clk: ace.clk.into(),
ctx: ace.ctx.into(),
ptr: ace.ptr.into(),
num_read,
num_eval,
}
},
Deg { v: 5, u: 6 },
);
let kernel_gate = ctx.chiplet_active.kernel_rom.clone();
g.batch(
"kernel_rom",
kernel_gate,
|b| {
let krom_mult: LB::Expr = krom.multiplicity.into();
let digest: [LB::Expr; 4] = krom.root.map(LB::Expr::from);
b.remove(
"kernel_rom_init",
KernelRomMsg::init(digest.clone()),
Deg { v: 5, u: 6 },
);
b.insert(
"kernel_rom_call",
krom_mult,
KernelRomMsg::call(digest),
Deg { v: 6, u: 6 },
);
},
Deg { v: 7, u: 7 }, );
},
Deg { v: 7, u: 7 },
);
},
Deg { v: 7, u: 7 },
);
}