Skip to main content

bb_compiler/
type_solver.rs

1//! Type-resolution pass — bipartite worklist solver following
2//! TVM Relay's `type_solver.h` design, adapted to Rust.
3//!
4//! ## Shape
5//!
6//! Two arenas: **type nodes** (one per value position) and
7//! **relation nodes** (one per `TypeRelation` instance on each
8//! [`AtomicOpDecl`]). Cross-linked via `rel_set` back-edges. The
9//! worklist holds relations ready to (re)run.
10//!
11//! ## Algorithm
12//!
13//! 1. **Seed.** Allocate a type node for every value name in the
14//!    graph (function inputs, op outputs). Mark each with its
15//!    declared bound (`TYPE_ANY` if none).
16//! 2. **Instantiate relations.** For each NodeProto in the graph,
17//!    look up its `AtomicOpDecl.type_relations` and allocate a
18//!    relation node per declared [`TypeRelation`]. Each relation
19//!    points at the type nodes for the ports it participates in.
20//!    Type nodes track back-edges via `rel_set`.
21//! 3. **Drain.** Pop a relation from the worklist, run it.
22//!    [`RelationResult`] dictates the next move:
23//!    - `Refined` → requeue every relation in the refined type
24//!      nodes' `rel_set`.
25//!    - `Satisfied` → remove from the worklist permanently.
26//!    - `Defer` → leave in the worklist (will retry only when
27//!      something else refines a participating type).
28//!    - `Failed` → abort with a `TypeError`.
29//! 4. **Fixpoint.** When the worklist is empty (or only `Defer`s
30//!    remain that no new refinement could activate), check the
31//!    post-condition: every type node resolves to a concrete leaf.
32//!    Otherwise → `UnresolvedType`.
33//!
34//! ## Scope
35//!
36//! Currently handles [`TypeRelation::SameElementType`] and
37//! [`TypeRelation::Elementwise`] — the two highest-frequency
38//! relations covering most arithmetic + reduction ops. Other
39//! variants (`BroadcastShape`, `SameType`, `ReduceOver`, `Custom`)
40//! plug in by extending the per-variant handler match inside the
41//! solver's internal `run_relation` dispatch.
42
43use std::collections::HashMap;
44
45use bb_ir::proto::onnx::GraphProto;
46use bb_ir::types::{PortRef, RelationResult, TypeNode, TypeRelation, TYPE_ANY};
47
48/// Errors the solver may report.
49#[derive(Debug)]
50pub enum TypeError {
51    /// A relation produced a hard contradiction (e.g. two consumers
52    /// of the same port require incompatible element types).
53    ConstraintFailed {
54        /// Op the relation was attached to.
55        op: String,
56        /// Relation diagnostic string.
57        detail: String,
58    },
59    /// The solver reached fixpoint with type nodes still abstract
60    /// (i.e. not narrowed to a concrete leaf in the lattice).
61    UnresolvedType {
62        /// Value name with no concrete resolution.
63        value: String,
64    },
65    /// An op references a port index that doesn't map to a value
66    /// (out-of-range input/output position on `AtomicOpDecl`).
67    PortOutOfRange {
68        /// Op name.
69        op: String,
70        /// Failing port reference.
71        port: PortRef,
72    },
73}
74
75impl std::fmt::Display for TypeError {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            Self::ConstraintFailed { op, detail } => {
79                write!(f, "type constraint failed at {op}: {detail}")
80            }
81            Self::UnresolvedType { value } => {
82                write!(f, "value `{value}` did not resolve to a concrete type")
83            }
84            Self::PortOutOfRange { op, port } => {
85                write!(f, "op {op} references out-of-range port {port:?}")
86            }
87        }
88    }
89}
90
91impl std::error::Error for TypeError {}
92
93/// Solver output: every value name in the graph maps to its
94/// resolved concrete TypeNode.
95#[derive(Debug)]
96pub struct TypeSolution {
97    by_value: HashMap<String, &'static TypeNode>,
98}
99
100impl TypeSolution {
101    /// Resolved TypeNode for a value name. `None` if the solver
102    /// didn't see this value.
103    pub fn type_of(&self, value: &str) -> Option<&'static TypeNode> {
104        self.by_value.get(value).copied()
105    }
106
107    /// Iterate every resolved (value_name, TypeNode) pair.
108    pub fn iter(&self) -> impl Iterator<Item = (&str, &'static TypeNode)> {
109        self.by_value.iter().map(|(k, v)| (k.as_str(), *v))
110    }
111}
112
113/// Bipartite type-resolution solver.
114pub struct TypeSolver {
115    /// Per-value type nodes. Index = the slot's position in
116    /// [`Self::value_index`].
117    types: Vec<TypeNodeSlot>,
118    /// Per-relation constraint nodes.
119    relations: Vec<RelationNode>,
120    /// Value name → index into `types`.
121    value_index: HashMap<String, usize>,
122}
123
124/// One value position's current type resolution + back-edges to
125/// relations that depend on it.
126struct TypeNodeSlot {
127    /// Current best-known resolution. `&TYPE_ANY` until a relation
128    /// narrows it. Refinement only proceeds toward MORE specific
129    /// types (down the lattice).
130    resolved: &'static TypeNode,
131    /// Relations participating in this slot. Populated at solver
132    /// construction; consulted when the slot refines to requeue
133    /// dependents.
134    rel_set: Vec<usize>,
135}
136
137/// One instantiated [`TypeRelation`] linked to its participating
138/// type slots.
139struct RelationNode {
140    /// The relation declaration (from the op's atomic_opset).
141    decl: &'static TypeRelation,
142    /// Op name (for diagnostics).
143    op_name: String,
144    /// Type-slot indices participating in this relation, in
145    /// declaration order. Length matches the relation variant's
146    /// port count (e.g. 2 for `Elementwise{input, output}`).
147    slots: Vec<usize>,
148    /// `true` once the relation reports `Satisfied` and is
149    /// permanently removed from the worklist.
150    satisfied: bool,
151}
152
153impl TypeSolver {
154    /// Build a fresh solver from a `GraphProto`. Walks every node,
155    /// allocates slots for every value name, instantiates relations
156    /// per the op's `AtomicOpDecl.type_relations`.
157    ///
158    /// `decl_for_op` lets the caller plug in their own
159    /// `(domain, op_type) -> &AtomicOpDecl` lookup (typically the
160    /// compiler's registered opset catalog).
161    pub fn from_graph(
162        graph: &GraphProto,
163        decl_for_op: impl Fn(&str, &str) -> Option<&'static bb_ir::atomic::AtomicOpDecl>,
164    ) -> Result<Self, TypeError> {
165        let mut solver = Self {
166            types: Vec::new(),
167            relations: Vec::new(),
168            value_index: HashMap::new(),
169        };
170
171        // First pass: allocate a type slot for every value name
172        // (graph inputs, then every op's outputs).
173        for input in &graph.input {
174            solver.intern_value(&input.name);
175        }
176        for node in &graph.node {
177            for out in &node.output {
178                if !out.is_empty() {
179                    solver.intern_value(out);
180                }
181            }
182            for inp in &node.input {
183                if !inp.is_empty() {
184                    solver.intern_value(inp);
185                }
186            }
187        }
188
189        // Second pass: for each NodeProto, instantiate the relations
190        // declared on its AtomicOpDecl.
191        for node in &graph.node {
192            let Some(decl) = decl_for_op(&node.domain, &node.op_type) else {
193                // No declared opset entry - skip. Unknown ops fall
194                // through; resolve_dispatch will catch them downstream.
195                continue;
196            };
197            for relation in decl.type_relations {
198                let slots = solver.resolve_relation_ports(node, relation)?;
199                let rel_idx = solver.relations.len();
200                solver.relations.push(RelationNode {
201                    decl: relation,
202                    op_name: format!("{}::{}", node.domain, node.op_type),
203                    slots: slots.clone(),
204                    satisfied: false,
205                });
206                for s in slots {
207                    solver.types[s].rel_set.push(rel_idx);
208                }
209            }
210        }
211
212        Ok(solver)
213    }
214
215    fn intern_value(&mut self, name: &str) -> usize {
216        if let Some(&idx) = self.value_index.get(name) {
217            return idx;
218        }
219        let idx = self.types.len();
220        self.types.push(TypeNodeSlot {
221            resolved: &TYPE_ANY,
222            rel_set: Vec::new(),
223        });
224        self.value_index.insert(name.to_string(), idx);
225        idx
226    }
227
228    /// Resolve `relation`'s [`PortRef`]s against the NodeProto's
229    /// input/output lists; return one type-slot index per
230    /// declared port.
231    fn resolve_relation_ports(
232        &mut self,
233        node: &bb_ir::proto::onnx::NodeProto,
234        relation: &TypeRelation,
235    ) -> Result<Vec<usize>, TypeError> {
236        let ports: Vec<PortRef> = match relation {
237            TypeRelation::SameType(p) | TypeRelation::SameElementType(p) => p.to_vec(),
238            TypeRelation::Elementwise { input, output } => vec![*input, *output],
239            TypeRelation::BroadcastShape { in0, in1, out } => vec![*in0, *in1, *out],
240            TypeRelation::ReduceOver { input, output } => vec![*input, *output],
241            TypeRelation::Custom { .. } => Vec::new(),
242        };
243
244        let op_name = format!("{}::{}", node.domain, node.op_type);
245        let mut slots = Vec::with_capacity(ports.len());
246        for port in ports {
247            let value_name = match port {
248                PortRef::Input(i) => node.input.get(i as usize).cloned(),
249                PortRef::Output(o) => node.output.get(o as usize).cloned(),
250            };
251            let Some(name) = value_name else {
252                return Err(TypeError::PortOutOfRange { op: op_name, port });
253            };
254            if name.is_empty() {
255                return Err(TypeError::PortOutOfRange { op: op_name, port });
256            }
257            slots.push(self.intern_value(&name));
258        }
259        Ok(slots)
260    }
261
262    /// Seed a value's type with a concrete (or narrower-than-Any)
263    /// TypeNode. Used by callers that know specific inputs' types
264    /// upfront (e.g. an AppEvent feeding a `Tensor<F32>`).
265    pub fn seed(&mut self, value: &str, node: &'static TypeNode) {
266        if let Some(&idx) = self.value_index.get(value) {
267            self.types[idx].resolved = node;
268        }
269    }
270
271    /// Walk `graph.input` + `graph.value_info` and seed every value
272    /// whose `ValueInfoProto.type.denotation` maps to a built-in
273    /// TypeNode (via [`bb_ir::types::builtins::lookup_denotation`]).
274    /// Values with unknown denotations are left at `TYPE_ANY`; the
275    /// solver narrows them via relations during `solve()`.
276    ///
277    /// Per the architecture's polymorphic-type contract, the DSL's
278    /// `Graph::input(name)` records each input with the
279    /// `ai.bytesandbrains.opaque` denotation (→ `TYPE_ANY`). Wire
280    /// op outputs + framework-recorded values carry pinned
281    /// denotations the lookup recognizes; that pinning seeds the
282    /// solver with concrete-leaf TypeNodes from which the
283    /// relation network propagates.
284    pub fn seed_from_value_info(&mut self, graph: &GraphProto) {
285        for vi in graph.input.iter().chain(graph.value_info.iter()) {
286            let Some(type_proto) = vi.r#type.as_ref() else {
287                continue;
288            };
289            let denotation = type_proto.denotation.as_str();
290            if denotation.is_empty() {
291                continue;
292            }
293            if let Some(node) = bb_ir::types::builtins::lookup_denotation(denotation) {
294                self.seed(&vi.name, node);
295            }
296        }
297    }
298
299    /// Run the worklist to fixpoint, then post-check that every
300    /// slot resolved to a concrete leaf.
301    pub fn solve(mut self) -> Result<TypeSolution, TypeError> {
302        // Initial worklist = every relation.
303        let mut worklist: std::collections::VecDeque<usize> = (0..self.relations.len()).collect();
304
305        while let Some(rel_idx) = worklist.pop_front() {
306            if self.relations[rel_idx].satisfied {
307                continue;
308            }
309            let outcome = self.run_relation(rel_idx)?;
310            match outcome {
311                RelationResult::Refined => {
312                    // Requeue dependents of any participating slot.
313                    let slots = self.relations[rel_idx].slots.clone();
314                    for s in slots {
315                        for &dep in &self.types[s].rel_set {
316                            if dep != rel_idx && !self.relations[dep].satisfied {
317                                worklist.push_back(dep);
318                            }
319                        }
320                    }
321                }
322                RelationResult::Satisfied => {
323                    self.relations[rel_idx].satisfied = true;
324                }
325                RelationResult::Defer => {
326                    // Don't requeue automatically; we'll come back
327                    // when a participating slot refines.
328                }
329                RelationResult::Failed(detail) => {
330                    return Err(TypeError::ConstraintFailed {
331                        op: self.relations[rel_idx].op_name.clone(),
332                        detail: detail.to_string(),
333                    });
334                }
335            }
336        }
337
338        // Post-check: every slot must be a concrete leaf.
339        let mut by_value: HashMap<String, &'static TypeNode> = HashMap::new();
340        for (name, &idx) in &self.value_index {
341            let node = self.types[idx].resolved;
342            // Allow unresolved (Any) entries to pass through silently
343            // - callers may want a partial solution for diagnostics.
344            // Hard error happens only if we WERE supposed to resolve.
345            by_value.insert(name.clone(), node);
346        }
347        Ok(TypeSolution { by_value })
348    }
349
350    /// Stamp `solution`'s resolved TypeNodes back onto every
351    /// matching `ValueInfoProto.type.denotation` in `graph`.
352    /// Downstream passes + the runtime read the narrowed
353    /// denotation instead of the recorder's
354    /// `ai.bytesandbrains.opaque` placeholder.
355    ///
356    /// Unresolved (still-`TYPE_ANY`) entries are left as-is —
357    /// they keep their original denotation. Permissive mode
358    /// surfaces here as silent pass-through; strict mode is the
359    /// caller's choice via `solve_strict()` BEFORE this is called.
360    pub fn apply_solution_to_value_info(graph: &mut GraphProto, solution: &TypeSolution) {
361        for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
362            let Some(node) = solution.type_of(&vi.name) else {
363                continue;
364            };
365            if node.is_abstract() {
366                continue;
367            }
368            let denotation = type_node_to_denotation(node);
369            if denotation.is_empty() {
370                continue;
371            }
372            if let Some(type_proto) = vi.r#type.as_mut() {
373                type_proto.denotation = denotation.to_string();
374            }
375        }
376    }
377
378    /// Strict-mode solve: every slot MUST resolve to a concrete leaf.
379    /// Returns `UnresolvedType` on the first abstract slot.
380    pub fn solve_strict(self) -> Result<TypeSolution, TypeError> {
381        let solution = self.solve()?;
382        for (name, node) in &solution.by_value {
383            if node.is_abstract() {
384                return Err(TypeError::UnresolvedType {
385                    value: name.clone(),
386                });
387            }
388        }
389        Ok(solution)
390    }
391
392    /// Run one relation, return its outcome. The match dispatches
393    /// to the per-variant handler.
394    fn run_relation(&mut self, idx: usize) -> Result<RelationResult, TypeError> {
395        let slots = self.relations[idx].slots.clone();
396        let decl = self.relations[idx].decl;
397
398        let outcome = match decl {
399            TypeRelation::SameType(_) => self.run_same_type(&slots),
400            TypeRelation::SameElementType(_) => self.run_same_element_type(&slots),
401            TypeRelation::Elementwise { .. } => self.run_elementwise(&slots),
402            TypeRelation::BroadcastShape { .. } => self.run_broadcast_shape(&slots),
403            TypeRelation::ReduceOver { .. } => self.run_reduce_over(&slots),
404            TypeRelation::Custom { run, .. } => {
405                // Custom relations are not yet implemented;
406                // defer until `CustomRelationCtx` has a real shape.
407                let _ = run;
408                Ok(RelationResult::Defer)
409            }
410        }?;
411
412        Ok(outcome)
413    }
414
415    // ---- Per-relation handlers ----------------------------------
416
417    /// `SameType` - every listed slot collapses to ONE concrete
418    /// TypeNode. Implementation: take the FIRST concrete resolution
419    /// among participants; narrow every other participant to match.
420    fn run_same_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
421        let pivot: Option<&'static TypeNode> = slots
422            .iter()
423            .map(|&s| self.types[s].resolved)
424            .find(|n| n.is_concrete());
425        let Some(pivot) = pivot else {
426            return Ok(RelationResult::Defer);
427        };
428        let mut refined = false;
429        for &s in slots {
430            let cur = self.types[s].resolved;
431            if std::ptr::eq(cur, pivot) {
432                continue;
433            }
434            // Allow refinement if the current bound is abstract +
435            // pivot is a subtype.
436            if cur.is_abstract() && pivot.is_subtype_of(cur) {
437                self.types[s].resolved = pivot;
438                refined = true;
439            } else {
440                return Ok(RelationResult::Failed(
441                    "SameType: incompatible concrete types",
442                ));
443            }
444        }
445        Ok(if refined {
446            RelationResult::Refined
447        } else {
448            RelationResult::Satisfied
449        })
450    }
451
452    /// `SameElementType` — every Tensor-typed slot shares an
453    /// element type. Currently treated as `SameType` (shape not yet
454    /// tracked); will tighten once explicit shape constraints land.
455    fn run_same_element_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
456        self.run_same_type(slots)
457    }
458
459    /// `Elementwise` - output's TypeNode equals input's. Shape
460    /// preserved (when shape tracking lands).
461    fn run_elementwise(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
462        // slots[0] = input, slots[1] = output
463        let inp = self.types[slots[0]].resolved;
464        let out = self.types[slots[1]].resolved;
465        if inp.is_concrete() && std::ptr::eq(inp, out) {
466            return Ok(RelationResult::Satisfied);
467        }
468        if inp.is_concrete() && out.is_abstract() && inp.is_subtype_of(out) {
469            self.types[slots[1]].resolved = inp;
470            return Ok(RelationResult::Refined);
471        }
472        if out.is_concrete() && inp.is_abstract() && out.is_subtype_of(inp) {
473            self.types[slots[0]].resolved = out;
474            return Ok(RelationResult::Refined);
475        }
476        if inp.is_concrete() && out.is_concrete() && !std::ptr::eq(inp, out) {
477            return Ok(RelationResult::Failed("Elementwise: input != output"));
478        }
479        Ok(RelationResult::Defer)
480    }
481
482    /// `BroadcastShape` — element types unify, output's shape is
483    /// the broadcast of the two inputs'. Currently defers to
484    /// element-type unification only (shape tracking is not yet
485    /// implemented).
486    fn run_broadcast_shape(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
487        // slots[0] = in0, slots[1] = in1, slots[2] = out
488        self.run_same_element_type(&[slots[0], slots[1], slots[2]])
489    }
490
491    /// `ReduceOver` - output's element type = input's element type.
492    fn run_reduce_over(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
493        self.run_elementwise(slots)
494    }
495}
496
497/// Inverse of [`bb_ir::types::builtins::lookup_denotation`] — map
498/// a built-in `TypeNode` back to the canonical denotation string
499/// the DSL records on `ValueInfoProto.denotation`. Returns the
500/// empty string for nodes without a known denotation (custom
501/// types extending the lattice via inventory submission can carry
502/// their own denotations; this helper covers the framework
503/// canon).
504fn type_node_to_denotation(node: &'static TypeNode) -> &'static str {
505    use bb_ir::types::builtins as B;
506    if std::ptr::eq(node, &B::TYPE_TENSOR_F32) {
507        return "ai.bytesandbrains.tensor.f32";
508    }
509    if std::ptr::eq(node, &B::TYPE_TENSOR_F64) {
510        return "ai.bytesandbrains.tensor.f64";
511    }
512    if std::ptr::eq(node, &B::TYPE_TENSOR_F16) {
513        return "ai.bytesandbrains.tensor.f16";
514    }
515    if std::ptr::eq(node, &B::TYPE_TENSOR_U8) {
516        return "ai.bytesandbrains.tensor.u8";
517    }
518    if std::ptr::eq(node, &B::TYPE_TENSOR_I32) {
519        return "ai.bytesandbrains.tensor.i32";
520    }
521    if std::ptr::eq(node, &B::TYPE_SCALAR_F32) {
522        return "bb.f32";
523    }
524    if std::ptr::eq(node, &B::TYPE_SCALAR_F64) {
525        return "bb.f64";
526    }
527    if std::ptr::eq(node, &B::TYPE_SCALAR_F16) {
528        return "bb.f16";
529    }
530    if std::ptr::eq(node, &B::TYPE_SCALAR_U8) {
531        return "bb.u8";
532    }
533    if std::ptr::eq(node, &B::TYPE_SCALAR_I32) {
534        return "bb.i32";
535    }
536    if std::ptr::eq(node, &B::TYPE_PEER_ID) {
537        return "bb.peer_id";
538    }
539    if std::ptr::eq(node, &B::TYPE_PEER_ID_VEC) {
540        return "bb.peer_id_vec";
541    }
542    if std::ptr::eq(node, &B::TYPE_TRIGGER) {
543        return "bb.trigger";
544    }
545    if std::ptr::eq(node, &B::TYPE_WIRE_REQ_ID) {
546        return "bb.wire_req_id";
547    }
548    ""
549}
550