bb_compiler/infer_peer_classes.rs
1//! `infer_peer_classes` - stamp every NodeProto with the **class of
2//! Node** it runs on.
3//!
4//! Runs in [`runner::run_pipeline`](super::runner::run_pipeline)
5//! after `expand_ops` and before `synthesize_wire_recvs`. The result feeds
6//! [`partition_by_wire_ops`](super::partition_by_wire_ops::partition_by_wire_ops) - partitions
7//! are now defined by `home_class`, not by `module_instance` chains.
8//!
9//! ## Algorithm
10//!
11//! 1. Seed every function input's `home` to [`SELF_CLASS`].
12//! 2. Walk nodes in declaration order. For each NodeProto:
13//! - `wire.Send` re-homes its `data` output to the **destination
14//! class** (taken from the peer input's `peer_class` tag).
15//! The send itself runs on its payload's home class; the
16//! `handle` output stays with the sender. Self-send case
17//! (`dest_class == payload_home`) is just a value of `dest_class`.
18//! - Every other op inherits its home from its data inputs. All
19//! data inputs (i.e. inputs that aren't PEER_ID values) must
20//! agree on a home; otherwise [`CompileError::CrossClassDataflow`].
21//! Peer-id inputs are **ambient** - they don't constrain the
22//! consuming op's home class.
23//! 3. The home is stamped on the NodeProto as [`HOME_CLASS_KEY`]
24//! metadata for downstream passes.
25//!
26//! ## Self-send semantics
27//!
28//! When a `wire.Send`'s destination class equals its sender's home
29//! class, both the send and the synthesized recv land in the same
30//! partition at the partition pass. The runtime side is N physical
31//! instances of one class talking to each other (e.g. gossip peers).
32
33use std::collections::HashMap;
34
35use crate::error::CompileError;
36use crate::partition_by_wire_ops::WIRE_DOMAIN;
37use bb_ir::peer_class::{
38 home_class_of_node, peer_class_of_node, peer_class_of_value_info, HOME_CLASS_KEY,
39 PEER_CLASS_KEY, SELF_CLASS,
40};
41use bb_ir::proto::onnx::{type_proto, GraphProto, StringStringEntryProto, TypeProto};
42
43/// Walk `graph.node` and stamp `HOME_CLASS_KEY` on each NodeProto.
44/// Pure.
45pub fn infer_peer_classes(graph: &mut GraphProto) -> Result<(), CompileError> {
46 // Compile-time peer-class trace: for every wire.Send peer
47 // input, walk backward through allow-listed pass-through ops
48 // (Identity, Slice, Gather, Concat, Squeeze, Unsqueeze). Graph
49 // inputs reached along that walk get the `peer_class =
50 // <input_name>` stamp; non-pass-through producers stop the
51 // trace (their own peer_class metadata, if any, drives routing).
52 stamp_peer_class_on_inputs_feeding_wire_sends(graph);
53
54 // value_name → home class.
55 let mut home: HashMap<String, String> = HashMap::new();
56
57 // wire_id → destination class for the matched Send. Populated
58 // when each Send is processed; consulted when the paired Recv
59 // is processed (Recv outputs + home_class lift to the same
60 // destination class so the partitioner cuts cleanly).
61 let mut wire_id_to_dest_class: HashMap<String, String> = HashMap::new();
62
63 // Pre-scan function inputs: every input is on @self; PEER_ID
64 // inputs additionally seed `peer_class[input_name] = <class>` so
65 // a `wire.Send` reading that input can find its destination class.
66 let mut peer_class_of_value: HashMap<String, String> = HashMap::new();
67 let mut peer_id_value_names: std::collections::HashSet<String> =
68 std::collections::HashSet::new();
69 for vi in &graph.input {
70 home.insert(vi.name.clone(), SELF_CLASS.to_string());
71 if value_info_is_peer_id(vi) {
72 peer_id_value_names.insert(vi.name.clone());
73 }
74 if let Some(class) = peer_class_of_value_info(vi) {
75 peer_class_of_value.insert(vi.name.clone(), class.to_string());
76 }
77 }
78 for vi in &graph.value_info {
79 if value_info_is_peer_id(vi) {
80 peer_id_value_names.insert(vi.name.clone());
81 }
82 if let Some(class) = peer_class_of_value_info(vi) {
83 peer_class_of_value
84 .entry(vi.name.clone())
85 .or_insert_with(|| class.to_string());
86 }
87 }
88
89 // Walk nodes in declaration order. The runner only feeds us
90 // topologically ordered functions; we don't re-sort.
91 for node in graph.node.iter_mut() {
92 // Skip nodes that already carry an inferred home (idempotent
93 // re-runs return the same stamps).
94 if home_class_of_node(node).is_some() {
95 continue;
96 }
97
98 // Record dynamically-produced peer outputs (peer-sampling,
99 // gossip neighbor selection) BEFORE handling the node so a
100 // wire.Send referencing one of these outputs finds it.
101 if let Some(class) = peer_class_of_node(node) {
102 for out in &node.output {
103 if !out.is_empty() {
104 peer_class_of_value
105 .entry(out.clone())
106 .or_insert_with(|| class.to_string());
107 }
108 }
109 }
110
111 let is_wire_send = node.domain == WIRE_DOMAIN && node.op_type == "Send";
112 let is_wire_recv = node.domain == WIRE_DOMAIN && node.op_type == "Recv";
113 if is_wire_send {
114 // wire.Send signature is (payload_0, ..., payload_{N-1}, peer):
115 // the peer is the LAST input, payloads precede it.
116 // Reading the last input lets multi-input wires (hierarchical
117 // FedAvg, GlobalRegistry Announce, gossip disseminate) infer
118 // the right destination class.
119 //
120 // Fallback to `@default` when the peer source carries no class
121 // annotation so naming downstream stays stable.
122 let payload_name = node.input.first().cloned().unwrap_or_default();
123 let peer_input = node.input.last().cloned().unwrap_or_default();
124 let payload_home = home
125 .get(&payload_name)
126 .cloned()
127 .unwrap_or_else(|| SELF_CLASS.to_string());
128 let dest_class = peer_class_of_value
129 .get(&peer_input)
130 .cloned()
131 .unwrap_or_else(|| "@default".to_string());
132
133 // Record wire_id → dest_class so the paired Recv lifts
134 // its outputs into the same partition.
135 if let Some(wire_id) = read_wire_id(node) {
136 wire_id_to_dest_class.insert(wire_id, dest_class.clone());
137 }
138
139 // Send output arity disambiguates the shape:
140 // len==1 → [handle]; output[0] stays with the sender.
141 // len>=2 → [data, handle]; output[0] is the data lifted
142 // to dest_class (carried by the paired Recv on
143 // the single-output variant).
144 if let Some(first_out) = node.output.first() {
145 if !first_out.is_empty() {
146 let class = if node.output.len() >= 2 {
147 dest_class.clone()
148 } else {
149 payload_home.clone()
150 };
151 home.insert(first_out.clone(), class);
152 }
153 }
154 if let Some(handle_out) = node.output.get(1) {
155 if !handle_out.is_empty() {
156 home.insert(handle_out.clone(), payload_home.clone());
157 }
158 }
159 stamp_home(node, &payload_home);
160 continue;
161 }
162 if is_wire_recv {
163 // wire.Recv carries no graph inputs; its outputs flow
164 // into downstream user ops on the destination class.
165 // Match the destination class via the wire_id metadata
166 // the DSL stamped on both halves of the pair.
167 let dest_class = read_wire_id(node)
168 .and_then(|wid| wire_id_to_dest_class.get(&wid).cloned())
169 .unwrap_or_else(|| SELF_CLASS.to_string());
170 for out in &node.output {
171 if !out.is_empty() {
172 home.insert(out.clone(), dest_class.clone());
173 }
174 }
175 stamp_home(node, &dest_class);
176 continue;
177 }
178
179 // Non-send ops: collect data-input homes. peer_id inputs are
180 // ambient routing metadata, not dataflow - they don't
181 // constrain home.
182 let mut input_homes: Vec<String> = Vec::new();
183 for input in &node.input {
184 if input.is_empty() {
185 continue;
186 }
187 if peer_id_value_names.contains(input) {
188 continue;
189 }
190 if let Some(h) = home.get(input) {
191 input_homes.push(h.clone());
192 }
193 }
194 // Dedup while preserving order so the error message points at
195 // the first conflict, not a sorted permutation.
196 input_homes.dedup();
197 let node_home = match input_homes.len() {
198 0 => SELF_CLASS.to_string(),
199 1 => input_homes.remove(0),
200 _ => {
201 return Err(CompileError::CrossClassDataflow {
202 node_name: node.name.clone(),
203 home_a: input_homes[0].clone(),
204 home_b: input_homes[1].clone(),
205 });
206 }
207 };
208 for out in &node.output {
209 if !out.is_empty() {
210 home.insert(out.clone(), node_home.clone());
211 }
212 }
213 stamp_home(node, &node_home);
214 }
215
216 Ok(())
217}
218
219/// Walk `wire.Send` ops; for each peer-slot input value, trace
220/// backward through allow-listed pass-through ops until reaching
221/// either a graph input (stamp it) or a non-pass-through producer
222/// (the producing op's `peer_class` metadata, if any, drives the
223/// destination class downstream — no input-side stamp needed).
224///
225/// Peer values commonly flow through structural ops (`Identity`,
226/// `Slice`, `Gather`, `Squeeze`, `Unsqueeze`, `Concat`) before
227/// reaching a `wire.Send`'s peer slot — picking the first N peers
228/// of a view or concatenating two peer subsets. The trace tolerates
229/// those so the graph-input source still gets stamped.
230fn stamp_peer_class_on_inputs_feeding_wire_sends(graph: &mut GraphProto) {
231 let producers = build_producer_map(graph);
232
233 let mut input_roots: std::collections::HashSet<String> = std::collections::HashSet::new();
234 let mut visited: std::collections::HashSet<String> = std::collections::HashSet::new();
235
236 for node in &graph.node {
237 if node.domain != WIRE_DOMAIN || node.op_type != "Send" {
238 continue;
239 }
240 let Some(peer_input) = node.input.last() else {
241 continue;
242 };
243 if peer_input.is_empty() {
244 continue;
245 }
246 trace_peer_source(
247 peer_input,
248 &producers,
249 &graph.node,
250 &mut input_roots,
251 &mut visited,
252 );
253 }
254
255 if input_roots.is_empty() {
256 return;
257 }
258
259 for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
260 if !input_roots.contains(&vi.name) {
261 continue;
262 }
263 let already = vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY);
264 if !already {
265 vi.metadata_props.push(StringStringEntryProto {
266 key: PEER_CLASS_KEY.to_string(),
267 value: vi.name.clone(),
268 });
269 }
270 }
271}
272
273/// Trace a value name backward through producers, collecting any
274/// graph-input ancestors reached via the allow-listed pass-through
275/// ops. Non-pass-through producers terminate the walk (their
276/// output may still carry `peer_class` metadata; the main pass
277/// reads that separately via [`peer_class_of_node`]).
278fn trace_peer_source(
279 name: &str,
280 producers: &HashMap<String, usize>,
281 nodes: &[bb_ir::proto::onnx::NodeProto],
282 input_roots: &mut std::collections::HashSet<String>,
283 visited: &mut std::collections::HashSet<String>,
284) {
285 if !visited.insert(name.to_string()) {
286 return;
287 }
288 if let Some(&idx) = producers.get(name) {
289 let producer = &nodes[idx];
290 if !is_peer_pass_through(producer) {
291 return;
292 }
293 for input in &producer.input {
294 if input.is_empty() {
295 continue;
296 }
297 trace_peer_source(input, producers, nodes, input_roots, visited);
298 }
299 return;
300 }
301 // Not produced by any node in this graph — it's a graph input
302 // (or a function arg). Mark it for stamping.
303 input_roots.insert(name.to_string());
304}
305
306/// Build `value_name → producing_node_index` over `graph.node`.
307/// Empty output names are skipped.
308fn build_producer_map(graph: &GraphProto) -> HashMap<String, usize> {
309 let mut m = HashMap::new();
310 for (i, node) in graph.node.iter().enumerate() {
311 for out in &node.output {
312 if out.is_empty() {
313 continue;
314 }
315 m.insert(out.clone(), i);
316 }
317 }
318 m
319}
320
321/// Conservative allow-list of ops whose output preserves the
322/// peer-class semantics of their inputs. Adding to this list is a
323/// deliberate act: a new entry says "if this op's input is a graph
324/// input feeding a `wire.Send`'s peer slot, the graph input itself
325/// is the peer source." Ops that produce peer values from non-peer
326/// inputs (e.g. `PeerSelector::Sample`) are NOT pass-throughs —
327/// their `peer_class` metadata already drives destination routing.
328fn is_peer_pass_through(node: &bb_ir::proto::onnx::NodeProto) -> bool {
329 matches!(
330 (node.domain.as_str(), node.op_type.as_str()),
331 ("ai.onnx", "Identity")
332 | ("ai.onnx", "Slice")
333 | ("ai.onnx", "Gather")
334 | ("ai.onnx", "Concat")
335 | ("ai.onnx", "Squeeze")
336 | ("ai.onnx", "Unsqueeze")
337 )
338}
339
340/// Pull the [`bb_ir::keys::WIRE_ID_KEY`] metadata stamp from a wire
341/// op NodeProto. Used to pair Send and Recv NodeProtos the DSL
342/// `Graph::wire` emits together.
343fn read_wire_id(node: &bb_ir::proto::onnx::NodeProto) -> Option<String> {
344 node.metadata_props
345 .iter()
346 .find(|p| p.key == bb_ir::keys::WIRE_ID_KEY)
347 .map(|p| p.value.clone())
348}
349
350/// Stamp the `HOME_CLASS_KEY` metadata onto a NodeProto, replacing
351/// any existing stamp (idempotent re-runs preserve the same value).
352fn stamp_home(node: &mut bb_ir::proto::onnx::NodeProto, home: &str) {
353 if let Some(existing) = node
354 .metadata_props
355 .iter_mut()
356 .find(|p| p.key == HOME_CLASS_KEY)
357 {
358 existing.value = home.to_string();
359 return;
360 }
361 node.metadata_props.push(StringStringEntryProto {
362 key: HOME_CLASS_KEY.to_string(),
363 value: home.to_string(),
364 });
365}
366
367/// Returns `true` when a ValueInfoProto carries the `peer_class`
368/// metadata stamp from `Graph::input(name, &TYPE_PEER_ID)`. We use the
369/// presence of `PEER_CLASS_KEY` as the signal rather than the TypeNode
370/// denotation, because the compiler doesn't have access to the
371/// `TypeNode` static after the graph crosses the recording boundary.
372///
373/// Accept both `bb.peer_id` (single recipient) and
374/// `bb.peer_id_vec` (broadcast multi-peer recipient) denotations
375/// so peer-vec values don't get misclassified as non-peer data.
376fn value_info_is_peer_id(vi: &bb_ir::proto::onnx::ValueInfoProto) -> bool {
377 if vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY) {
378 return true;
379 }
380 // Fall back to the type's denotation for plain `Output<PeerId>`
381 // / `Output<Vec<PeerId>>` values that didn't go through
382 // `Graph::input` (hand-built fixtures, replayed ModelProto
383 // bodies).
384 matches!(&vi.r#type, Some(TypeProto { value: Some(type_proto::Value::TensorType(_)), denotation, .. })
385 if denotation == "bb.peer_id" || denotation == "bb.peer_id_vec")
386}
387