bb_ir/types/relations.rs
1//! Type relations on ops.
2//!
3//! Every op declares its type contract via composable
4//! [`TypeRelation`]s — a small set of high-coverage predicates
5//! (`SameElementType`, `Elementwise`, `BroadcastShape`,
6//! `ReduceOver`) plus a `Custom` escape hatch. The compiler's
7//! TypeSolver walks the graph, instantiates each relation as a
8//! constraint node, and resolves every value's TypeNode via a
9//! bipartite worklist (TVM Relay shape).
10//!
11//! Coverage strategy: a library of trait predicates (MLIR pattern)
12//! handles ~90% of ops. The remaining ~10% (Reshape, Gather, Concat,
13//! anything with structural type effects) use `Custom`. Adding a
14//! new op = declaring its `type_relations` in `atomic_opset()`.
15
16use super::TypeNode;
17
18/// Reference to a port position on an op's input/output list.
19/// Indices into `AtomicOpDecl.inputs` / `outputs`.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub enum PortRef {
22 /// `inputs[index]` on the surrounding op.
23 Input(u8),
24 /// `outputs[index]` on the surrounding op.
25 Output(u8),
26}
27
28/// Outcome of running a relation against the solver's current type
29/// nodes. The solver's worklist treats each variant differently:
30/// `Refined` requeues dependents, `Satisfied` removes the relation,
31/// `Defer` parks it for later, `Failed` aborts the solve.
32#[derive(Debug)]
33pub enum RelationResult {
34 /// Made progress narrowing one or more type variables. Requeue
35 /// any relations sharing those types.
36 Refined,
37 /// Constraint fully satisfied. Remove from the worklist.
38 Satisfied,
39 /// Insufficient information today. Come back when something
40 /// else refines the participating types.
41 Defer,
42 /// Hard contradiction. The solver propagates the diagnostic
43 /// back as a build error.
44 Failed(&'static str),
45}
46
47/// One type relation declared on an op. The TypeSolver instantiates
48/// each as a constraint node linked to its participating type
49/// variables via back-edges.
50#[derive(Debug)]
51pub enum TypeRelation {
52 /// All listed ports share the SAME concrete TypeNode. Implements
53 /// Julia's "diagonal variable" rule - a port declared `Tensor`
54 /// that participates in `SameType([in0, in1, out0])` collapses
55 /// to ONE element type across all three positions, regardless
56 /// of the bound's permissiveness.
57 SameType(&'static [PortRef]),
58
59 /// All listed Tensor-typed ports share the same ELEMENT type.
60 /// Shapes may differ (broadcasting is a separate concern).
61 /// `Add(x: Tensor, y: Tensor) -> Tensor` uses this.
62 SameElementType(&'static [PortRef]),
63
64 /// The output is the broadcast of two tensor inputs. Composes
65 /// with `SameElementType` to express `Add` / `Mul` / `Sub` /
66 /// `Div` fully.
67 BroadcastShape {
68 /// First broadcast operand.
69 in0: PortRef,
70 /// Second broadcast operand.
71 in1: PortRef,
72 /// Output (shape = broadcast(in0.shape, in1.shape)).
73 out: PortRef,
74 },
75
76 /// Output preserves the input's TypeNode entirely. Used by
77 /// element-wise unary ops (`Sqrt`, `Neg`, `Abs`, `Relu`, etc.):
78 /// shape preserved, element type preserved.
79 Elementwise {
80 /// Input.
81 input: PortRef,
82 /// Output.
83 output: PortRef,
84 },
85
86 /// Output is a reduction over the input: same element type,
87 /// reduced shape (driven by op attributes like `axes`).
88 /// `ReduceSum` / `ReduceMean` / `ReduceMax` use this.
89 ReduceOver {
90 /// Input tensor being reduced.
91 input: PortRef,
92 /// Output tensor (lower rank or same rank with size-1 axes).
93 output: PortRef,
94 },
95
96 /// Escape hatch for ops that don't fit a predicate. The custom
97 /// function receives the current TypeNodes for participating
98 /// ports and returns a [`RelationResult`].
99 ///
100 /// Use sparingly - `Reshape`, `Gather`, `Concat`, `Cast`, and
101 /// any op with attribute-driven type changes need this.
102 Custom {
103 /// Stable identifier for diagnostics.
104 name: &'static str,
105 /// Solver entry point. Receives the participating ports'
106 /// current resolutions (`Option<&TypeNode>`); narrows them
107 /// or returns `Failed`.
108 run: fn(&CustomRelationCtx<'_>) -> RelationResult,
109 },
110}
111
112/// Context passed to a `Custom` relation's `run` function. Borrows
113/// from the solver; exposes a read-only view of each participating
114/// port's current type resolution. Concrete shape lands when the
115/// TypeSolver (T4) materializes.
116#[derive(Debug)]
117pub struct CustomRelationCtx<'a> {
118 /// Solver-allocated handles for the ports this relation touches,
119 /// paired with their current best-known TypeNode (None = still
120 /// unresolved).
121 pub ports: &'a [(PortRef, Option<&'static TypeNode>)],
122}