use std::collections::{BTreeMap, HashMap};
use crate::error::CompileError;
use crate::partition_by_wire_ops::WireEdge;
use bb_ir::proto::onnx::{GraphProto, NodeProto, StringStringEntryProto};
pub use bb_ir::keys::{
dest_suffix_attribute, BATCH_GROUP_KEY, DEST_SITE_NAME_PREFIX, DEST_SUFFIX_ATTR_PREFIX,
TRIGGER_DENOTATION, WIRE_TRANSPORT_KEY,
};
pub fn analyze_wire_edges(
sub_graph: &mut GraphProto,
wire_edges: &[WireEdge],
) -> Result<(), CompileError> {
let denotation_by_name: HashMap<&str, &str> = sub_graph
.value_info
.iter()
.chain(sub_graph.input.iter())
.chain(sub_graph.output.iter())
.filter_map(|v| {
let denot = v.r#type.as_ref()?.denotation.as_str();
if denot.is_empty() {
None
} else {
Some((v.name.as_str(), denot))
}
})
.collect();
let mut batch_groups: BTreeMap<(String, String), u32> = BTreeMap::new();
let mut next_batch_id: u32 = 0;
for edge in wire_edges {
let consumer_port_denots: Vec<&str> =
consumer_input_denotations(&denotation_by_name, sub_graph, &edge.value_name);
let transport = if consumer_port_denots.is_empty() {
match denotation_by_name.get(edge.value_name.as_str()) {
Some(d) if *d == TRIGGER_DENOTATION => "trigger_only",
_ => "data",
}
} else if consumer_port_denots
.iter()
.all(|d| *d == TRIGGER_DENOTATION)
{
"trigger_only"
} else {
"data"
};
let key = (edge.producer_role.clone(), edge.consumer_role.clone());
let batch_id = *batch_groups.entry(key).or_insert_with(|| {
let id = next_batch_id;
next_batch_id += 1;
id
});
let batch_str = batch_id.to_string();
let recv_site_name = edge
.recv_node
.output
.first()
.cloned()
.unwrap_or_else(|| edge.value_name.clone());
let dest_site_key = format!("{DEST_SITE_NAME_PREFIX}{}", edge.value_name);
for node in sub_graph.node.iter_mut() {
let matches_value = node.output.iter().any(|o| o == &edge.value_name);
if !matches_value {
continue;
}
if node.op_type == "Send" {
set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
set_metadata(&mut node.metadata_props, &dest_site_key, &recv_site_name);
} else if node.op_type == "Recv" {
set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
}
}
}
Ok(())
}
pub fn dest_suffix_attr<'a>(node: &'a NodeProto, input_name: &str) -> Option<&'a [u8]> {
let key = format!("{DEST_SUFFIX_ATTR_PREFIX}{input_name}");
node.attribute
.iter()
.find(|a| a.name == key)
.map(|a| a.s.as_slice())
}
fn consumer_input_denotations<'a>(
denotation_by_name: &HashMap<&'a str, &'a str>,
sub_graph: &'a GraphProto,
value_name: &str,
) -> Vec<&'a str> {
let mut out: Vec<&str> = Vec::new();
for node in &sub_graph.node {
for input in &node.input {
if input == value_name {
if let Some(d) = denotation_by_name.get(input.as_str()) {
out.push(*d);
} else {
out.push("");
}
}
}
}
out
}
fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
existing.value = value.to_string();
return;
}
props.push(StringStringEntryProto {
key: key.to_string(),
value: value.to_string(),
});
}