Skip to main content

bb_compiler/
infer_peer_classes.rs

1//! `infer_peer_classes` - stamp every NodeProto with the **class of
2//! Node** it runs on.
3//!
4//! Runs in [`runner::run_pipeline`](super::runner::run_pipeline)
5//! after `expand_ops` and before `synthesize_wire_recvs`. The result feeds
6//! [`partition_by_wire_ops`](super::partition_by_wire_ops::partition_by_wire_ops) - partitions
7//! are now defined by `home_class`, not by `module_instance` chains.
8//!
9//! ## Algorithm
10//!
11//! 1. Seed every function input's `home` to [`SELF_CLASS`].
12//! 2. Walk nodes in declaration order. For each NodeProto:
13//!    - `wire.Send` re-homes its `data` output to the **destination
14//!      class** (taken from the peer input's `peer_class` tag).
15//!      The send itself runs on its payload's home class; the
16//!      `handle` output stays with the sender. Self-send case
17//!      (`dest_class == payload_home`) is just a value of `dest_class`.
18//!    - Every other op inherits its home from its data inputs. All
19//!      data inputs (i.e. inputs that aren't PEER_ID values) must
20//!      agree on a home; otherwise [`CompileError::CrossClassDataflow`].
21//!      Peer-id inputs are **ambient** - they don't constrain the
22//!      consuming op's home class.
23//! 3. The home is stamped on the NodeProto as [`HOME_CLASS_KEY`]
24//!    metadata for downstream passes.
25//!
26//! ## Self-send semantics
27//!
28//! When a `wire.Send`'s destination class equals its sender's home
29//! class, both the send and the synthesized recv land in the same
30//! partition at the partition pass. The runtime side is N physical
31//! instances of one class talking to each other (e.g. gossip peers).
32
33use std::collections::HashMap;
34
35use crate::error::CompileError;
36use crate::partition_by_wire_ops::WIRE_DOMAIN;
37use bb_ir::peer_class::{
38    home_class_of_node, peer_class_of_node, peer_class_of_value_info, HOME_CLASS_KEY,
39    PEER_CLASS_KEY, SELF_CLASS,
40};
41use bb_ir::proto::onnx::{type_proto, GraphProto, StringStringEntryProto, TypeProto};
42
43/// Walk `graph.node` and stamp `HOME_CLASS_KEY` on each NodeProto.
44/// Pure.
45pub fn infer_peer_classes(graph: &mut GraphProto) -> Result<(), CompileError> {
46    // Compile-time peer-class trace: for every wire.Send peer
47    // input, walk backward through allow-listed pass-through ops
48    // (Identity, Slice, Gather, Concat, Squeeze, Unsqueeze). Graph
49    // inputs reached along that walk get the `peer_class =
50    // <input_name>` stamp; non-pass-through producers stop the
51    // trace (their own peer_class metadata, if any, drives routing).
52    stamp_peer_class_on_inputs_feeding_wire_sends(graph);
53
54    // value_name → home class.
55    let mut home: HashMap<String, String> = HashMap::new();
56
57    // wire_id → destination class for the matched Send. Populated
58    // when each Send is processed; consulted when the paired Recv
59    // is processed (Recv outputs + home_class lift to the same
60    // destination class so the partitioner cuts cleanly).
61    let mut wire_id_to_dest_class: HashMap<String, String> = HashMap::new();
62
63    // Pre-scan function inputs: every input is on @self; PEER_ID
64    // inputs additionally seed `peer_class[input_name] = <class>` so
65    // a `wire.Send` reading that input can find its destination class.
66    let mut peer_class_of_value: HashMap<String, String> = HashMap::new();
67    let mut peer_id_value_names: std::collections::HashSet<String> =
68        std::collections::HashSet::new();
69    for vi in &graph.input {
70        home.insert(vi.name.clone(), SELF_CLASS.to_string());
71        if value_info_is_peer_id(vi) {
72            peer_id_value_names.insert(vi.name.clone());
73        }
74        if let Some(class) = peer_class_of_value_info(vi) {
75            peer_class_of_value.insert(vi.name.clone(), class.to_string());
76        }
77    }
78    for vi in &graph.value_info {
79        if value_info_is_peer_id(vi) {
80            peer_id_value_names.insert(vi.name.clone());
81        }
82        if let Some(class) = peer_class_of_value_info(vi) {
83            peer_class_of_value
84                .entry(vi.name.clone())
85                .or_insert_with(|| class.to_string());
86        }
87    }
88
89    // Walk nodes in declaration order. The runner only feeds us
90    // topologically ordered functions; we don't re-sort.
91    for node in graph.node.iter_mut() {
92        // Skip nodes that already carry an inferred home (idempotent
93        // re-runs return the same stamps).
94        if home_class_of_node(node).is_some() {
95            continue;
96        }
97
98        // Record dynamically-produced peer outputs (peer-sampling,
99        // gossip neighbor selection) BEFORE handling the node so a
100        // wire.Send referencing one of these outputs finds it.
101        if let Some(class) = peer_class_of_node(node) {
102            for out in &node.output {
103                if !out.is_empty() {
104                    peer_class_of_value
105                        .entry(out.clone())
106                        .or_insert_with(|| class.to_string());
107                }
108            }
109        }
110
111        let is_wire_send = node.domain == WIRE_DOMAIN && node.op_type == "Send";
112        let is_wire_recv = node.domain == WIRE_DOMAIN && node.op_type == "Recv";
113        if is_wire_send {
114            // wire.Send signature is (payload_0, ..., payload_{N-1}, peer):
115            // the peer is the LAST input, payloads precede it.
116            // Reading the last input lets multi-input wires (hierarchical
117            // FedAvg, GlobalRegistry Announce, gossip disseminate) infer
118            // the right destination class.
119            //
120            // Fallback to `@default` when the peer source carries no class
121            // annotation so naming downstream stays stable.
122            let payload_name = node.input.first().cloned().unwrap_or_default();
123            let peer_input = node.input.last().cloned().unwrap_or_default();
124            let payload_home = home
125                .get(&payload_name)
126                .cloned()
127                .unwrap_or_else(|| SELF_CLASS.to_string());
128            let dest_class = peer_class_of_value
129                .get(&peer_input)
130                .cloned()
131                .unwrap_or_else(|| "@default".to_string());
132
133            // Record wire_id → dest_class so the paired Recv lifts
134            // its outputs into the same partition.
135            if let Some(wire_id) = read_wire_id(node) {
136                wire_id_to_dest_class.insert(wire_id, dest_class.clone());
137            }
138
139            // Send output arity disambiguates the shape:
140            //   len==1 → [handle]; output[0] stays with the sender.
141            //   len>=2 → [data, handle]; output[0] is the data lifted
142            //            to dest_class (carried by the paired Recv on
143            //            the single-output variant).
144            if let Some(first_out) = node.output.first() {
145                if !first_out.is_empty() {
146                    let class = if node.output.len() >= 2 {
147                        dest_class.clone()
148                    } else {
149                        payload_home.clone()
150                    };
151                    home.insert(first_out.clone(), class);
152                }
153            }
154            if let Some(handle_out) = node.output.get(1) {
155                if !handle_out.is_empty() {
156                    home.insert(handle_out.clone(), payload_home.clone());
157                }
158            }
159            stamp_home(node, &payload_home);
160            continue;
161        }
162        if is_wire_recv {
163            // wire.Recv carries no graph inputs; its outputs flow
164            // into downstream user ops on the destination class.
165            // Match the destination class via the wire_id metadata
166            // the DSL stamped on both halves of the pair.
167            let dest_class = read_wire_id(node)
168                .and_then(|wid| wire_id_to_dest_class.get(&wid).cloned())
169                .unwrap_or_else(|| SELF_CLASS.to_string());
170            for out in &node.output {
171                if !out.is_empty() {
172                    home.insert(out.clone(), dest_class.clone());
173                }
174            }
175            stamp_home(node, &dest_class);
176            continue;
177        }
178
179        // Non-send ops: collect data-input homes. peer_id inputs are
180        // ambient routing metadata, not dataflow - they don't
181        // constrain home.
182        let mut input_homes: Vec<String> = Vec::new();
183        for input in &node.input {
184            if input.is_empty() {
185                continue;
186            }
187            if peer_id_value_names.contains(input) {
188                continue;
189            }
190            if let Some(h) = home.get(input) {
191                input_homes.push(h.clone());
192            }
193        }
194        // Dedup while preserving order so the error message points at
195        // the first conflict, not a sorted permutation.
196        input_homes.dedup();
197        let node_home = match input_homes.len() {
198            0 => SELF_CLASS.to_string(),
199            1 => input_homes.remove(0),
200            _ => {
201                return Err(CompileError::CrossClassDataflow {
202                    node_name: node.name.clone(),
203                    home_a: input_homes[0].clone(),
204                    home_b: input_homes[1].clone(),
205                });
206            }
207        };
208        for out in &node.output {
209            if !out.is_empty() {
210                home.insert(out.clone(), node_home.clone());
211            }
212        }
213        stamp_home(node, &node_home);
214    }
215
216    Ok(())
217}
218
219/// Walk `wire.Send` ops; for each peer-slot input value, trace
220/// backward through allow-listed pass-through ops until reaching
221/// either a graph input (stamp it) or a non-pass-through producer
222/// (the producing op's `peer_class` metadata, if any, drives the
223/// destination class downstream — no input-side stamp needed).
224///
225/// Peer values commonly flow through structural ops (`Identity`,
226/// `Slice`, `Gather`, `Squeeze`, `Unsqueeze`, `Concat`) before
227/// reaching a `wire.Send`'s peer slot — picking the first N peers
228/// of a view or concatenating two peer subsets. The trace tolerates
229/// those so the graph-input source still gets stamped.
230fn stamp_peer_class_on_inputs_feeding_wire_sends(graph: &mut GraphProto) {
231    let producers = build_producer_map(graph);
232
233    let mut input_roots: std::collections::HashSet<String> = std::collections::HashSet::new();
234    let mut visited: std::collections::HashSet<String> = std::collections::HashSet::new();
235
236    for node in &graph.node {
237        if node.domain != WIRE_DOMAIN || node.op_type != "Send" {
238            continue;
239        }
240        let Some(peer_input) = node.input.last() else {
241            continue;
242        };
243        if peer_input.is_empty() {
244            continue;
245        }
246        trace_peer_source(
247            peer_input,
248            &producers,
249            &graph.node,
250            &mut input_roots,
251            &mut visited,
252        );
253    }
254
255    if input_roots.is_empty() {
256        return;
257    }
258
259    for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
260        if !input_roots.contains(&vi.name) {
261            continue;
262        }
263        let already = vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY);
264        if !already {
265            vi.metadata_props.push(StringStringEntryProto {
266                key: PEER_CLASS_KEY.to_string(),
267                value: vi.name.clone(),
268            });
269        }
270    }
271}
272
273/// Trace a value name backward through producers, collecting any
274/// graph-input ancestors reached via the allow-listed pass-through
275/// ops. Non-pass-through producers terminate the walk (their
276/// output may still carry `peer_class` metadata; the main pass
277/// reads that separately via [`peer_class_of_node`]).
278fn trace_peer_source(
279    name: &str,
280    producers: &HashMap<String, usize>,
281    nodes: &[bb_ir::proto::onnx::NodeProto],
282    input_roots: &mut std::collections::HashSet<String>,
283    visited: &mut std::collections::HashSet<String>,
284) {
285    if !visited.insert(name.to_string()) {
286        return;
287    }
288    if let Some(&idx) = producers.get(name) {
289        let producer = &nodes[idx];
290        if !is_peer_pass_through(producer) {
291            return;
292        }
293        for input in &producer.input {
294            if input.is_empty() {
295                continue;
296            }
297            trace_peer_source(input, producers, nodes, input_roots, visited);
298        }
299        return;
300    }
301    // Not produced by any node in this graph — it's a graph input
302    // (or a function arg). Mark it for stamping.
303    input_roots.insert(name.to_string());
304}
305
306/// Build `value_name → producing_node_index` over `graph.node`.
307/// Empty output names are skipped.
308fn build_producer_map(graph: &GraphProto) -> HashMap<String, usize> {
309    let mut m = HashMap::new();
310    for (i, node) in graph.node.iter().enumerate() {
311        for out in &node.output {
312            if out.is_empty() {
313                continue;
314            }
315            m.insert(out.clone(), i);
316        }
317    }
318    m
319}
320
321/// Conservative allow-list of ops whose output preserves the
322/// peer-class semantics of their inputs. Adding to this list is a
323/// deliberate act: a new entry says "if this op's input is a graph
324/// input feeding a `wire.Send`'s peer slot, the graph input itself
325/// is the peer source." Ops that produce peer values from non-peer
326/// inputs (e.g. `PeerSelector::Sample`) are NOT pass-throughs —
327/// their `peer_class` metadata already drives destination routing.
328fn is_peer_pass_through(node: &bb_ir::proto::onnx::NodeProto) -> bool {
329    matches!(
330        (node.domain.as_str(), node.op_type.as_str()),
331        ("ai.onnx", "Identity")
332            | ("ai.onnx", "Slice")
333            | ("ai.onnx", "Gather")
334            | ("ai.onnx", "Concat")
335            | ("ai.onnx", "Squeeze")
336            | ("ai.onnx", "Unsqueeze")
337    )
338}
339
340/// Pull the [`bb_ir::keys::WIRE_ID_KEY`] metadata stamp from a wire
341/// op NodeProto. Used to pair Send and Recv NodeProtos the DSL
342/// `Graph::wire` emits together.
343fn read_wire_id(node: &bb_ir::proto::onnx::NodeProto) -> Option<String> {
344    node.metadata_props
345        .iter()
346        .find(|p| p.key == bb_ir::keys::WIRE_ID_KEY)
347        .map(|p| p.value.clone())
348}
349
350/// Stamp the `HOME_CLASS_KEY` metadata onto a NodeProto, replacing
351/// any existing stamp (idempotent re-runs preserve the same value).
352fn stamp_home(node: &mut bb_ir::proto::onnx::NodeProto, home: &str) {
353    if let Some(existing) = node
354        .metadata_props
355        .iter_mut()
356        .find(|p| p.key == HOME_CLASS_KEY)
357    {
358        existing.value = home.to_string();
359        return;
360    }
361    node.metadata_props.push(StringStringEntryProto {
362        key: HOME_CLASS_KEY.to_string(),
363        value: home.to_string(),
364    });
365}
366
367/// Returns `true` when a ValueInfoProto carries the `peer_class`
368/// metadata stamp from `Graph::input(name, &TYPE_PEER_ID)`. We use the
369/// presence of `PEER_CLASS_KEY` as the signal rather than the TypeNode
370/// denotation, because the compiler doesn't have access to the
371/// `TypeNode` static after the graph crosses the recording boundary.
372///
373/// Accept both `bb.peer_id` (single recipient) and
374/// `bb.peer_id_vec` (broadcast multi-peer recipient) denotations
375/// so peer-vec values don't get misclassified as non-peer data.
376fn value_info_is_peer_id(vi: &bb_ir::proto::onnx::ValueInfoProto) -> bool {
377    if vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY) {
378        return true;
379    }
380    // Fall back to the type's denotation for plain `Output<PeerId>`
381    // / `Output<Vec<PeerId>>` values that didn't go through
382    // `Graph::input` (hand-built fixtures, replayed ModelProto
383    // bodies).
384    matches!(&vi.r#type, Some(TypeProto { value: Some(type_proto::Value::TensorType(_)), denotation, .. })
385        if denotation == "bb.peer_id" || denotation == "bb.peer_id_vec")
386}
387