use bb_ir::keys::{
binding_key, encode_binding_value, stamp_model_metadata, COMPILED_CURRENT_VERSION,
COMPILED_KEY, RECV_SLOT_ID_KEY, REQUIRED_TRAIT_KEY, SLOT_ID_KEY,
};
use bb_ir::proto::onnx::{ModelProto, NodeProto, StringStringEntryProto};
use crate::artifact::BindingSpec;
pub(crate) fn stamp_compilation_metadata(
model: &mut ModelProto,
bindings: &BindingSpec,
target_name: &str,
) {
stamp_model_metadata(model, COMPILED_KEY, COMPILED_CURRENT_VERSION);
let role_to_slot_id = collect_role_slot_ids(model);
for slot in &bindings.slots {
let role_canon = canonical_role(&slot.role);
let slot_id_or_neg1 = role_to_slot_id
.iter()
.find(|(role, _)| canonical_role(role) == role_canon)
.map(|(_, id)| *id as i64)
.unwrap_or(-1);
let key = binding_key(target_name, &slot.slot);
let value = encode_binding_value(&role_canon, &slot.concrete_type_name, slot_id_or_neg1);
stamp_model_metadata(model, &key, &value);
}
stamp_recv_slot_ids(model);
}
fn stamp_recv_slot_ids(model: &mut ModelProto) {
for function in &mut model.functions {
let mut recv_indices: Vec<(usize, String)> = Vec::new();
let mut consumer_slot_ids: Vec<(String, u32)> = Vec::new();
for (idx, node) in function.node.iter().enumerate() {
if is_wire_recv(node) {
if let Some(payload) = node.output.first() {
if !payload.is_empty() {
recv_indices.push((idx, payload.clone()));
}
}
continue;
}
let Some(slot_id) =
metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
else {
continue;
};
for input in &node.input {
if !input.is_empty() {
consumer_slot_ids.push((input.clone(), slot_id));
}
}
}
for (recv_idx, payload_name) in recv_indices {
let Some(slot_id) = consumer_slot_ids
.iter()
.find(|(name, _)| name == &payload_name)
.map(|(_, id)| *id)
else {
continue;
};
stamp_node_metadata(
&mut function.node[recv_idx],
RECV_SLOT_ID_KEY,
&slot_id.to_string(),
);
}
}
}
fn is_wire_recv(node: &NodeProto) -> bool {
node.domain == "ai.bytesandbrains.wire" && node.op_type == "Recv"
}
fn stamp_node_metadata(node: &mut NodeProto, key: &str, value: &str) {
if let Some(existing) = node.metadata_props.iter_mut().find(|p| p.key == key) {
existing.value = value.to_string();
return;
}
node.metadata_props.push(StringStringEntryProto {
key: key.to_string(),
value: value.to_string(),
});
}
fn collect_role_slot_ids(model: &ModelProto) -> Vec<(String, u32)> {
let mut out: Vec<(String, u32)> = Vec::new();
for function in &model.functions {
for node in &function.node {
let Some(role) = metadata_value(node, REQUIRED_TRAIT_KEY) else {
continue;
};
let Some(slot_id) =
metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
else {
continue;
};
if !out.iter().any(|(r, id)| r == role && *id == slot_id) {
out.push((role.to_string(), slot_id));
}
}
}
out
}
fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
node.metadata_props
.iter()
.find(|p| p.key == key)
.map(|p| p.value.as_str())
}
fn canonical_role(role: &str) -> String {
role.strip_suffix("Runtime").unwrap_or(role).to_string()
}
pub fn stamp_for_test(model: &mut ModelProto, bindings: &[(&str, &str, &str)]) {
let target = model
.functions
.first()
.map(|f| f.name.clone())
.unwrap_or_default();
let mut spec = BindingSpec::new();
for (slot, role, type_name) in bindings {
spec.push(slot.to_string(), role.to_string(), type_name.to_string());
}
stamp_compilation_metadata(model, &spec, &target);
}