use std::collections::{BTreeMap, HashMap};
use crate::error::CompileError;
use crate::synthesize_wire_recvs::SYNTHESIZED_FROM_KEY;
use bb_ir::peer_class::{home_class_of_node, SELF_CLASS};
use bb_ir::proto::onnx::{GraphProto, NodeProto, ValueInfoProto};
pub const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
pub const SEND_OP_TYPES: &[&str] = &["Send"];
pub const RECV_OP_TYPES: &[&str] = &["Recv"];
#[derive(Debug, Default)]
pub struct NetworkAnalysis {
pub per_role: BTreeMap<String, GraphProto>,
pub wire_edges: Vec<WireEdge>,
}
#[derive(Debug)]
pub struct WireEdge {
pub producer_role: String,
pub consumer_role: String,
pub value_name: String,
pub send_node: NodeProto,
pub recv_node: NodeProto,
}
pub fn partition_by_wire_ops(graph: &GraphProto) -> Result<NetworkAnalysis, CompileError> {
let mut per_role: BTreeMap<String, GraphProto> = BTreeMap::new();
for node in &graph.node {
let class = home_class_of_node(node)
.map(str::to_string)
.unwrap_or_else(|| SELF_CLASS.to_string());
per_role.entry(class).or_default().node.push(node.clone());
}
let value_info_by_name: HashMap<&str, &ValueInfoProto> = graph
.value_info
.iter()
.map(|v| (v.name.as_str(), v))
.collect();
let input_by_name: HashMap<&str, &ValueInfoProto> =
graph.input.iter().map(|v| (v.name.as_str(), v)).collect();
let output_by_name: HashMap<&str, &ValueInfoProto> =
graph.output.iter().map(|v| (v.name.as_str(), v)).collect();
for sub in per_role.values_mut() {
let mut referenced: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
for node in &sub.node {
for inp in &node.input {
if !inp.is_empty() {
referenced.insert(inp.clone());
}
}
}
for name in &referenced {
if let Some(&vi) = input_by_name.get(name.as_str()) {
sub.input.push(vi.clone());
}
if let Some(&vi) = value_info_by_name.get(name.as_str()) {
sub.value_info.push(vi.clone());
}
}
let mut produced_here: std::collections::BTreeSet<String> =
std::collections::BTreeSet::new();
for node in &sub.node {
for out in &node.output {
if !out.is_empty() {
produced_here.insert(out.clone());
}
}
}
for name in &produced_here {
if let Some(&vi) = output_by_name.get(name.as_str()) {
sub.output.push(vi.clone());
}
}
}
let wire_edges = discover_wire_edges(graph);
Ok(NetworkAnalysis {
per_role,
wire_edges,
})
}
fn discover_wire_edges(graph: &GraphProto) -> Vec<WireEdge> {
let mut send_by_idx: HashMap<usize, &NodeProto> = HashMap::new();
for node in &graph.node {
if node.domain != WIRE_DOMAIN || !SEND_OP_TYPES.contains(&node.op_type.as_str()) {
continue;
}
if let Some(idx) = parse_send_sentinel_idx(node) {
send_by_idx.insert(idx, node);
}
}
let mut edges = Vec::new();
for recv in &graph.node {
if recv.domain != WIRE_DOMAIN || !RECV_OP_TYPES.contains(&recv.op_type.as_str()) {
continue;
}
let Some(send_idx) = recv
.metadata_props
.iter()
.find(|p| p.key == SYNTHESIZED_FROM_KEY)
.and_then(|p| p.value.parse::<usize>().ok())
else {
continue;
};
let Some(send) = send_by_idx.get(&send_idx) else {
continue;
};
let producer_role = home_class_of_node(send)
.map(str::to_string)
.unwrap_or_else(|| SELF_CLASS.to_string());
let consumer_role = home_class_of_node(recv)
.map(str::to_string)
.unwrap_or_else(|| SELF_CLASS.to_string());
let value_name = recv.output.first().cloned().unwrap_or_default();
edges.push(WireEdge {
producer_role,
consumer_role,
value_name,
send_node: (*send).clone(),
recv_node: recv.clone(),
});
}
edges
}
fn parse_send_sentinel_idx(send: &NodeProto) -> Option<usize> {
let first = send.output.first()?;
let marker = "__send_sentinel_";
let pos = first.rfind(marker)?;
first[pos + marker.len()..].parse().ok()
}