bb_compiler/type_solver.rs
1//! Type-resolution pass — bipartite worklist solver following
2//! TVM Relay's `type_solver.h` design, adapted to Rust.
3//!
4//! ## Shape
5//!
6//! Two arenas: **type nodes** (one per value position) and
7//! **relation nodes** (one per `TypeRelation` instance on each
8//! [`AtomicOpDecl`]). Cross-linked via `rel_set` back-edges. The
9//! worklist holds relations ready to (re)run.
10//!
11//! ## Algorithm
12//!
13//! 1. **Seed.** Allocate a type node for every value name in the
14//! graph (function inputs, op outputs). Mark each with its
15//! declared bound (`TYPE_ANY` if none).
16//! 2. **Instantiate relations.** For each NodeProto in the graph,
17//! look up its `AtomicOpDecl.type_relations` and allocate a
18//! relation node per declared [`TypeRelation`]. Each relation
19//! points at the type nodes for the ports it participates in.
20//! Type nodes track back-edges via `rel_set`.
21//! 3. **Drain.** Pop a relation from the worklist, run it.
22//! [`RelationResult`] dictates the next move:
23//! - `Refined` → requeue every relation in the refined type
24//! nodes' `rel_set`.
25//! - `Satisfied` → remove from the worklist permanently.
26//! - `Defer` → leave in the worklist (will retry only when
27//! something else refines a participating type).
28//! - `Failed` → abort with a `TypeError`.
29//! 4. **Fixpoint.** When the worklist is empty (or only `Defer`s
30//! remain that no new refinement could activate), check the
31//! post-condition: every type node resolves to a concrete leaf.
32//! Otherwise → `UnresolvedType`.
33//!
34//! ## Scope
35//!
36//! Currently handles [`TypeRelation::SameElementType`] and
37//! [`TypeRelation::Elementwise`] — the two highest-frequency
38//! relations covering most arithmetic + reduction ops. Other
39//! variants (`BroadcastShape`, `SameType`, `ReduceOver`, `Custom`)
40//! plug in by extending the per-variant handler match inside the
41//! solver's internal `run_relation` dispatch.
42
43use std::collections::HashMap;
44
45use bb_ir::proto::onnx::GraphProto;
46use bb_ir::types::{PortRef, RelationResult, TypeNode, TypeRelation, TYPE_ANY};
47
48/// Errors the solver may report.
49#[derive(Debug)]
50pub enum TypeError {
51 /// A relation produced a hard contradiction (e.g. two consumers
52 /// of the same port require incompatible element types).
53 ConstraintFailed {
54 /// Op the relation was attached to.
55 op: String,
56 /// Relation diagnostic string.
57 detail: String,
58 },
59 /// The solver reached fixpoint with type nodes still abstract
60 /// (i.e. not narrowed to a concrete leaf in the lattice).
61 UnresolvedType {
62 /// Value name with no concrete resolution.
63 value: String,
64 },
65 /// An op references a port index that doesn't map to a value
66 /// (out-of-range input/output position on `AtomicOpDecl`).
67 PortOutOfRange {
68 /// Op name.
69 op: String,
70 /// Failing port reference.
71 port: PortRef,
72 },
73}
74
75impl std::fmt::Display for TypeError {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 Self::ConstraintFailed { op, detail } => {
79 write!(f, "type constraint failed at {op}: {detail}")
80 }
81 Self::UnresolvedType { value } => {
82 write!(f, "value `{value}` did not resolve to a concrete type")
83 }
84 Self::PortOutOfRange { op, port } => {
85 write!(f, "op {op} references out-of-range port {port:?}")
86 }
87 }
88 }
89}
90
91impl std::error::Error for TypeError {}
92
93/// Solver output: every value name in the graph maps to its
94/// resolved concrete TypeNode.
95#[derive(Debug)]
96pub struct TypeSolution {
97 by_value: HashMap<String, &'static TypeNode>,
98}
99
100impl TypeSolution {
101 /// Resolved TypeNode for a value name. `None` if the solver
102 /// didn't see this value.
103 pub fn type_of(&self, value: &str) -> Option<&'static TypeNode> {
104 self.by_value.get(value).copied()
105 }
106
107 /// Iterate every resolved (value_name, TypeNode) pair.
108 pub fn iter(&self) -> impl Iterator<Item = (&str, &'static TypeNode)> {
109 self.by_value.iter().map(|(k, v)| (k.as_str(), *v))
110 }
111}
112
113/// Bipartite type-resolution solver.
114pub struct TypeSolver {
115 /// Per-value type nodes. Index = the slot's position in
116 /// [`Self::value_index`].
117 types: Vec<TypeNodeSlot>,
118 /// Per-relation constraint nodes.
119 relations: Vec<RelationNode>,
120 /// Value name → index into `types`.
121 value_index: HashMap<String, usize>,
122}
123
124/// One value position's current type resolution + back-edges to
125/// relations that depend on it.
126struct TypeNodeSlot {
127 /// Current best-known resolution. `&TYPE_ANY` until a relation
128 /// narrows it. Refinement only proceeds toward MORE specific
129 /// types (down the lattice).
130 resolved: &'static TypeNode,
131 /// Relations participating in this slot. Populated at solver
132 /// construction; consulted when the slot refines to requeue
133 /// dependents.
134 rel_set: Vec<usize>,
135}
136
137/// One instantiated [`TypeRelation`] linked to its participating
138/// type slots.
139struct RelationNode {
140 /// The relation declaration (from the op's atomic_opset).
141 decl: &'static TypeRelation,
142 /// Op name (for diagnostics).
143 op_name: String,
144 /// Type-slot indices participating in this relation, in
145 /// declaration order. Length matches the relation variant's
146 /// port count (e.g. 2 for `Elementwise{input, output}`).
147 slots: Vec<usize>,
148 /// `true` once the relation reports `Satisfied` and is
149 /// permanently removed from the worklist.
150 satisfied: bool,
151}
152
153impl TypeSolver {
154 /// Build a fresh solver from a `GraphProto`. Walks every node,
155 /// allocates slots for every value name, instantiates relations
156 /// per the op's `AtomicOpDecl.type_relations`.
157 ///
158 /// `decl_for_op` lets the caller plug in their own
159 /// `(domain, op_type) -> &AtomicOpDecl` lookup (typically the
160 /// compiler's registered opset catalog).
161 pub fn from_graph(
162 graph: &GraphProto,
163 decl_for_op: impl Fn(&str, &str) -> Option<&'static bb_ir::atomic::AtomicOpDecl>,
164 ) -> Result<Self, TypeError> {
165 let mut solver = Self {
166 types: Vec::new(),
167 relations: Vec::new(),
168 value_index: HashMap::new(),
169 };
170
171 // First pass: allocate a type slot for every value name
172 // (graph inputs, then every op's outputs).
173 for input in &graph.input {
174 solver.intern_value(&input.name);
175 }
176 for node in &graph.node {
177 for out in &node.output {
178 if !out.is_empty() {
179 solver.intern_value(out);
180 }
181 }
182 for inp in &node.input {
183 if !inp.is_empty() {
184 solver.intern_value(inp);
185 }
186 }
187 }
188
189 // Second pass: for each NodeProto, instantiate the relations
190 // declared on its AtomicOpDecl.
191 for node in &graph.node {
192 let Some(decl) = decl_for_op(&node.domain, &node.op_type) else {
193 // No declared opset entry - skip. Unknown ops fall
194 // through; resolve_dispatch will catch them downstream.
195 continue;
196 };
197 for relation in decl.type_relations {
198 let slots = solver.resolve_relation_ports(node, relation)?;
199 let rel_idx = solver.relations.len();
200 solver.relations.push(RelationNode {
201 decl: relation,
202 op_name: format!("{}::{}", node.domain, node.op_type),
203 slots: slots.clone(),
204 satisfied: false,
205 });
206 for s in slots {
207 solver.types[s].rel_set.push(rel_idx);
208 }
209 }
210 }
211
212 Ok(solver)
213 }
214
215 fn intern_value(&mut self, name: &str) -> usize {
216 if let Some(&idx) = self.value_index.get(name) {
217 return idx;
218 }
219 let idx = self.types.len();
220 self.types.push(TypeNodeSlot {
221 resolved: &TYPE_ANY,
222 rel_set: Vec::new(),
223 });
224 self.value_index.insert(name.to_string(), idx);
225 idx
226 }
227
228 /// Resolve `relation`'s [`PortRef`]s against the NodeProto's
229 /// input/output lists; return one type-slot index per
230 /// declared port.
231 fn resolve_relation_ports(
232 &mut self,
233 node: &bb_ir::proto::onnx::NodeProto,
234 relation: &TypeRelation,
235 ) -> Result<Vec<usize>, TypeError> {
236 let ports: Vec<PortRef> = match relation {
237 TypeRelation::SameType(p) | TypeRelation::SameElementType(p) => p.to_vec(),
238 TypeRelation::Elementwise { input, output } => vec![*input, *output],
239 TypeRelation::BroadcastShape { in0, in1, out } => vec![*in0, *in1, *out],
240 TypeRelation::ReduceOver { input, output } => vec![*input, *output],
241 TypeRelation::Custom { .. } => Vec::new(),
242 };
243
244 let op_name = format!("{}::{}", node.domain, node.op_type);
245 let mut slots = Vec::with_capacity(ports.len());
246 for port in ports {
247 let value_name = match port {
248 PortRef::Input(i) => node.input.get(i as usize).cloned(),
249 PortRef::Output(o) => node.output.get(o as usize).cloned(),
250 };
251 let Some(name) = value_name else {
252 return Err(TypeError::PortOutOfRange { op: op_name, port });
253 };
254 if name.is_empty() {
255 return Err(TypeError::PortOutOfRange { op: op_name, port });
256 }
257 slots.push(self.intern_value(&name));
258 }
259 Ok(slots)
260 }
261
262 /// Seed a value's type with a concrete (or narrower-than-Any)
263 /// TypeNode. Used by callers that know specific inputs' types
264 /// upfront (e.g. an AppEvent feeding a `Tensor<F32>`).
265 pub fn seed(&mut self, value: &str, node: &'static TypeNode) {
266 if let Some(&idx) = self.value_index.get(value) {
267 self.types[idx].resolved = node;
268 }
269 }
270
271 /// Walk `graph.input` + `graph.value_info` and seed every value
272 /// whose `ValueInfoProto.type.denotation` maps to a built-in
273 /// TypeNode (via [`bb_ir::types::builtins::lookup_denotation`]).
274 /// Values with unknown denotations are left at `TYPE_ANY`; the
275 /// solver narrows them via relations during `solve()`.
276 ///
277 /// Per the architecture's polymorphic-type contract, the DSL's
278 /// `Graph::input(name)` records each input with the
279 /// `ai.bytesandbrains.opaque` denotation (→ `TYPE_ANY`). Wire
280 /// op outputs + framework-recorded values carry pinned
281 /// denotations the lookup recognizes; that pinning seeds the
282 /// solver with concrete-leaf TypeNodes from which the
283 /// relation network propagates.
284 pub fn seed_from_value_info(&mut self, graph: &GraphProto) {
285 for vi in graph.input.iter().chain(graph.value_info.iter()) {
286 let Some(type_proto) = vi.r#type.as_ref() else {
287 continue;
288 };
289 let denotation = type_proto.denotation.as_str();
290 if denotation.is_empty() {
291 continue;
292 }
293 if let Some(node) = bb_ir::types::builtins::lookup_denotation(denotation) {
294 self.seed(&vi.name, node);
295 }
296 }
297 }
298
299 /// Run the worklist to fixpoint, then post-check that every
300 /// slot resolved to a concrete leaf.
301 pub fn solve(mut self) -> Result<TypeSolution, TypeError> {
302 // Initial worklist = every relation.
303 let mut worklist: std::collections::VecDeque<usize> = (0..self.relations.len()).collect();
304
305 while let Some(rel_idx) = worklist.pop_front() {
306 if self.relations[rel_idx].satisfied {
307 continue;
308 }
309 let outcome = self.run_relation(rel_idx)?;
310 match outcome {
311 RelationResult::Refined => {
312 // Requeue dependents of any participating slot.
313 let slots = self.relations[rel_idx].slots.clone();
314 for s in slots {
315 for &dep in &self.types[s].rel_set {
316 if dep != rel_idx && !self.relations[dep].satisfied {
317 worklist.push_back(dep);
318 }
319 }
320 }
321 }
322 RelationResult::Satisfied => {
323 self.relations[rel_idx].satisfied = true;
324 }
325 RelationResult::Defer => {
326 // Don't requeue automatically; we'll come back
327 // when a participating slot refines.
328 }
329 RelationResult::Failed(detail) => {
330 return Err(TypeError::ConstraintFailed {
331 op: self.relations[rel_idx].op_name.clone(),
332 detail: detail.to_string(),
333 });
334 }
335 }
336 }
337
338 // Post-check: every slot must be a concrete leaf.
339 let mut by_value: HashMap<String, &'static TypeNode> = HashMap::new();
340 for (name, &idx) in &self.value_index {
341 let node = self.types[idx].resolved;
342 // Allow unresolved (Any) entries to pass through silently
343 // - callers may want a partial solution for diagnostics.
344 // Hard error happens only if we WERE supposed to resolve.
345 by_value.insert(name.clone(), node);
346 }
347 Ok(TypeSolution { by_value })
348 }
349
350 /// Stamp `solution`'s resolved TypeNodes back onto every
351 /// matching `ValueInfoProto.type.denotation` in `graph`.
352 /// Downstream passes + the runtime read the narrowed
353 /// denotation instead of the recorder's
354 /// `ai.bytesandbrains.opaque` placeholder.
355 ///
356 /// Unresolved (still-`TYPE_ANY`) entries are left as-is —
357 /// they keep their original denotation. Permissive mode
358 /// surfaces here as silent pass-through; strict mode is the
359 /// caller's choice via `solve_strict()` BEFORE this is called.
360 pub fn apply_solution_to_value_info(graph: &mut GraphProto, solution: &TypeSolution) {
361 for vi in graph.input.iter_mut().chain(graph.value_info.iter_mut()) {
362 let Some(node) = solution.type_of(&vi.name) else {
363 continue;
364 };
365 if node.is_abstract() {
366 continue;
367 }
368 let denotation = type_node_to_denotation(node);
369 if denotation.is_empty() {
370 continue;
371 }
372 if let Some(type_proto) = vi.r#type.as_mut() {
373 type_proto.denotation = denotation.to_string();
374 }
375 }
376 }
377
378 /// Strict-mode solve: every slot MUST resolve to a concrete leaf.
379 /// Returns `UnresolvedType` on the first abstract slot.
380 pub fn solve_strict(self) -> Result<TypeSolution, TypeError> {
381 let solution = self.solve()?;
382 for (name, node) in &solution.by_value {
383 if node.is_abstract() {
384 return Err(TypeError::UnresolvedType {
385 value: name.clone(),
386 });
387 }
388 }
389 Ok(solution)
390 }
391
392 /// Run one relation, return its outcome. The match dispatches
393 /// to the per-variant handler.
394 fn run_relation(&mut self, idx: usize) -> Result<RelationResult, TypeError> {
395 let slots = self.relations[idx].slots.clone();
396 let decl = self.relations[idx].decl;
397
398 let outcome = match decl {
399 TypeRelation::SameType(_) => self.run_same_type(&slots),
400 TypeRelation::SameElementType(_) => self.run_same_element_type(&slots),
401 TypeRelation::Elementwise { .. } => self.run_elementwise(&slots),
402 TypeRelation::BroadcastShape { .. } => self.run_broadcast_shape(&slots),
403 TypeRelation::ReduceOver { .. } => self.run_reduce_over(&slots),
404 TypeRelation::Custom { run, .. } => {
405 // Custom relations are not yet implemented;
406 // defer until `CustomRelationCtx` has a real shape.
407 let _ = run;
408 Ok(RelationResult::Defer)
409 }
410 }?;
411
412 Ok(outcome)
413 }
414
415 // ---- Per-relation handlers ----------------------------------
416
417 /// `SameType` - every listed slot collapses to ONE concrete
418 /// TypeNode. Implementation: take the FIRST concrete resolution
419 /// among participants; narrow every other participant to match.
420 fn run_same_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
421 let pivot: Option<&'static TypeNode> = slots
422 .iter()
423 .map(|&s| self.types[s].resolved)
424 .find(|n| n.is_concrete());
425 let Some(pivot) = pivot else {
426 return Ok(RelationResult::Defer);
427 };
428 let mut refined = false;
429 for &s in slots {
430 let cur = self.types[s].resolved;
431 if std::ptr::eq(cur, pivot) {
432 continue;
433 }
434 // Allow refinement if the current bound is abstract +
435 // pivot is a subtype.
436 if cur.is_abstract() && pivot.is_subtype_of(cur) {
437 self.types[s].resolved = pivot;
438 refined = true;
439 } else {
440 return Ok(RelationResult::Failed(
441 "SameType: incompatible concrete types",
442 ));
443 }
444 }
445 Ok(if refined {
446 RelationResult::Refined
447 } else {
448 RelationResult::Satisfied
449 })
450 }
451
452 /// `SameElementType` — every Tensor-typed slot shares an
453 /// element type. Currently treated as `SameType` (shape not yet
454 /// tracked); will tighten once explicit shape constraints land.
455 fn run_same_element_type(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
456 self.run_same_type(slots)
457 }
458
459 /// `Elementwise` - output's TypeNode equals input's. Shape
460 /// preserved (when shape tracking lands).
461 fn run_elementwise(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
462 // slots[0] = input, slots[1] = output
463 let inp = self.types[slots[0]].resolved;
464 let out = self.types[slots[1]].resolved;
465 if inp.is_concrete() && std::ptr::eq(inp, out) {
466 return Ok(RelationResult::Satisfied);
467 }
468 if inp.is_concrete() && out.is_abstract() && inp.is_subtype_of(out) {
469 self.types[slots[1]].resolved = inp;
470 return Ok(RelationResult::Refined);
471 }
472 if out.is_concrete() && inp.is_abstract() && out.is_subtype_of(inp) {
473 self.types[slots[0]].resolved = out;
474 return Ok(RelationResult::Refined);
475 }
476 if inp.is_concrete() && out.is_concrete() && !std::ptr::eq(inp, out) {
477 return Ok(RelationResult::Failed("Elementwise: input != output"));
478 }
479 Ok(RelationResult::Defer)
480 }
481
482 /// `BroadcastShape` — element types unify, output's shape is
483 /// the broadcast of the two inputs'. Currently defers to
484 /// element-type unification only (shape tracking is not yet
485 /// implemented).
486 fn run_broadcast_shape(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
487 // slots[0] = in0, slots[1] = in1, slots[2] = out
488 self.run_same_element_type(&[slots[0], slots[1], slots[2]])
489 }
490
491 /// `ReduceOver` - output's element type = input's element type.
492 fn run_reduce_over(&mut self, slots: &[usize]) -> Result<RelationResult, TypeError> {
493 self.run_elementwise(slots)
494 }
495}
496
497/// Inverse of [`bb_ir::types::builtins::lookup_denotation`] — map
498/// a built-in `TypeNode` back to the canonical denotation string
499/// the DSL records on `ValueInfoProto.denotation`. Returns the
500/// empty string for nodes without a known denotation (custom
501/// types extending the lattice via inventory submission can carry
502/// their own denotations; this helper covers the framework
503/// canon).
504fn type_node_to_denotation(node: &'static TypeNode) -> &'static str {
505 use bb_ir::types::builtins as B;
506 if std::ptr::eq(node, &B::TYPE_TENSOR_F32) {
507 return "ai.bytesandbrains.tensor.f32";
508 }
509 if std::ptr::eq(node, &B::TYPE_TENSOR_F64) {
510 return "ai.bytesandbrains.tensor.f64";
511 }
512 if std::ptr::eq(node, &B::TYPE_TENSOR_F16) {
513 return "ai.bytesandbrains.tensor.f16";
514 }
515 if std::ptr::eq(node, &B::TYPE_TENSOR_U8) {
516 return "ai.bytesandbrains.tensor.u8";
517 }
518 if std::ptr::eq(node, &B::TYPE_TENSOR_I32) {
519 return "ai.bytesandbrains.tensor.i32";
520 }
521 if std::ptr::eq(node, &B::TYPE_SCALAR_F32) {
522 return "bb.f32";
523 }
524 if std::ptr::eq(node, &B::TYPE_SCALAR_F64) {
525 return "bb.f64";
526 }
527 if std::ptr::eq(node, &B::TYPE_SCALAR_F16) {
528 return "bb.f16";
529 }
530 if std::ptr::eq(node, &B::TYPE_SCALAR_U8) {
531 return "bb.u8";
532 }
533 if std::ptr::eq(node, &B::TYPE_SCALAR_I32) {
534 return "bb.i32";
535 }
536 if std::ptr::eq(node, &B::TYPE_PEER_ID) {
537 return "bb.peer_id";
538 }
539 if std::ptr::eq(node, &B::TYPE_PEER_ID_VEC) {
540 return "bb.peer_id_vec";
541 }
542 if std::ptr::eq(node, &B::TYPE_TRIGGER) {
543 return "bb.trigger";
544 }
545 if std::ptr::eq(node, &B::TYPE_WIRE_REQ_ID) {
546 return "bb.wire_req_id";
547 }
548 ""
549}
550