bb_compiler/stamp_compilation_metadata.rs
1//! `stamp_compilation_metadata` — the final pass that turns each
2//! per-partition `ModelProto` into a complete artifact by writing
3//! the compilation passport + slot binding table into
4//! `metadata_props`. Install reads these stamps; the typed
5//! `BindingSpec` Rust struct never crosses the compile boundary
6//! after this pass runs.
7//!
8//! Per-partition: each partition's root function name is the
9//! target identifier. Bindings are stamped against that target.
10//! A future multi-partition merge step concatenates the partitions
11//! into one ModelProto with each target as a distinct entry in
12//! `functions[]`; this pass is partition-local so the merge can
13//! happen at any granularity.
14//!
15//! Slot-id resolution: walk the partition's NodeProtos for the
16//! `(REQUIRED_TRAIT_KEY, SLOT_ID_KEY)` pair stamped by the
17//! recorder's placeholder DSL methods. For each `BindingSlot`,
18//! find the slot_id whose role matches; dep-only slots (no
19//! NodeProto references) encode as `-1`.
20
21use bb_ir::keys::{
22 binding_key, encode_binding_value, stamp_model_metadata, COMPILED_CURRENT_VERSION,
23 COMPILED_KEY, RECV_SLOT_ID_KEY, REQUIRED_TRAIT_KEY, SLOT_ID_KEY,
24};
25use bb_ir::proto::onnx::{ModelProto, NodeProto, StringStringEntryProto};
26
27use crate::artifact::BindingSpec;
28
29/// Stamp `bb.compiled = "v1"` on the model + one binding entry per
30/// slot under the per-target prefix. Each partition's root function
31/// is its target; pass each partition's `target_name` alongside its
32/// model. The pass mutates `model.metadata_props` in place and is
33/// idempotent (re-stamping with the same values is a no-op).
34///
35/// Also walks every function for `wire.Recv` NodeProtos whose
36/// payload output feeds a role NodeProto (carrying [`SLOT_ID_KEY`])
37/// and stamps [`RECV_SLOT_ID_KEY`] on the Recv node's
38/// `metadata_props`. The install path reads this to populate
39/// `GraphSlot::recv_site_to_slot_id` so `decode_typed_fill` can route
40/// backend-bound tensor fills through the bound backend's
41/// `materialize_from_wire`. Recv nodes whose payload does not flow
42/// into a role NodeProto are left unstamped (framework-carrier path).
43pub(crate) fn stamp_compilation_metadata(
44 model: &mut ModelProto,
45 bindings: &BindingSpec,
46 target_name: &str,
47) {
48 // 1) Compilation passport — present on every compiled model.
49 stamp_model_metadata(model, COMPILED_KEY, COMPILED_CURRENT_VERSION);
50
51 // 2) Build a (role → slot_id) map by walking the model's
52 // NodeProtos for the recorder's (required_trait, slot_id)
53 // placeholder stamps. The compiler-internal BindingSpec lists
54 // every slot the install path must fill; bindings whose role
55 // doesn't appear in any NodeProto are dep-only and encode -1.
56 let role_to_slot_id = collect_role_slot_ids(model);
57
58 // 3) Stamp one binding entry per BindingSlot.
59 for slot in &bindings.slots {
60 let role_canon = canonical_role(&slot.role);
61 let slot_id_or_neg1 = role_to_slot_id
62 .iter()
63 .find(|(role, _)| canonical_role(role) == role_canon)
64 .map(|(_, id)| *id as i64)
65 .unwrap_or(-1);
66 let key = binding_key(target_name, &slot.slot);
67 let value = encode_binding_value(&role_canon, &slot.concrete_type_name, slot_id_or_neg1);
68 stamp_model_metadata(model, &key, &value);
69 }
70
71 // 4) Stamp `RECV_SLOT_ID_KEY` on each `wire.Recv` whose payload
72 // output flows into a role NodeProto. The install pass reads
73 // this to map the Recv's allocated `NodeSiteId` onto the
74 // role's `slot_id`.
75 stamp_recv_slot_ids(model);
76}
77
78/// Walk every function in `model.functions`, find `wire.Recv` nodes,
79/// and stamp `RECV_SLOT_ID_KEY` on each whose first output flows
80/// into a role NodeProto's input. The Recv's first output is its
81/// payload site; we look up consumers by name match.
82fn stamp_recv_slot_ids(model: &mut ModelProto) {
83 for function in &mut model.functions {
84 // Build a (payload_name → recv_index) map for every Recv,
85 // and a (consumer_input_name → slot_id) map by scanning role
86 // NodeProtos.
87 let mut recv_indices: Vec<(usize, String)> = Vec::new();
88 let mut consumer_slot_ids: Vec<(String, u32)> = Vec::new();
89 for (idx, node) in function.node.iter().enumerate() {
90 if is_wire_recv(node) {
91 if let Some(payload) = node.output.first() {
92 if !payload.is_empty() {
93 recv_indices.push((idx, payload.clone()));
94 }
95 }
96 continue;
97 }
98 let Some(slot_id) =
99 metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
100 else {
101 continue;
102 };
103 for input in &node.input {
104 if !input.is_empty() {
105 consumer_slot_ids.push((input.clone(), slot_id));
106 }
107 }
108 }
109 for (recv_idx, payload_name) in recv_indices {
110 // Pick the first downstream consumer's slot_id. If the
111 // same payload feeds multiple role NodeProtos with
112 // distinct slot_ids, the compiler would have rejected the
113 // graph upstream (one Recv site == one destination
114 // binding); the first hit is therefore the only hit on
115 // any valid input.
116 let Some(slot_id) = consumer_slot_ids
117 .iter()
118 .find(|(name, _)| name == &payload_name)
119 .map(|(_, id)| *id)
120 else {
121 continue;
122 };
123 stamp_node_metadata(
124 &mut function.node[recv_idx],
125 RECV_SLOT_ID_KEY,
126 &slot_id.to_string(),
127 );
128 }
129 }
130}
131
132fn is_wire_recv(node: &NodeProto) -> bool {
133 node.domain == "ai.bytesandbrains.wire" && node.op_type == "Recv"
134}
135
136fn stamp_node_metadata(node: &mut NodeProto, key: &str, value: &str) {
137 if let Some(existing) = node.metadata_props.iter_mut().find(|p| p.key == key) {
138 existing.value = value.to_string();
139 return;
140 }
141 node.metadata_props.push(StringStringEntryProto {
142 key: key.to_string(),
143 value: value.to_string(),
144 });
145}
146
147/// Walk every NodeProto in every function for the recorder's
148/// `(REQUIRED_TRAIT_KEY, SLOT_ID_KEY)` pair; return the distinct
149/// `(required_trait, slot_id)` pairs seen. Order is deterministic
150/// (insertion order over the walk).
151fn collect_role_slot_ids(model: &ModelProto) -> Vec<(String, u32)> {
152 let mut out: Vec<(String, u32)> = Vec::new();
153 for function in &model.functions {
154 for node in &function.node {
155 let Some(role) = metadata_value(node, REQUIRED_TRAIT_KEY) else {
156 continue;
157 };
158 let Some(slot_id) =
159 metadata_value(node, SLOT_ID_KEY).and_then(|v| v.parse::<u32>().ok())
160 else {
161 continue;
162 };
163 if !out.iter().any(|(r, id)| r == role && *id == slot_id) {
164 out.push((role.to_string(), slot_id));
165 }
166 }
167 }
168 out
169}
170
171fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
172 node.metadata_props
173 .iter()
174 .find(|p| p.key == key)
175 .map(|p| p.value.as_str())
176}
177
178/// `BindingSlot.role` carries the engine-side trait name
179/// (`"BackendRuntime"`, `"IndexRuntime"`, …). Both the install path
180/// and the runtime `RuntimeResourceRef::dependency` lookups use the
181/// canonical Contract role identifier (PascalCase, no `Runtime`
182/// suffix). The stamped binding values use the canonical form so
183/// install doesn't have to re-normalize per slot.
184fn canonical_role(role: &str) -> String {
185 role.strip_suffix("Runtime").unwrap_or(role).to_string()
186}
187
188/// Test-only helper: stamp a model with a synthetic
189/// `BindingSpec` so test fixtures can drive `install()` without
190/// running the full compile pipeline. The first FunctionProto's
191/// name is used as the `target`. Bindings are `(slot, role,
192/// TYPE_NAME)` triples; the role string is the canonical
193/// PascalCase identifier (`"Backend"`, `"Index"`, etc.) — the
194/// stamp pass canonicalizes either form.
195///
196/// Drives through the same stamp path the compiler uses, so tests
197/// exercise the real encoding.
198pub fn stamp_for_test(model: &mut ModelProto, bindings: &[(&str, &str, &str)]) {
199 let target = model
200 .functions
201 .first()
202 .map(|f| f.name.clone())
203 .unwrap_or_default();
204 let mut spec = BindingSpec::new();
205 for (slot, role, type_name) in bindings {
206 spec.push(slot.to_string(), role.to_string(), type_name.to_string());
207 }
208 stamp_compilation_metadata(model, &spec, &target);
209}
210