use std::collections::HashMap;
use crate::error::CompileError;
use crate::partition_by_wire_ops::WIRE_DOMAIN;
use bb_ir::peer_class::{
home_class_of_node, peer_class_of_node, peer_class_of_value_info, HOME_CLASS_KEY,
PEER_CLASS_KEY, SELF_CLASS,
};
use bb_ir::proto::onnx::{type_proto, GraphProto, StringStringEntryProto, TypeProto};
pub fn infer_peer_classes(graph: &mut GraphProto) -> Result<(), CompileError> {
stamp_peer_class_on_inputs_feeding_wire_sends(graph);
let mut home: HashMap<String, String> = HashMap::new();
let mut wire_id_to_dest_class: HashMap<String, String> = HashMap::new();
let mut peer_class_of_value: HashMap<String, String> = HashMap::new();
let mut peer_id_value_names: std::collections::HashSet<String> =
std::collections::HashSet::new();
for vi in &graph.input {
home.insert(vi.name.clone(), SELF_CLASS.to_string());
if value_info_is_peer_id(vi) {
peer_id_value_names.insert(vi.name.clone());
}
if let Some(class) = peer_class_of_value_info(vi) {
peer_class_of_value.insert(vi.name.clone(), class.to_string());
}
}
for vi in &graph.value_info {
if value_info_is_peer_id(vi) {
peer_id_value_names.insert(vi.name.clone());
}
if let Some(class) = peer_class_of_value_info(vi) {
peer_class_of_value
.entry(vi.name.clone())
.or_insert_with(|| class.to_string());
}
}
for node in graph.node.iter_mut() {
if home_class_of_node(node).is_some() {
continue;
}
if let Some(class) = peer_class_of_node(node) {
for out in &node.output {
if !out.is_empty() {
peer_class_of_value
.entry(out.clone())
.or_insert_with(|| class.to_string());
}
}
}
let is_wire_send = node.domain == WIRE_DOMAIN && node.op_type == "Send";
let is_wire_recv = node.domain == WIRE_DOMAIN && node.op_type == "Recv";
if is_wire_send {
let payload_name = node.input.first().cloned().unwrap_or_default();
let peer_input = node.input.last().cloned().unwrap_or_default();
let payload_home = home
.get(&payload_name)
.cloned()
.unwrap_or_else(|| SELF_CLASS.to_string());
let dest_class = peer_class_of_value
.get(&peer_input)
.cloned()
.unwrap_or_else(|| "@default".to_string());
if let Some(wire_id) = read_wire_id(node) {
wire_id_to_dest_class.insert(wire_id, dest_class.clone());
}
if let Some(first_out) = node.output.first() {
if !first_out.is_empty() {
let class = if node.output.len() >= 2 {
dest_class.clone()
} else {
payload_home.clone()
};
home.insert(first_out.clone(), class);
}
}
if let Some(handle_out) = node.output.get(1) {
if !handle_out.is_empty() {
home.insert(handle_out.clone(), payload_home.clone());
}
}
stamp_home(node, &payload_home);
continue;
}
if is_wire_recv {
let dest_class = read_wire_id(node)
.and_then(|wid| wire_id_to_dest_class.get(&wid).cloned())
.unwrap_or_else(|| SELF_CLASS.to_string());
for out in &node.output {
if !out.is_empty() {
home.insert(out.clone(), dest_class.clone());
}
}
stamp_home(node, &dest_class);
continue;
}
let mut input_homes: Vec<String> = Vec::new();
for input in &node.input {
if input.is_empty() {
continue;
}
if peer_id_value_names.contains(input) {
continue;
}
if let Some(h) = home.get(input) {
input_homes.push(h.clone());
}
}
input_homes.dedup();
let node_home = match input_homes.len() {
0 => SELF_CLASS.to_string(),
1 => input_homes.remove(0),
_ => {
return Err(CompileError::CrossClassDataflow {
node_name: node.name.clone(),
home_a: input_homes[0].clone(),
home_b: input_homes[1].clone(),
});
}
};
for out in &node.output {
if !out.is_empty() {
home.insert(out.clone(), node_home.clone());
}
}
stamp_home(node, &node_home);
}
Ok(())
}
fn stamp_peer_class_on_inputs_feeding_wire_sends(graph: &mut GraphProto) {
let producers = build_producer_map(graph);
let mut input_roots: std::collections::HashSet<String> = std::collections::HashSet::new();
let mut visited: std::collections::HashSet<String> = std::collections::HashSet::new();
for node in &graph.node {
if node.domain != WIRE_DOMAIN || node.op_type != "Send" {
continue;
}
let Some(peer_input) = node.input.last() else {
continue;
};
if peer_input.is_empty() {
continue;
}
trace_peer_source(
peer_input,
&producers,
&graph.node,
&mut input_roots,
&mut visited,
);
}
if input_roots.is_empty() {
return;
}
for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
if !input_roots.contains(&vi.name) {
continue;
}
let already = vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY);
if !already {
vi.metadata_props.push(StringStringEntryProto {
key: PEER_CLASS_KEY.to_string(),
value: vi.name.clone(),
});
}
}
}
fn trace_peer_source(
name: &str,
producers: &HashMap<String, usize>,
nodes: &[bb_ir::proto::onnx::NodeProto],
input_roots: &mut std::collections::HashSet<String>,
visited: &mut std::collections::HashSet<String>,
) {
if !visited.insert(name.to_string()) {
return;
}
if let Some(&idx) = producers.get(name) {
let producer = &nodes[idx];
if !is_peer_pass_through(producer) {
return;
}
for input in &producer.input {
if input.is_empty() {
continue;
}
trace_peer_source(input, producers, nodes, input_roots, visited);
}
return;
}
input_roots.insert(name.to_string());
}
fn build_producer_map(graph: &GraphProto) -> HashMap<String, usize> {
let mut m = HashMap::new();
for (i, node) in graph.node.iter().enumerate() {
for out in &node.output {
if out.is_empty() {
continue;
}
m.insert(out.clone(), i);
}
}
m
}
fn is_peer_pass_through(node: &bb_ir::proto::onnx::NodeProto) -> bool {
matches!(
(node.domain.as_str(), node.op_type.as_str()),
("ai.onnx", "Identity")
| ("ai.onnx", "Slice")
| ("ai.onnx", "Gather")
| ("ai.onnx", "Concat")
| ("ai.onnx", "Squeeze")
| ("ai.onnx", "Unsqueeze")
)
}
fn read_wire_id(node: &bb_ir::proto::onnx::NodeProto) -> Option<String> {
node.metadata_props
.iter()
.find(|p| p.key == bb_ir::keys::WIRE_ID_KEY)
.map(|p| p.value.clone())
}
fn stamp_home(node: &mut bb_ir::proto::onnx::NodeProto, home: &str) {
if let Some(existing) = node
.metadata_props
.iter_mut()
.find(|p| p.key == HOME_CLASS_KEY)
{
existing.value = home.to_string();
return;
}
node.metadata_props.push(StringStringEntryProto {
key: HOME_CLASS_KEY.to_string(),
value: home.to_string(),
});
}
fn value_info_is_peer_id(vi: &bb_ir::proto::onnx::ValueInfoProto) -> bool {
if vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY) {
return true;
}
matches!(&vi.r#type, Some(TypeProto { value: Some(type_proto::Value::TensorType(_)), denotation, .. })
if denotation == "bb.peer_id" || denotation == "bb.peer_id_vec")
}