bb-compiler 0.3.4

Compiler pipeline for the bytesandbrains framework — Compiler driver, CompilerPass trait, canonical pass list, BuildError.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
//! `infer_peer_classes` - stamp every NodeProto with the **class of
//! Node** it runs on.
//!
//! Runs in [`runner::run_pipeline`](super::runner::run_pipeline)
//! after `expand_ops` and before `synthesize_wire_recvs`. The result feeds
//! [`partition_by_wire_ops`](super::partition_by_wire_ops::partition_by_wire_ops) - partitions
//! are now defined by `home_class`, not by `module_instance` chains.
//!
//! ## Algorithm
//!
//! 1. Seed every function input's `home` to [`SELF_CLASS`].
//! 2. Walk nodes in declaration order. For each NodeProto:
//!    - `wire.Send` re-homes its `data` output to the **destination
//!      class** (taken from the peer input's `peer_class` tag).
//!      The send itself runs on its payload's home class; the
//!      `handle` output stays with the sender. Self-send case
//!      (`dest_class == payload_home`) is just a value of `dest_class`.
//!    - Every other op inherits its home from its data inputs. All
//!      data inputs (i.e. inputs that aren't PEER_ID values) must
//!      agree on a home; otherwise [`CompileError::CrossClassDataflow`].
//!      Peer-id inputs are **ambient** - they don't constrain the
//!      consuming op's home class.
//! 3. The home is stamped on the NodeProto as [`HOME_CLASS_KEY`]
//!    metadata for downstream passes.
//!
//! ## Self-send semantics
//!
//! When a `wire.Send`'s destination class equals its sender's home
//! class, both the send and the synthesized recv land in the same
//! partition at the partition pass. The runtime side is N physical
//! instances of one class talking to each other (e.g. gossip peers).

use std::collections::HashMap;

use crate::error::CompileError;
use crate::partition_by_wire_ops::WIRE_DOMAIN;
use bb_ir::peer_class::{
    home_class_of_node, peer_class_of_node, peer_class_of_value_info, HOME_CLASS_KEY,
    PEER_CLASS_KEY, SELF_CLASS,
};
use bb_ir::proto::onnx::{type_proto, GraphProto, StringStringEntryProto, TypeProto};

