Skip to main content

bb_compiler/
stamp_compilation_metadata.rs

1//! `stamp_compilation_metadata` — the final pass that turns each
2//! per-partition `ModelProto` into a complete artifact by writing
3//! the compilation passport + slot binding table into
4//! `metadata_props`. Install reads these stamps; the typed
5//! `BindingSpec` Rust struct never crosses the compile boundary
6//! after this pass runs.
7//!
8//! Per-partition: each partition's root function name is the
9//! target identifier. Bindings are stamped against that target.
10//! A future multi-partition merge step concatenates the partitions
11//! into one ModelProto with each target as a distinct entry in
12//! `functions[]`; this pass is partition-local so the merge can
13//! happen at any granularity.
14//!
15//! Slot-id resolution: walk the partition's NodeProtos for the
16//! `(REQUIRED_TRAIT_KEY, SLOT_ID_KEY)` pair stamped by the
17//! recorder's placeholder DSL methods. For each `BindingSlot`,
18//! find the slot_id whose role matches; dep-only slots (no
19//! NodeProto references) encode as `-1`.
20
21use bb_ir::keys::{
22    binding_key, encode_binding_value, stamp_model_metadata, COMPILED_CURRENT_VERSION,
23    COMPILED_KEY, RECV_SLOT_ID_KEY, REQUIRED_TRAIT_KEY, SLOT_ID_KEY,
24};
25use bb_ir::proto::onnx::{ModelProto, NodeProto, StringStringEntryProto};
26
27use crate::artifact::BindingSpec;
28
29/// Stamp `bb.compiled = "v1"` on the model + one binding entry per
30/// slot under the per-target prefix. Each partition's root function
31/// is its target; pass each partition's `target_name` alongside its
32/// model. The pass mutates `model.metadata_props` in place and is
33/// idempotent (re-stamping with the same values is a no-op).
34///
35/// Also walks every function for `wire.Recv` NodeProtos whose
36/// payload output feeds a role NodeProto (carrying [`SLOT_ID_KEY`])
37/// and stamps [`RECV_SLOT_ID_KEY`] on the Recv node's
38/// `metadata_props`. The install path reads this to populate
39/// `GraphSlot::recv_site_to_slot_id` so `decode_typed_fill` can route
40/// backend-bound tensor fills through the bound backend's
41/// `materialize_from_wire`. Recv nodes whose payload does not flow
42/// into a role NodeProto are left unstamped (framework-carrier path).
43pub(crate) fn stamp_compilation_metadata(
44    model: &mut ModelProto,
45    bindings: &BindingSpec,
46    target_name: &str,
47) {
48    // 1) Compilation passport — present on every compiled model.
49    stamp_model_metadata(model, COMPILED_KEY, COMPILED_CURRENT_VERSION);
50
51    // 2) Build a (role → slot_id) map by walking the model's
52    //    NodeProtos for the recorder's (required_trait, slot_id)
53    //    placeholder stamps. The compiler-internal BindingSpec lists
54    //    every slot the install path must fill; bindings whose role
55    //    doesn't appear in any NodeProto are dep-only and encode -1.
56    let role_to_slot_id = collect_role_slot_ids(model);
57
58    // 3) Stamp one binding entry per BindingSlot.
59    for slot in &bindings.slots {
60        let role_canon = canonical_role(&slot.role);
61        let slot_id_or_neg1 = role_to_slot_id
62            .iter()
63            .find(|(role, _)| canonical_role(role) == role_canon)
64            .map(|(_, id)| *id as i64)
65            .unwrap_or(-1);
66        let key = binding_key(target_name, &slot.slot);
67        let value = encode_binding_value(&role_canon, &slot.concrete_type_name, slot_id_or_neg1);
68        stamp_model_metadata(model, &key, &value);
69    }
70
71    // 4) Stamp `RECV_SLOT_ID_KEY` on each `wire.Recv` whose payload
72    //    output flows into a role NodeProto. The install pass reads
73    //    this to map the Recv's allocated `NodeSiteId` onto the
74    //    role's `slot_id`.
75    stamp_recv_slot_ids(model);
76}
77
78/// Walk every function in `model.functions`, find `wire.Recv` nodes,
79/// and stamp `RECV_SLOT_ID_KEY` on each whose first output flows
80/// into a role NodeProto's input. The Recv's first output is its
81/// payload site; we look up consumers by name match.
82fn stamp_recv_slot_ids(model: &mut ModelProto) {
83    for function in &mut model.functions {
84        // Build a (payload_name → recv_index) map for every Recv,
85        // and a (consumer_input_name → slot_id) map by scanning role
86        // NodeProtos.
87        let mut recv_indices: Vec<(usize, String)> = Vec::new();
88        let mut consumer_slot_ids: Vec<(String, u32)> = Vec::new();
89        for (idx, node) in function.node.iter().enumerate() {
90            if is_wire_recv(node) {
91                if let Some(payload) = node.output.first() {
92                    if !payload.is_empty() {
93                        recv_indices.push((idx, payload.clone()));
94                    }
95                }
96                continue;
97            }
98            let Some(slot_id) =
99                metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
100            else {
101                continue;
102            };
103            for input in &node.input {
104                if !input.is_empty() {
105                    consumer_slot_ids.push((input.clone(), slot_id));
106                }
107            }
108        }
109        for (recv_idx, payload_name) in recv_indices {
110            // Pick the first downstream consumer's slot_id. If the
111            // same payload feeds multiple role NodeProtos with
112            // distinct slot_ids, the compiler would have rejected the
113            // graph upstream (one Recv site == one destination
114            // binding); the first hit is therefore the only hit on
115            // any valid input.
116            let Some(slot_id) = consumer_slot_ids
117                .iter()
118                .find(|(name, _)| name == &payload_name)
119                .map(|(_, id)| *id)
120            else {
121                continue;
122            };
123            stamp_node_metadata(
124                &mut function.node[recv_idx],
125                RECV_SLOT_ID_KEY,
126                &slot_id.to_string(),
127            );
128        }
129    }
130}
131
132fn is_wire_recv(node: &NodeProto) -> bool {
133    node.domain == "ai.bytesandbrains.wire" && node.op_type == "Recv"
134}
135
136fn stamp_node_metadata(node: &mut NodeProto, key: &str, value: &str) {
137    if let Some(existing) = node.metadata_props.iter_mut().find(|p| p.key == key) {
138        existing.value = value.to_string();
139        return;
140    }
141    node.metadata_props.push(StringStringEntryProto {
142        key: key.to_string(),
143        value: value.to_string(),
144    });
145}
146
147/// Walk every NodeProto in every function for the recorder's
148/// `(REQUIRED_TRAIT_KEY, SLOT_ID_KEY)` pair; return the distinct
149/// `(required_trait, slot_id)` pairs seen. Order is deterministic
150/// (insertion order over the walk).
151fn collect_role_slot_ids(model: &ModelProto) -> Vec<(String, u32)> {
152    let mut out: Vec<(String, u32)> = Vec::new();
153    for function in &model.functions {
154        for node in &function.node {
155            let Some(role) = metadata_value(node, REQUIRED_TRAIT_KEY) else {
156                continue;
157            };
158            let Some(slot_id) =
159                metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
160            else {
161                continue;
162            };
163            if !out.iter().any(|(r, id)| r == role && *id == slot_id) {
164                out.push((role.to_string(), slot_id));
165            }
166        }
167    }
168    out
169}
170
171fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
172    node.metadata_props
173        .iter()
174        .find(|p| p.key == key)
175        .map(|p| p.value.as_str())
176}
177
178/// `BindingSlot.role` carries the engine-side trait name
179/// (`"BackendRuntime"`, `"IndexRuntime"`, …). Both the install path
180/// and the runtime `RuntimeResourceRef::dependency` lookups use the
181/// canonical Contract role identifier (PascalCase, no `Runtime`
182/// suffix). The stamped binding values use the canonical form so
183/// install doesn't have to re-normalize per slot.
184fn canonical_role(role: &str) -> String {
185    role.strip_suffix("Runtime").unwrap_or(role).to_string()
186}
187
188/// Test-only helper: stamp a model with a synthetic
189/// `BindingSpec` so test fixtures can drive `install()` without
190/// running the full compile pipeline. The first FunctionProto's
191/// name is used as the `target`. Bindings are `(slot, role,
192/// TYPE_NAME)` triples; the role string is the canonical
193/// PascalCase identifier (`"Backend"`, `"Index"`, etc.) — the
194/// stamp pass canonicalizes either form.
195///
196/// Drives through the same stamp path the compiler uses, so tests
197/// exercise the real encoding.
198pub fn stamp_for_test(model: &mut ModelProto, bindings: &[(&str, &str, &str)]) {
199    let target = model
200        .functions
201        .first()
202        .map(|f| f.name.clone())
203        .unwrap_or_default();
204    let mut spec = BindingSpec::new();
205    for (slot, role, type_name) in bindings {
206        spec.push(slot.to_string(), role.to_string(), type_name.to_string());
207    }
208    stamp_compilation_metadata(model, &spec, &target);
209}
210