Skip to main content

bb_compiler/
refine_polymorphic_value_info.rs

1//! `refine_polymorphic_value_info` — narrows the placeholder
2//! `TYPE_TENSOR` denotation stamped by the DSL recorder on every
3//! Contract-method NodeProto to each bound concrete's actual
4//! `Storage::TYPE`.
5//!
6//! Runs BEFORE `run_pipeline` (and therefore BEFORE `type_solver`)
7//! so the solver walks the narrowed denotations, not the placeholder
8//! ones. Because this pass needs access to `BindingSpec` it lives in
9//! `Compiler::compile()` alongside `validate_all_slots_bound`, not
10//! inside the canonical runner pipeline (which has no binding
11//! context). `validate_all_slots_bound` runs AFTER the pipeline to
12//! confirm every bound slot was used.
13//!
14//! Pass order: `refine_polymorphic_value_info` → `run_pipeline`
15//! (containing `type_solver`) → `validate_all_slots_bound`.
16//!
17
18use std::collections::HashMap;
19
20use crate::artifact::BindingSpec;
21use crate::error::CompileError;
22use bb_ir::proto::onnx::{ModelProto, NodeProto};
23use bb_ir::syscall_ids::{OP_PASS_THROUGH, OP_WIRE_SEND, SYSCALL_DOMAIN, WIRE_DOMAIN};
24use bb_ir::types::TypeNode;
25
26/// Walk every Contract-method NodeProto in every `FunctionProto` and
27/// refine its `value_info` denotation from the polymorphic
28/// `TYPE_TENSOR` placeholder to the bound concrete's actual
29/// `Storage::TYPE`.
30///
31/// Only nodes carrying both `ai.bytesandbrains.required_trait` **and**
32/// `ai.bytesandbrains.slot_id` metadata are considered Contract-method
33/// nodes. Nodes without these metadata entries are silently skipped.
34///
35/// If the resolved `BindingSlot` has empty `storage_types` (e.g. a
36/// hand-implemented concrete that didn't use `#[derive(bb::<Role>)]`),
37/// the denotation is left unchanged — no error is returned. This is the
38/// documented graceful-degradation path.
39pub(crate) fn refine_polymorphic_value_info(
40    model: &mut ModelProto,
41    spec: &BindingSpec,
42) -> Result<(), CompileError> {
43    // Collect (output_value_name → narrowed TypeNode) pairs in a
44    // first pass over functions, then apply them in a second pass
45    // to avoid borrowing `model` mutably while reading it.
46    let mut refinements: Vec<(String, &'static TypeNode)> = Vec::new();
47
48    for function in model.functions.iter() {
49        for node in function.node.iter() {
50            // Nodes without required_trait and without slot_id are
51            // not Contract-method nodes — silently skip them.
52            let has_required_trait =
53                metadata_value(node, "ai.bytesandbrains.required_trait").is_some();
54            let has_slot_id = metadata_value(node, "ai.bytesandbrains.slot_id").is_some();
55
56            if !has_required_trait && !has_slot_id {
57                continue;
58            }
59
60            // A node with slot_id but no required_trait is malformed IR —
61            // the DSL recorder always stamps both together.
62            if has_slot_id && !has_required_trait {
63                return Err(CompileError::MissingRequiredTraitMetadata {
64                    node: node.name.clone(),
65                });
66            }
67
68            // At this point has_required_trait is true.
69            let required_trait =
70                metadata_value(node, "ai.bytesandbrains.required_trait").expect("checked above");
71
72            // Guard: two or more slots sharing a role means
73            // lookup_by_role would silently pick the first and apply
74            // the wrong concrete's storage type to nodes belonging to
75            // the other slot. Surface this as a hard error when the
76            // model actually contains a node referencing the
77            // ambiguous role. (Models with no Contract nodes for a
78            // given role are unaffected even if the spec has
79            // duplicate role bindings.)
80            {
81                let matching: Vec<&str> = spec
82                    .slots
83                    .iter()
84                    .filter(|s| s.role == required_trait)
85                    .map(|s| s.slot.as_str())
86                    .collect();
87                if matching.len() > 1 {
88                    return Err(CompileError::AmbiguousRoleBinding {
89                        role: required_trait.to_string(),
90                        slot_names: matching.iter().map(|s| s.to_string()).collect(),
91                    });
92                }
93            }
94
95            // slot_id must be present and parseable on a Contract-method
96            // node; if it isn't, surface an error — the DSL recorder
97            // always stamps it.
98            let slot_id_str =
99                metadata_value(node, "ai.bytesandbrains.slot_id").ok_or_else(|| {
100                    CompileError::InvalidSlotId {
101                        node: node.name.clone(),
102                        value: String::new(),
103                    }
104                })?;
105            // `_slot_id` is validated for malformed-IR detection;
106            // lookup uses `lookup_by_role` until slot_id-keyed
107            // BindingSpec lookup is added.
108            let _slot_id: u32 = slot_id_str
109                .parse()
110                .map_err(|_| CompileError::InvalidSlotId {
111                    node: node.name.clone(),
112                    value: slot_id_str.to_string(),
113                })?;
114
115            // Look up the binding slot by its role (the `required_trait`
116            // string is the role-runtime identifier).
117            let slot =
118                spec.lookup_by_role(required_trait)
119                    .ok_or_else(|| CompileError::UnknownSlotId {
120                        node: node.name.clone(),
121                        slot_id: _slot_id,
122                    })?;
123
124            // Map required_trait → port label; CodecRuntime also needs
125            // the per-node `codec.port` metadata. `None` means this
126            // role carries no storage-typed port (e.g. PeerSelector).
127            let Some(port) = port_name_for_trait(required_trait, node)? else {
128                continue;
129            };
130
131            // Resolve the storage TypeNode. An empty `storage_types`
132            // (graceful degradation) returns `None`; we skip quietly.
133            let Some(narrowed) = slot.storage_type_opt(port) else {
134                continue;
135            };
136
137            for output in &node.output {
138                refinements.push((output.clone(), narrowed));
139            }
140        }
141    }
142
143    propagate_through_value_preserving_ops(model, &mut refinements);
144
145    apply_refinements(model, &refinements);
146    Ok(())
147}
148
149/// Forward closure walk: `PassThrough` (syscall) and `wire.Send`
150/// preserve their input value's type on their first output (the
151/// renamed port / re-published value). Once the Contract-method
152/// upstream of one of these ops has been refined, the refined type
153/// must travel forward so the port name's `value_info` stops
154/// resolving as abstract `TYPE_TENSOR`.
155///
156/// Without this propagation the type solver's strict mode rejects
157/// `g.net_out(name, peers, role_method_output)` because the port
158/// name's `value_info` keeps the recorder's placeholder denotation.
159fn propagate_through_value_preserving_ops(
160    model: &ModelProto,
161    refinements: &mut Vec<(String, &'static TypeNode)>,
162) {
163    if refinements.is_empty() {
164        return;
165    }
166    let mut by_name: HashMap<String, &'static TypeNode> =
167        refinements.iter().map(|(n, t)| (n.clone(), *t)).collect();
168
169    loop {
170        let mut added = false;
171        for function in model.functions.iter() {
172            for node in function.node.iter() {
173                if !is_value_preserving(node) {
174                    continue;
175                }
176                let Some(input_name) = node.input.first() else {
177                    continue;
178                };
179                let Some(&narrowed) = by_name.get(input_name) else {
180                    continue;
181                };
182                let Some(output_name) = node.output.first() else {
183                    continue;
184                };
185                if by_name.contains_key(output_name) {
186                    continue;
187                }
188                by_name.insert(output_name.clone(), narrowed);
189                refinements.push((output_name.clone(), narrowed));
190                added = true;
191            }
192        }
193        if !added {
194            break;
195        }
196    }
197}
198
199/// Whether a NodeProto carries its first input's type to its first
200/// output unchanged. Covers the recorder-emitted PassThrough (the
201/// idempotent `g.output` re-name) and `wire.Send` (which republishes
202/// the value under the port name on the sender partition).
203fn is_value_preserving(node: &NodeProto) -> bool {
204    matches!(
205        (node.domain.as_str(), node.op_type.as_str()),
206        (SYSCALL_DOMAIN, OP_PASS_THROUGH) | (WIRE_DOMAIN, OP_WIRE_SEND)
207    )
208}
209
210/// Map a `required_trait` string to the port label used in
211/// `BindingSlot.storage_types`. For `CodecRuntime` the port is
212/// read from `ai.bytesandbrains.codec.port` metadata on the node.
213///
214/// Returns `Ok(None)` for roles that carry no storage-typed port
215/// (e.g. `PeerSelectorRuntime`). The caller skips refinement for
216/// those nodes.
217fn port_name_for_trait(
218    required_trait: &str,
219    node: &NodeProto,
220) -> Result<Option<&'static str>, CompileError> {
221    match required_trait {
222        "IndexRuntime" => Ok(Some("vector")),
223        "AggregatorRuntime" => Ok(Some("element")),
224        "ModelRuntime" => Ok(Some("tensor")),
225        "DataSourceRuntime" => Ok(Some("sample")),
226        "BackendRuntime" => Ok(Some("tensor")),
227        "PeerSelectorRuntime" => Ok(None), // peer selectors have no storage-typed port
228        "CodecRuntime" => {
229            let port_meta =
230                metadata_value(node, "ai.bytesandbrains.codec.port").ok_or_else(|| {
231                    CompileError::MissingCodecPortMetadata {
232                        node: node.name.clone(),
233                    }
234                })?;
235            match port_meta {
236                "in" => Ok(Some("in")),
237                "out" => Ok(Some("out")),
238                other => Err(CompileError::InvalidCodecPort {
239                    node: node.name.clone(),
240                    value: other.to_string(),
241                }),
242            }
243        }
244        _ => Err(CompileError::UnknownRoleRuntime {
245            node: node.name.clone(),
246            role: required_trait.to_string(),
247        }),
248    }
249}
250
251/// Stamp the collected refinements onto the model's
252/// `FunctionProto.value_info` entries in place.
253fn apply_refinements(model: &mut ModelProto, refinements: &[(String, &'static TypeNode)]) {
254    if refinements.is_empty() {
255        return;
256    }
257    for function in model.functions.iter_mut() {
258        for vi in function.value_info.iter_mut() {
259            if let Some((_, narrowed)) = refinements.iter().find(|(name, _)| *name == vi.name) {
260                if let Some(ref mut t) = vi.r#type {
261                    t.denotation = narrowed.denotation.to_string();
262                }
263            }
264        }
265    }
266}
267
268/// Read a metadata value from a `NodeProto.metadata_props` list.
269fn metadata_value<'a>(node: &'a NodeProto, key: &str) -> Option<&'a str> {
270    node.metadata_props
271        .iter()
272        .find(|p| p.key == key)
273        .map(|p| p.value.as_str())
274}
275