/// Walk `graph.node` and stamp `HOME_CLASS_KEY` on each NodeProto.
/// Pure per COMPILER.md §3.2.
pub fn infer_peer_classes(graph: &mut GraphProto) -> Result<(), CompileError> {
    // Phase B - stamp `peer_class` metadata on every graph input
    // whose dataflow ultimately reaches a `wire.Send`'s peer slot.
    // Replaces the recording-time pointer-identity autotag in
    // `bb-dsl/src/graph.rs` (deleted alongside the single-arg
    // `Graph::input(name)` migration). The trace walks backward
    // from each `wire.Send`'s peer input through allow-listed
    // pass-through ops (Identity, Slice, Gather, Concat, Squeeze,
    // Unsqueeze) until it reaches a graph input or a non-pass-
    // through producer. Graph inputs along that walk get the
    // `peer_class = <input_name>` stamp; non-pass-through producers
    // stop the trace.
    stamp_peer_class_on_inputs_feeding_wire_sends(graph);

    // value_name → home class.
    let mut home: HashMap<String, String> = HashMap::new();

    // wire_id → destination class for the matched Send. Populated
    // when each Send is processed; consulted when the paired Recv
    // is processed (Recv outputs + home_class lift to the same
    // destination class so the partitioner cuts cleanly).
    let mut wire_id_to_dest_class: HashMap<String, String> = HashMap::new();

    // Pre-scan function inputs: every input is on @self; PEER_ID
    // inputs additionally seed `peer_class[input_name] = <class>` so
    // a `wire.Send` reading that input can find its destination class.
    let mut peer_class_of_value: HashMap<String, String> = HashMap::new();
    let mut peer_id_value_names: std::collections::HashSet<String> =
        std::collections::HashSet::new();
    for vi in &graph.input {
        home.insert(vi.name.clone(), SELF_CLASS.to_string());
        if value_info_is_peer_id(vi) {
            peer_id_value_names.insert(vi.name.clone());
        }
        if let Some(class) = peer_class_of_value_info(vi) {
            peer_class_of_value.insert(vi.name.clone(), class.to_string());
        }
    }
    for vi in &graph.value_info {
        if value_info_is_peer_id(vi) {
            peer_id_value_names.insert(vi.name.clone());
        }
        if let Some(class) = peer_class_of_value_info(vi) {
            peer_class_of_value
                .entry(vi.name.clone())
                .or_insert_with(|| class.to_string());
        }
    }

    // Walk nodes in declaration order. The runner only feeds us
    // topologically ordered functions; we don't re-sort.
    for node in graph.node.iter_mut() {
        // Skip nodes that already carry an inferred home (idempotent
        // re-runs return the same stamps).
        if home_class_of_node(node).is_some() {
            continue;
        }

        // Record dynamically-produced peer outputs (peer-sampling,
        // gossip neighbor selection) BEFORE handling the node so a
        // wire.Send referencing one of these outputs finds it.
        if let Some(class) = peer_class_of_node(node) {
            for out in &node.output {
                if !out.is_empty() {
                    peer_class_of_value
                        .entry(out.clone())
                        .or_insert_with(|| class.to_string());
                }
            }
        }

        let is_wire_send = node.domain == WIRE_DOMAIN && node.op_type == "Send";
        let is_wire_recv = node.domain == WIRE_DOMAIN && node.op_type == "Recv";
        if is_wire_send {
            // wire.Send signature is `(payload_0, ..., payload_{N-1},
            // peer)`. Per-DSL convention (bb-dsl/src/graph.rs:387-389)
            // the peer is the LAST input and 1..N data payloads
            // precede it. back-scan instead of the legacy
            // hard-coded `input.get(1)` so multi-input wires
            // (hierarchical FedAvg, GlobalRegistry Announce, gossip
            // disseminate) infer the right destination class.
            //
            // The destination class is whatever peer class produced
            // the peer input; if the peer source has no class
            // annotation we fall back to `@default` so naming
            // downstream stays stable.
            let payload_name = node.input.first().cloned().unwrap_or_default();
            let peer_input = node.input.last().cloned().unwrap_or_default();
            let payload_home = home
                .get(&payload_name)
                .cloned()
                .unwrap_or_else(|| SELF_CLASS.to_string());
            let dest_class = peer_class_of_value
                .get(&peer_input)
                .cloned()
                .unwrap_or_else(|| "@default".to_string());

            // Record wire_id → dest_class so the paired Recv lifts
            // its outputs into the same partition.
            if let Some(wire_id) = read_wire_id(node) {
                wire_id_to_dest_class.insert(wire_id, dest_class.clone());
            }

            // For the new DSL shape Send.output == [handle]; for
            // legacy hand-written shape it's [data, handle]. Treat
            // either: classify the first output as dest_class when
            // it's a `data`-shaped value, else as the sender's
            // payload_home for a sender-side handle.
            if let Some(first_out) = node.output.first() {
                if !first_out.is_empty() {
                    let class = if node.output.len() >= 2 {
                        // [data, handle] — output[0] is the data
                        // lifted to dest_class (legacy).
                        dest_class.clone()
                    } else {
                        // [handle] — output[0] stays with the sender
                        // (new DSL). The paired Recv carries the
                        // dest-side outputs.
                        payload_home.clone()
                    };
                    home.insert(first_out.clone(), class);
                }
            }
            if let Some(handle_out) = node.output.get(1) {
                if !handle_out.is_empty() {
                    home.insert(handle_out.clone(), payload_home.clone());
                }
            }
            stamp_home(node, &payload_home);
            continue;
        }
        if is_wire_recv {
            // wire.Recv carries no graph inputs; its outputs flow
            // into downstream user ops on the destination class.
            // Match the destination class via the wire_id metadata
            // the DSL stamped on both halves of the pair.
            let dest_class = read_wire_id(node)
                .and_then(|wid| wire_id_to_dest_class.get(&wid).cloned())
                .unwrap_or_else(|| SELF_CLASS.to_string());
            for out in &node.output {
                if !out.is_empty() {
                    home.insert(out.clone(), dest_class.clone());
                }
            }
            stamp_home(node, &dest_class);
            continue;
        }

        // Non-send ops: collect data-input homes. peer_id inputs are
        // ambient routing metadata, not dataflow - they don't
        // constrain home.
        let mut input_homes: Vec<String> = Vec::new();
        for input in &node.input {
            if input.is_empty() {
                continue;
            }
            if peer_id_value_names.contains(input) {
                continue;
            }
            if let Some(h) = home.get(input) {
                input_homes.push(h.clone());
            }
        }
        // Dedup while preserving order so the error message points at
        // the first conflict, not a sorted permutation.
        input_homes.dedup();
        let node_home = match input_homes.len() {
            0 => SELF_CLASS.to_string(),
            1 => input_homes.remove(0),
            _ => {
                return Err(CompileError::CrossClassDataflow {
                    node_name: node.name.clone(),
                    home_a: input_homes[0].clone(),
                    home_b: input_homes[1].clone(),
                });
            }
        };
        for out in &node.output {
            if !out.is_empty() {
                home.insert(out.clone(), node_home.clone());
            }
        }
        stamp_home(node, &node_home);
    }

    Ok(())
}

