bb_compiler/partition_by_wire_ops.rs
1//! Slice the recorded function into per-BB-Node sub-graphs by
2//! wire-op reachability.
3//!
4//! Wire ops (`Send`, `SendReqBatched`, `SendResp`, `Recv`, `RecvReq`,
5//! `RecvRespBatched` under domain `ai.bytesandbrains.wire`) are the
6//! partition boundary. Two NodeProtos belong to the same partition
7//! iff there is a dataflow path between them that does NOT cross a
8//! wire op.
9//!
10//! Wire ops attach to the partition on their data side: Send-flavored
11//! ops join the partition of their data-input producers, Recv-flavored
12//! ops join the partition of their data-output consumers.
13//!
14//! Wire ops are user-authored `Send` NodeProtos and compiler-
15//! synthesized `Recv` NodeProtos. The downstream `analyze_wire_edges`
16//! pass classifies each cross-partition edge directly on each
17//! per-role sub-graph; `wire_edges` carries the resulting per-edge
18//! metadata.
19
20use std::collections::{BTreeMap, HashMap};
21
22use crate::error::CompileError;
23use crate::synthesize_wire_recvs::SYNTHESIZED_FROM_KEY;
24use bb_ir::peer_class::{home_class_of_node, SELF_CLASS};
25use bb_ir::proto::onnx::{GraphProto, NodeProto, ValueInfoProto};
26
27/// Wire-op domain - every NodeProto with this domain is a wire op.
28/// `wire.Send` is the only user-authored op; `wire.Recv` is
29/// synthesized by [`super::synthesize_wire_recvs::synthesize_wire_recvs`].
30pub const WIRE_DOMAIN: &str = "ai.bytesandbrains.wire";
31
32/// Send-flavored wire op_types. Single user-facing `Send` op per the
33/// wire-collapse design - see [`crate::syscall::wire`]. Used by
34/// downstream passes to detect send NodeProtos by `op_type`.
35pub const SEND_OP_TYPES: &[&str] = &["Send"];
36
37/// Recv-flavored wire op_types. `Recv` is synthesized by
38/// [`super::synthesize_wire_recvs::synthesize_wire_recvs`] - never user-authored.
39pub const RECV_OP_TYPES: &[&str] = &["Recv"];
40
41/// Output of the partition pass: per-role sub-graphs + cross-role
42/// edges. Each cross-partition edge appears as a user-authored
43/// `wire::Send` / synthesized `wire::Recv` NodeProto inside the
44/// per-role sub-graphs; `wire_edges` carries the matching
45/// per-edge metadata (producer/consumer roles, transport kind).
46#[derive(Debug, Default)]
47pub struct NetworkAnalysis {
48 /// One entry per BB-Node partition. Single-Node Modules yield
49 /// one entry; federated Modules yield one per wire-op-bounded
50 /// partition.
51 pub per_role: BTreeMap<String, GraphProto>,
52
53 /// Cross-role edges paired by sender index - one entry per
54 /// Send/Recv pair discovered after `synthesize_wire_recvs`.
55 /// [`super::analyze_wire_edges::analyze_wire_edges`] reads this
56 /// to classify each edge's transport and assign batch ids.
57 pub wire_edges: Vec<WireEdge>,
58}
59
60/// A directional cross-partition edge - produced by the compiler
61/// when a wire op pair is identified. Populated by
62/// [`super::analyze_wire_edges::analyze_wire_edges`].
63#[derive(Debug)]
64pub struct WireEdge {
65 /// Origin BB-Node role.
66 pub producer_role: String,
67
68 /// Destination BB-Node role.
69 pub consumer_role: String,
70
71 /// The value-name crossing the edge (the producer-side output
72 /// name).
73 pub value_name: String,
74
75 /// Producer-side `Send`-flavored NodeProto.
76 pub send_node: NodeProto,
77
78 /// Consumer-side `Recv`-flavored NodeProto.
79 pub recv_node: NodeProto,
80}
81
82/// Partition the graph by inferred home class. Pure.
83///
84/// After [`super::infer_peer_classes`] has stamped every NodeProto
85/// with `HOME_CLASS_KEY`, partitioning is a direct group-by on that
86/// key. Nodes lacking a home stamp fall through to
87/// [`SELF_CLASS`](super::peer_class::SELF_CLASS).
88pub fn partition_by_wire_ops(graph: &GraphProto) -> Result<NetworkAnalysis, CompileError> {
89 let mut per_role: BTreeMap<String, GraphProto> = BTreeMap::new();
90 for node in &graph.node {
91 let class = home_class_of_node(node)
92 .map(str::to_string)
93 .unwrap_or_else(|| SELF_CLASS.to_string());
94 per_role.entry(class).or_default().node.push(node.clone());
95 }
96
97 // Copy each role's referenced graph.input + value_info.
98 let value_info_by_name: HashMap<&str, &ValueInfoProto> = graph
99 .value_info
100 .iter()
101 .map(|v| (v.name.as_str(), v))
102 .collect();
103 let input_by_name: HashMap<&str, &ValueInfoProto> =
104 graph.input.iter().map(|v| (v.name.as_str(), v)).collect();
105
106 // Forward each role's slice of `graph.output` so downstream
107 // passes (analyze_wire_edges, gate-rx inserters) see the
108 // recorder-stamped "this value crosses the boundary" hints.
109 let output_by_name: HashMap<&str, &ValueInfoProto> =
110 graph.output.iter().map(|v| (v.name.as_str(), v)).collect();
111
112 for sub in per_role.values_mut() {
113 let mut referenced: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
114 for node in &sub.node {
115 for inp in &node.input {
116 if !inp.is_empty() {
117 referenced.insert(inp.clone());
118 }
119 }
120 }
121 for name in &referenced {
122 if let Some(&vi) = input_by_name.get(name.as_str()) {
123 sub.input.push(vi.clone());
124 }
125 if let Some(&vi) = value_info_by_name.get(name.as_str()) {
126 sub.value_info.push(vi.clone());
127 }
128 }
129
130 // Carry forward every top-level output produced by this
131 // sub_graph's nodes. The set of producer-side outputs is
132 // exactly the intersection of `graph.output` names and
133 // values emitted by the sub_graph's nodes.
134 let mut produced_here: std::collections::BTreeSet<String> =
135 std::collections::BTreeSet::new();
136 for node in &sub.node {
137 for out in &node.output {
138 if !out.is_empty() {
139 produced_here.insert(out.clone());
140 }
141 }
142 }
143 for name in &produced_here {
144 if let Some(&vi) = output_by_name.get(name.as_str()) {
145 sub.output.push(vi.clone());
146 }
147 }
148 }
149
150 let wire_edges = discover_wire_edges(graph);
151
152 Ok(NetworkAnalysis {
153 per_role,
154 wire_edges,
155 })
156}
157
158/// Pair Send and synthesized Recv NodeProtos into one [`WireEdge`]
159/// per cross-partition data flow. Sends carry an
160/// `output[0] = "<data>__send_sentinel_<idx>"` rename from
161/// [`super::synthesize_wire_recvs`]; synthesized Recvs carry
162/// `SYNTHESIZED_FROM_KEY = <idx>` metadata pointing back at the same
163/// Send. Matching pairs become wire edges. Sends without a paired
164/// Recv (fire-and-forget) are skipped.
165fn discover_wire_edges(graph: &GraphProto) -> Vec<WireEdge> {
166 let mut send_by_idx: HashMap<usize, &NodeProto> = HashMap::new();
167 for node in &graph.node {
168 if node.domain != WIRE_DOMAIN || !SEND_OP_TYPES.contains(&node.op_type.as_str()) {
169 continue;
170 }
171 if let Some(idx) = parse_send_sentinel_idx(node) {
172 send_by_idx.insert(idx, node);
173 }
174 }
175
176 let mut edges = Vec::new();
177 for recv in &graph.node {
178 if recv.domain != WIRE_DOMAIN || !RECV_OP_TYPES.contains(&recv.op_type.as_str()) {
179 continue;
180 }
181 let Some(send_idx) = recv
182 .metadata_props
183 .iter()
184 .find(|p| p.key == SYNTHESIZED_FROM_KEY)
185 .and_then(|p| p.value.parse::<usize>().ok())
186 else {
187 continue;
188 };
189 let Some(send) = send_by_idx.get(&send_idx) else {
190 continue;
191 };
192 let producer_role = home_class_of_node(send)
193 .map(str::to_string)
194 .unwrap_or_else(|| SELF_CLASS.to_string());
195 let consumer_role = home_class_of_node(recv)
196 .map(str::to_string)
197 .unwrap_or_else(|| SELF_CLASS.to_string());
198 let value_name = recv.output.first().cloned().unwrap_or_default();
199 edges.push(WireEdge {
200 producer_role,
201 consumer_role,
202 value_name,
203 send_node: (*send).clone(),
204 recv_node: recv.clone(),
205 });
206 }
207 edges
208}
209
210/// Strip the `__send_sentinel_<idx>` suffix from a Send's first
211/// output and return the index. Mirrors the rename
212/// [`super::synthesize_wire_recvs`] applies to every Send with a
213/// downstream consumer.
214fn parse_send_sentinel_idx(send: &NodeProto) -> Option<usize> {
215 let first = send.output.first()?;
216 let marker = "__send_sentinel_";
217 let pos = first.rfind(marker)?;
218 first[pos + marker.len()..].parse().ok()
219}
220