bb_compiler/
analyze_wire_edges.rs1use std::collections::{BTreeMap, HashMap};
6
7use crate::error::CompileError;
8use crate::partition_by_wire_ops::WireEdge;
9use bb_ir::proto::onnx::{GraphProto, NodeProto, StringStringEntryProto};
10
11pub use bb_ir::keys::{
15 dest_suffix_attribute, BATCH_GROUP_KEY, DEST_SITE_NAME_PREFIX, DEST_SUFFIX_ATTR_PREFIX,
16 TRIGGER_DENOTATION, WIRE_TRANSPORT_KEY,
17};
18
19pub fn analyze_wire_edges(
29 sub_graph: &mut GraphProto,
30 wire_edges: &[WireEdge],
31) -> Result<(), CompileError> {
32 let denotation_by_name: HashMap<&str, &str> = sub_graph
33 .value_info
34 .iter()
35 .chain(sub_graph.input.iter())
36 .chain(sub_graph.output.iter())
37 .filter_map(|v| {
38 let denot = v.r#type.as_ref()?.denotation.as_str();
39 if denot.is_empty() {
40 None
41 } else {
42 Some((v.name.as_str(), denot))
43 }
44 })
45 .collect();
46
47 let mut batch_groups: BTreeMap<(String, String), u32> = BTreeMap::new();
48 let mut next_batch_id: u32 = 0;
49
50 for edge in wire_edges {
59 let consumer_port_denots: Vec<&str> =
60 consumer_input_denotations(&denotation_by_name, sub_graph, &edge.value_name);
61
62 let transport = if consumer_port_denots.is_empty() {
63 match denotation_by_name.get(edge.value_name.as_str()) {
67 Some(d) if *d == TRIGGER_DENOTATION => "trigger_only",
68 _ => "data",
69 }
70 } else if consumer_port_denots
71 .iter()
72 .all(|d| *d == TRIGGER_DENOTATION)
73 {
74 "trigger_only"
75 } else {
76 "data"
77 };
78
79 let key = (edge.producer_role.clone(), edge.consumer_role.clone());
80 let batch_id = *batch_groups.entry(key).or_insert_with(|| {
81 let id = next_batch_id;
82 next_batch_id += 1;
83 id
84 });
85 let batch_str = batch_id.to_string();
86
87 let recv_site_name = edge
93 .recv_node
94 .output
95 .first()
96 .cloned()
97 .unwrap_or_else(|| edge.value_name.clone());
98 let dest_site_key = format!("{DEST_SITE_NAME_PREFIX}{}", edge.value_name);
99
100 for node in sub_graph.node.iter_mut() {
101 let matches_value = node.output.iter().any(|o| o == &edge.value_name);
102 if !matches_value {
103 continue;
104 }
105 if node.op_type == "Send" {
106 set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
107 set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
108 set_metadata(&mut node.metadata_props, &dest_site_key, &recv_site_name);
109 } else if node.op_type == "Recv" {
110 set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
111 set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
112 }
113 }
114 }
115
116 Ok(())
117}
118
119pub fn dest_suffix_attr<'a>(node: &'a NodeProto, input_name: &str) -> Option<&'a [u8]> {
124 let key = format!("{DEST_SUFFIX_ATTR_PREFIX}{input_name}");
125 node.attribute
126 .iter()
127 .find(|a| a.name == key)
128 .map(|a| a.s.as_slice())
129}
130
131fn consumer_input_denotations<'a>(
138 denotation_by_name: &HashMap<&'a str, &'a str>,
139 sub_graph: &'a GraphProto,
140 value_name: &str,
141) -> Vec<&'a str> {
142 let mut out: Vec<&str> = Vec::new();
143 for node in &sub_graph.node {
144 for input in &node.input {
145 if input == value_name {
146 if let Some(d) = denotation_by_name.get(input.as_str()) {
147 out.push(*d);
148 } else {
149 out.push("");
152 }
153 }
154 }
155 }
156 out
157}
158
159fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
160 if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
161 existing.value = value.to_string();
162 return;
163 }
164 props.push(StringStringEntryProto {
165 key: key.to_string(),
166 value: value.to_string(),
167 });
168}
169