/// Walk `wire.Send` ops; for each peer-slot input value, trace
/// backward through allow-listed pass-through ops until reaching
/// either a graph input (stamp it) or a non-pass-through producer
/// (the producing op's `peer_class` metadata, if any, drives the
/// destination class downstream — no input-side stamp needed).
///
/// This is the compile-time relocation of the recording-time
/// pointer-identity autotag that previously fired in
/// `Graph::input(name, &TYPE_PEER_ID)`. The trace lets a graph
/// input flow through ops like `Identity`, `Slice`, `Gather`,
/// `Squeeze`, `Unsqueeze`, or `Concat` before reaching a
/// `wire.Send`'s peer slot — the common patterns when peer
/// selection routes through structural ops (e.g. picking the first
/// N peers of a view, or concatenating two peer subsets).
fn stamp_peer_class_on_inputs_feeding_wire_sends(graph: &mut GraphProto) {
    let producers = build_producer_map(graph);

    let mut input_roots: std::collections::HashSet<String> = std::collections::HashSet::new();
    let mut visited: std::collections::HashSet<String> = std::collections::HashSet::new();

    for node in &graph.node {
        if node.domain != WIRE_DOMAIN || node.op_type != "Send" {
            continue;
        }
        let Some(peer_input) = node.input.last() else {
            continue;
        };
        if peer_input.is_empty() {
            continue;
        }
        trace_peer_source(
            peer_input,
            &producers,
            &graph.node,
            &mut input_roots,
            &mut visited,
        );
    }

    if input_roots.is_empty() {
        return;
    }

    for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
        if !input_roots.contains(&vi.name) {
            continue;
        }
        let already = vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY);
        if !already {
            vi.metadata_props.push(StringStringEntryProto {
                key: PEER_CLASS_KEY.to_string(),
                value: vi.name.clone(),
            });
        }
    }
}

/// Trace a value name backward through producers, collecting any
/// graph-input ancestors reached via the allow-listed pass-through
/// ops. Non-pass-through producers terminate the walk (their
/// output may still carry `peer_class` metadata; the main pass
/// reads that separately via [`peer_class_of_node`]).
fn trace_peer_source(
    name: &str,
    producers: &HashMap<String, usize>,
    nodes: &[bb_ir::proto::onnx::NodeProto],
    input_roots: &mut std::collections::HashSet<String>,
    visited: &mut std::collections::HashSet<String>,
) {
    if !visited.insert(name.to_string()) {
        return;
    }
    if let Some(&idx) = producers.get(name) {
        let producer = &nodes[idx];
        if !is_peer_pass_through(producer) {
            return;
        }
        for input in &producer.input {
            if input.is_empty() {
                continue;
            }
            trace_peer_source(input, producers, nodes, input_roots, visited);
        }
        return;
    }
    // Not produced by any node in this graph — it's a graph input
    // (or a function arg). Mark it for stamping.
    input_roots.insert(name.to_string());
}

