bb_compiler/refine_polymorphic_value_info.rs
1//! `refine_polymorphic_value_info` — narrows the placeholder
2//! `TYPE_TENSOR` denotation stamped by the DSL recorder on every
3//! Contract-method NodeProto to each bound concrete's actual
4//! `Storage::TYPE`.
5//!
6//! Runs BEFORE `run_pipeline` (and therefore BEFORE `type_solver`)
7//! so the solver walks the narrowed denotations, not the placeholder
8//! ones. Because this pass needs access to `BindingSpec` it lives in
9//! `Compiler::compile()` alongside `validate_all_slots_bound`, not
10//! inside the canonical runner pipeline (which has no binding
11//! context). `validate_all_slots_bound` runs AFTER the pipeline to
12//! confirm every bound slot was used.
13//!
14//! Pass order: `refine_polymorphic_value_info` → `run_pipeline`
15//! (containing `type_solver`) → `validate_all_slots_bound`.
16//!
17
18use std::collections::HashMap;
19
20use crate::artifact::BindingSpec;
21use crate::error::CompileError;
22use bb_ir::proto::onnx::{ModelProto, NodeProto};
23use bb_ir::syscall_ids::{OP_PASS_THROUGH, OP_WIRE_SEND, SYSCALL_DOMAIN, WIRE_DOMAIN};
24use bb_ir::types::TypeNode;
25
26/// Walk every Contract-method NodeProto in every `FunctionProto` and
27/// refine its `value_info` denotation from the polymorphic
28/// `TYPE_TENSOR` placeholder to the bound concrete's actual
29/// `Storage::TYPE`.
30///
31/// Only nodes carrying both `ai.bytesandbrains.required_trait` **and**
32/// `ai.bytesandbrains.slot_id` metadata are considered Contract-method
33/// nodes. Nodes without these metadata entries are silently skipped.
34///
35/// If the resolved `BindingSlot` has empty `storage_types` (e.g. a
36/// hand-implemented concrete that didn't use `#[derive(bb::<Role>)]`),
37/// the denotation is left unchanged — no error is returned. This is the
38/// documented graceful-degradation path.
39pub(crate) fn refine_polymorphic_value_info(
40 model: &mut ModelProto,
41 spec: &BindingSpec,
42) -> Result<(), CompileError> {
43 // Collect (output_value_name → narrowed TypeNode) pairs in a
44 // first pass over functions, then apply them in a second pass
45 // to avoid borrowing `model` mutably while reading it.
46 let mut refinements: Vec<(String, &'static TypeNode)> = Vec::new();
47
48 for function in model.functions.iter() {
49 for node in function.node.iter() {
50 // Nodes without required_trait and without slot_id are
51 // not Contract-method nodes — silently skip them.
52 let has_required_trait =
53 metadata_value(node, "ai.bytesandbrains.required_trait").is_some();
54 let has_slot_id = metadata_value(node, "ai.bytesandbrains.slot_id").is_some();
55
56 if !has_required_trait && !has_slot_id {
57 continue;
58 }
59
60 // A node with slot_id but no required_trait is malformed IR —
61 // the DSL recorder always stamps both together.
62 if has_slot_id && !has_required_trait {
63 return Err(CompileError::MissingRequiredTraitMetadata {
64 node: node.name.clone(),
65 });
66 }
67
68 // At this point has_required_trait is true.
69 let required_trait =
70 metadata_value(node, "ai.bytesandbrains.required_trait").expect("checked above");
71
72 // Guard: two or more slots sharing a role means
73 // lookup_by_role would silently pick the first and apply
74 // the wrong concrete's storage type to nodes belonging to
75 // the other slot. Surface this as a hard error when the
76 // model actually contains a node referencing the
77 // ambiguous role. (Models with no Contract nodes for a
78 // given role are unaffected even if the spec has
79 // duplicate role bindings.)
80 {
81 let matching: Vec<&str> = spec
82 .slots
83 .iter()
84 .filter(|s| s.role == required_trait)
85 .map(|s| s.slot.as_str())
86 .collect();
87 if matching.len() > 1 {
88 return Err(CompileError::AmbiguousRoleBinding {
89 role: required_trait.to_string(),
90 slot_names: matching.iter().map(|s| s.to_string()).collect(),
91 });
92 }
93 }
94
95 // slot_id must be present and parseable on a Contract-method
96 // node; if it isn't, surface an error — the DSL recorder
97 // always stamps it.
98 let slot_id_str =
99 metadata_value(node, "ai.bytesandbrains.slot_id").ok_or_else(|| {
100 CompileError::InvalidSlotId {
101 node: node.name.clone(),
102 value: String::new(),
103 }
104 })?;
105 // `_slot_id` is validated for malformed-IR detection;
106 // lookup uses `lookup_by_role` until slot_id-keyed
107 // BindingSpec lookup is added.
108 let _slot_id: u32 = slot_id_str
109 .parse()
110 .map_err(|_| CompileError::InvalidSlotId {
111 node: node.name.clone(),
112 value: slot_id_str.to_string(),
113 })?;
114
115 // Look up the binding slot by its role (the `required_trait`
116 // string is the role-runtime identifier).
117 let slot =
118 spec.lookup_by_role(required_trait)
119 .ok_or_else(|| CompileError::UnknownSlotId {
120 node: node.name.clone(),
121 slot_id: _slot_id,
122 })?;
123
124 // Map required_trait → port label; CodecRuntime also needs
125 // the per-node `codec.port` metadata. `None` means this
126 // role carries no storage-typed port (e.g. PeerSelector).
127 let Some(port) = port_name_for_trait(required_trait, node)? else {
128 continue;
129 };
130
131 // Resolve the storage TypeNode. An empty `storage_types`
132 // (graceful degradation) returns `None`; we skip quietly.
133 let Some(narrowed) = slot.storage_type_opt(port) else {
134 continue;
135 };
136
137 for output in &node.output {
138 refinements.push((output.clone(), narrowed));
139 }
140 }
141 }
142
143 propagate_through_value_preserving_ops(model, &mut refinements);
144
145 apply_refinements(model, &refinements);
146 Ok(())
147}
148
149/// Forward closure walk: `PassThrough` (syscall) and `wire.Send`
150/// preserve their input value's type on their first output (the
151/// renamed port / re-published value). Once the Contract-method
152/// upstream of one of these ops has been refined, the refined type
153/// must travel forward so the port name's `value_info` stops
154/// resolving as abstract `TYPE_TENSOR`.
155///
156/// Without this propagation the type solver's strict mode rejects
157/// `g.net_out(name, peers, role_method_output)` because the port
158/// name's `value_info` keeps the recorder's placeholder denotation.
159fn propagate_through_value_preserving_ops(
160 model: &ModelProto,
161 refinements: &mut Vec<(String, &'static TypeNode)>,
162) {
163 if refinements.is_empty() {
164 return;
165 }
166 let mut by_name: HashMap<String, &'static TypeNode> =
167 refinements.iter().map(|(n, t)| (n.clone(), *t)).collect();
168
169 loop {
170 let mut added = false;
171 for function in model.functions.iter() {
172 for node in function.node.iter() {
173 if !is_value_preserving(node) {
174 continue;
175 }
176 let Some(input_name) = node.input.first() else {
177 continue;
178 };
179 let Some(&narrowed) = by_name.get(input_name) else {
180 continue;
181 };
182 let Some(output_name) = node.output.first() else {
183 continue;
184 };
185 if by_name.contains_key(output_name) {
186 continue;
187 }
188 by_name.insert(output_name.clone(), narrowed);
189 refinements.push((output_name.clone(), narrowed));
190 added = true;
191 }
192 }
193 if !added {
194 break;
195 }
196 }
197}
198
199/// Whether a NodeProto carries its first input's type to its first
200/// output unchanged. Covers the recorder-emitted PassThrough (the
201/// idempotent `g.output` re-name) and `wire.Send` (which republishes
202/// the value under the port name on the sender partition).
203fn is_value_preserving(node: &NodeProto) -> bool {
204 matches!(
205 (node.domain.as_str(), node.op_type.as_str()),
206 (SYSCALL_DOMAIN, OP_PASS_THROUGH) | (WIRE_DOMAIN, OP_WIRE_SEND)
207 )
208}
209
210/// Map a `required_trait` string to the port label used in
211/// `BindingSlot.storage_types`. For `CodecRuntime` the port is
212/// read from `ai.bytesandbrains.codec.port` metadata on the node.
213///
214/// Returns `Ok(None)` for roles that carry no storage-typed port
215/// (e.g. `PeerSelectorRuntime`). The caller skips refinement for
216/// those nodes.
217fn port_name_for_trait(
218 required_trait: &str,
219 node: &NodeProto,
220) -> Result<Option<&'static str>, CompileError> {
221 match required_trait {
222 "IndexRuntime" => Ok(Some("vector")),
223 "AggregatorRuntime" => Ok(Some("element")),
224 "ModelRuntime" => Ok(Some("tensor")),
225 "DataSourceRuntime" => Ok(Some("sample")),
226 "BackendRuntime" => Ok(Some("tensor")),
227 "PeerSelectorRuntime" => Ok(None), // peer selectors have no storage-typed port
228 "CodecRuntime" => {
229 let port_meta =
230 metadata_value(node, "ai.bytesandbrains.codec.port").ok_or_else(|| {
231 CompileError::MissingCodecPortMetadata {
232 node: node.name.clone(),
233 }
234 })?;
235 match port_meta {
236 "in" => Ok(Some("in")),
237 "out" => Ok(Some("out")),
238 other => Err(CompileError::InvalidCodecPort {
239 node: node.name.clone(),
240 value: other.to_string(),
241 }),
242 }
243 }
244 _ => Err(CompileError::UnknownRoleRuntime {
245 node: node.name.clone(),
246 role: required_trait.to_string(),
247 }),
248 }
249}
250
251/// Stamp the collected refinements onto the model's
252/// `FunctionProto.value_info` entries in place.
253fn apply_refinements(model: &mut ModelProto, refinements: &[(String, &'static TypeNode)]) {
254 if refinements.is_empty() {
255 return;
256 }
257 for function in model.functions.iter_mut() {
258 for vi in function.value_info.iter_mut() {
259 if let Some((_, narrowed)) = refinements.iter().find(|(name, _)| *name == vi.name) {
260 if let Some(ref mut t) = vi.r#type {
261 t.denotation = narrowed.denotation.to_string();
262 }
263 }
264 }
265 }
266}
267
268/// Read a metadata value from a `NodeProto.metadata_props` list.
269fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
270 node.metadata_props
271 .iter()
272 .find(|p| p.key == key)
273 .map(|p| p.value.as_str())
274}
275