Skip to main content

bb_compiler/
analyze_wire_edges.rs

1//! `analyze_wire_edges` — classify each cross-Node edge as `data`
2//! or `trigger_only` and group sends in the same cycle scope for
3//! batching.
4
5use std::collections::{BTreeMap, HashMap};
6
7use crate::error::CompileError;
8use crate::partition_by_wire_ops::WireEdge;
9use bb_ir::proto::onnx::{GraphProto, NodeProto, StringStringEntryProto};
10
11// IR-level metadata keys + helpers live in `bb_ir::keys` — single
12// source of truth across DSL → compiler → runtime. This pass uses
13// them directly.
14pub use bb_ir::keys::{
15    dest_suffix_attribute, BATCH_GROUP_KEY, DEST_SITE_NAME_PREFIX, DEST_SUFFIX_ATTR_PREFIX,
16    TRIGGER_DENOTATION, WIRE_TRANSPORT_KEY,
17};
18
19/// Per-edge classification + per-cycle batching. Pure.
20///
21/// Writes the classification metadata directly onto the matching
22/// `sub_graph.node` NodeProtos (Send + Recv pairs). The
23/// `wire_edges` slice drives iteration — identifying which edges
24/// exist and pairing producer/consumer roles — but the pass treats
25/// it as read-only: the `WireEdge.send_node` / `WireEdge.recv_node`
26/// clones are discarded by downstream passes, so writing them is
27/// a no-op.
28pub fn analyze_wire_edges(
29    sub_graph: &mut GraphProto,
30    wire_edges: &[WireEdge],
31) -> Result<(), CompileError> {
32    let denotation_by_name: HashMap<&str, &str> = sub_graph
33        .value_info
34        .iter()
35        .chain(sub_graph.input.iter())
36        .chain(sub_graph.output.iter())
37        .filter_map(|v| {
38            let denot = v.r#type.as_ref()?.denotation.as_str();
39            if denot.is_empty() {
40                None
41            } else {
42                Some((v.name.as_str(), denot))
43            }
44        })
45        .collect();
46
47    let mut batch_groups: BTreeMap<(String, String), u32> = BTreeMap::new();
48    let mut next_batch_id: u32 = 0;
49
50    // Classification rule: if EVERY
51    // downstream consumer's input-port type is `bb.trigger`, mark
52    // the edge `trigger_only`; otherwise `data`. We walk the
53    // sub-graph's nodes to find each consumer of the edge value,
54    // then resolve that consumer's input-port type via the
55    // value_info denotation map. Empty consumer set defaults to
56    // `data` (conservative - preserves payload bytes for
57    // out-of-graph receivers).
58    for edge in wire_edges {
59        let consumer_port_denots: Vec<&str> =
60            consumer_input_denotations(&denotation_by_name, sub_graph, &edge.value_name);
61
62        let transport = if consumer_port_denots.is_empty() {
63            // No in-sub-graph consumer found - fall back to the
64            // edge value's declared denotation (preserves prior
65            // behavior for edges that exit the sub-graph entirely).
66            match denotation_by_name.get(edge.value_name.as_str()) {
67                Some(d) if *d == TRIGGER_DENOTATION => "trigger_only",
68                _ => "data",
69            }
70        } else if consumer_port_denots
71            .iter()
72            .all(|d| *d == TRIGGER_DENOTATION)
73        {
74            "trigger_only"
75        } else {
76            "data"
77        };
78
79        let key = (edge.producer_role.clone(), edge.consumer_role.clone());
80        let batch_id = *batch_groups.entry(key).or_insert_with(|| {
81            let id = next_batch_id;
82            next_batch_id += 1;
83            id
84        });
85        let batch_str = batch_id.to_string();
86
87        // Stamp the deferred recv-site name on the producer Send
88        // NodeProto. Node's install path resolves each entry
89        // to a `NodeSiteId` against the consumer's installed graph
90        // and rewrites the Send NodeProto with a canonical
91        // `dest_suffix.<input>` Address-bytes attribute.
92        let recv_site_name = edge
93            .recv_node
94            .output
95            .first()
96            .cloned()
97            .unwrap_or_else(|| edge.value_name.clone());
98        let dest_site_key = format!("{DEST_SITE_NAME_PREFIX}{}", edge.value_name);
99
100        for node in sub_graph.node.iter_mut() {
101            let matches_value = node.output.iter().any(|o| o == &edge.value_name);
102            if !matches_value {
103                continue;
104            }
105            if node.op_type == "Send" {
106                set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
107                set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
108                set_metadata(&mut node.metadata_props, &dest_site_key, &recv_site_name);
109            } else if node.op_type == "Recv" {
110                set_metadata(&mut node.metadata_props, WIRE_TRANSPORT_KEY, transport);
111                set_metadata(&mut node.metadata_props, BATCH_GROUP_KEY, &batch_str);
112            }
113        }
114    }
115
116    Ok(())
117}
118
119/// Look up a per-input `dest_suffix.<name>` attribute on the given
120/// NodeProto. Returns the canonical Address byte encoding stamped by
121/// Node's install-time resolver. Used by the wire syscall to
122/// populate each `SlotFill.dest_suffix` at dispatch time.
123pub fn dest_suffix_attr<'a>(node: &'a NodeProto, input_name: &str) -> Option<&'a [u8]> {
124    let key = format!("{DEST_SUFFIX_ATTR_PREFIX}{input_name}");
125    node.attribute
126        .iter()
127        .find(|a| a.name == key)
128        .map(|a| a.s.as_slice())
129}
130
131/// Walk every NodeProto in `sub_graph` and, for each one that
132/// consumes `value_name` on any of its input ports, return the
133/// per-port denotation as declared in `denotation_by_name`. Empty
134/// when no consumer references the value. (ONNX typically declares
135/// one type per value name, but the per-port walk lets us be
136/// explicit about §9.1's "every downstream consumer" rule.)
137fn consumer_input_denotations<'a>(
138    denotation_by_name: &HashMap<&'a str, &'a str>,
139    sub_graph: &'a GraphProto,
140    value_name: &str,
141) -> Vec<&'a str> {
142    let mut out: Vec<&str> = Vec::new();
143    for node in &sub_graph.node {
144        for input in &node.input {
145            if input == value_name {
146                if let Some(d) = denotation_by_name.get(input.as_str()) {
147                    out.push(*d);
148                } else {
149                    // Consumer with no declared type → default to
150                    // data so we keep the payload bytes available.
151                    out.push("");
152                }
153            }
154        }
155    }
156    out
157}
158
159fn set_metadata(props: &mut Vec<StringStringEntryProto>, key: &str, value: &str) {
160    if let Some(existing) = props.iter_mut().find(|p| p.key == key) {
161        existing.value = value.to_string();
162        return;
163    }
164    props.push(StringStringEntryProto {
165        key: key.to_string(),
166        value: value.to_string(),
167    });
168}
169