/// Build `value_name → producing_node_index` over `graph.node`.
/// Empty output names are skipped.
fn build_producer_map(graph: &GraphProto) -> HashMap<String, usize> {
    let mut m = HashMap::new();
    for (i, node) in graph.node.iter().enumerate() {
        for out in &node.output {
            if out.is_empty() {
                continue;
            }
            m.insert(out.clone(), i);
        }
    }
    m
}

/// Conservative allow-list of ops whose output preserves the
/// peer-class semantics of their inputs. Adding to this list is a
/// deliberate act: a new entry says "if this op's input is a graph
/// input feeding a `wire.Send`'s peer slot, the graph input itself
/// is the peer source." Ops that produce peer values from non-peer
/// inputs (e.g. `PeerSelector::Sample`) are NOT pass-throughs —
/// their `peer_class` metadata already drives destination routing.
fn is_peer_pass_through(node: &bb_ir::proto::onnx::NodeProto) -> bool {
    matches!(
        (node.domain.as_str(), node.op_type.as_str()),
        ("ai.onnx", "Identity")
            | ("ai.onnx", "Slice")
            | ("ai.onnx", "Gather")
            | ("ai.onnx", "Concat")
            | ("ai.onnx", "Squeeze")
            | ("ai.onnx", "Unsqueeze")
    )
}

/// Pull the [`bb_ir::keys::WIRE_ID_KEY`] metadata stamp from a wire
/// op NodeProto. Used to pair Send and Recv NodeProtos the DSL
/// `Graph::wire` emits together.
fn read_wire_id(node: &bb_ir::proto::onnx::NodeProto) -> Option<String> {
    node.metadata_props
        .iter()
        .find(|p| p.key == bb_ir::keys::WIRE_ID_KEY)
        .map(|p| p.value.clone())
}

/// Stamp the `HOME_CLASS_KEY` metadata onto a NodeProto, replacing
/// any existing stamp (idempotent re-runs preserve the same value).
fn stamp_home(node: &mut bb_ir::proto::onnx::NodeProto, home: &str) {
    if let Some(existing) = node
        .metadata_props
        .iter_mut()
        .find(|p| p.key == HOME_CLASS_KEY)
    {
        existing.value = home.to_string();
        return;
    }
    node.metadata_props.push(StringStringEntryProto {
        key: HOME_CLASS_KEY.to_string(),
        value: home.to_string(),
    });
}

/// Returns `true` when a ValueInfoProto carries the `peer_class`
/// metadata stamp from `Graph::input(name, &TYPE_PEER_ID)`. We use the
/// presence of `PEER_CLASS_KEY` as the signal rather than the TypeNode
/// denotation, because the compiler doesn't have access to the
/// `TypeNode` static after the graph crosses the recording boundary.
///
/// accept both `"bb.peer_id"` (single peer recipient)
/// and `"bb.peer_id_vec"` (broadcast multi-peer recipient)
/// denotations. Closes the corrected design's S4 finding: the
/// previous implementation only matched `bb.peer_id` and silently
/// misclassified the broadcast peer-vec case.
fn value_info_is_peer_id(vi: &bb_ir::proto::onnx::ValueInfoProto) -> bool {
    if vi.metadata_props.iter().any(|p| p.key == PEER_CLASS_KEY) {
        return true;
    }
    // Fall back to the type's denotation for plain `Output<PeerId>`
    // / `Output<Vec<PeerId>>` values that didn't go through
    // `Graph::input` (hand-built fixtures, replayed ModelProto
    // bodies).
    matches!(&vi.r#type, Some(TypeProto { value: Some(type_proto::Value::TensorType(_)), denotation, .. })
        if denotation == "bb.peer_id" || denotation == "bb.peer_id_vec")
}