use std::collections::HashMap;
use crate::engine::dispatch_entry::OpDispatch;
use crate::ids::{NodeSiteId, OpRef};
use bb_ir::proto::onnx::FunctionProto;
pub struct GraphSlot {
pub function: FunctionProto,
pub op_dispatch: Vec<OpDispatch>,
pub consumers: HashMap<NodeSiteId, Vec<OpRef>>,
pub site_names: HashMap<String, NodeSiteId>,
pub top_level_outputs: HashMap<NodeSiteId, String>,
pub recv_sender_sites: HashMap<NodeSiteId, NodeSiteId>,
pub recv_wire_type_hash: HashMap<NodeSiteId, u64>,
pub recv_site_to_slot_id: HashMap<NodeSiteId, u32>,
pub is_entry_point: bool,
}
impl GraphSlot {
#[cfg(any(test, feature = "test-components"))]
pub fn new_for_test(_name: String, function: FunctionProto) -> Self {
Self {
function,
op_dispatch: Vec::new(),
consumers: HashMap::new(),
site_names: HashMap::new(),
top_level_outputs: HashMap::new(),
recv_sender_sites: HashMap::new(),
recv_wire_type_hash: HashMap::new(),
recv_site_to_slot_id: HashMap::new(),
is_entry_point: false,
}
}
pub fn from_function(
_name: String,
function: FunctionProto,
graph_idx: u32,
next_node_site_id: &mut u64,
) -> Self {
let mut site_names: HashMap<String, NodeSiteId> = HashMap::new();
let mut consumers: HashMap<NodeSiteId, Vec<OpRef>> = HashMap::new();
let mut op_refs: Vec<OpRef> = Vec::with_capacity(function.node.len());
let mut op_dispatch: Vec<OpDispatch> = Vec::with_capacity(function.node.len());
for (idx, node) in function.node.iter().enumerate() {
let op_ref = OpRef::pack(graph_idx, idx as u32);
op_refs.push(op_ref);
op_dispatch.push(OpDispatch::Unresolved);
for out in &node.output {
if out.is_empty() {
continue;
}
site_names.entry(out.clone()).or_insert_with(|| {
let r = NodeSiteId::from(*next_node_site_id);
*next_node_site_id = next_node_site_id.saturating_add(1);
r
});
}
}
for (idx, node) in function.node.iter().enumerate() {
let consumer = op_refs[idx];
for input in &node.input {
if input.is_empty() {
continue;
}
let Some(&site) = site_names.get(input) else {
continue;
};
consumers.entry(site).or_default().push(consumer);
}
}
let mut top_level_outputs: HashMap<NodeSiteId, String> = HashMap::new();
for name in &function.output {
if let Some(&site) = site_names.get(name) {
top_level_outputs.insert(site, name.clone());
}
}
let mut recv_sender_sites: HashMap<NodeSiteId, NodeSiteId> = HashMap::new();
let mut recv_site_to_slot_id: HashMap<NodeSiteId, u32> = HashMap::new();
for node in &function.node {
if node.domain != "ai.bytesandbrains.wire" || node.op_type != "Recv" {
continue;
}
let payload = node.output.first().and_then(|n| site_names.get(n));
let sender = node.output.get(1).and_then(|n| site_names.get(n));
if let (Some(&p), Some(&s)) = (payload, sender) {
recv_sender_sites.insert(p, s);
}
if let Some(&payload_site) = payload {
let slot_id = node
.metadata_props
.iter()
.find(|kv| kv.key == bb_ir::keys::RECV_SLOT_ID_KEY)
.and_then(|kv| kv.value.parse::<u32>().ok());
if let Some(slot_id) = slot_id {
recv_site_to_slot_id.insert(payload_site, slot_id);
}
}
}
let recv_wire_type_hash: HashMap<NodeSiteId, u64> = HashMap::new();
let _ = op_refs;
Self {
function,
op_dispatch,
consumers,
site_names,
top_level_outputs,
recv_sender_sites,
recv_wire_type_hash,
recv_site_to_slot_id,
is_entry_point: true,
}
}
}