Skip to main content

bb_ir/types/
relations.rs

1//! Type relations on ops.
2//!
3//! Every op declares its type contract via composable
4//! [`TypeRelation`]s — a small set of high-coverage predicates
5//! (`SameElementType`, `Elementwise`, `BroadcastShape`,
6//! `ReduceOver`) plus a `Custom` escape hatch. The compiler's
7//! TypeSolver walks the graph, instantiates each relation as a
8//! constraint node, and resolves every value's TypeNode via a
9//! bipartite worklist (TVM Relay shape).
10//!
11//! Coverage strategy: a library of trait predicates (MLIR pattern)
12//! handles ~90% of ops. The remaining ~10% (Reshape, Gather, Concat,
13//! anything with structural type effects) use `Custom`. Adding a
14//! new op = declaring its `type_relations` in `atomic_opset()`.
15
16use super::TypeNode;
17
18/// Reference to a port position on an op's input/output list.
19/// Indices into `AtomicOpDecl.inputs` / `outputs`.
20#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
21pub enum PortRef {
22    /// `inputs[index]` on the surrounding op.
23    Input(u8),
24    /// `outputs[index]` on the surrounding op.
25    Output(u8),
26}
27
28/// Outcome of running a relation against the solver's current type
29/// nodes. The solver's worklist treats each variant differently:
30/// `Refined` requeues dependents, `Satisfied` removes the relation,
31/// `Defer` parks it for later, `Failed` aborts the solve.
32#[derive(Debug)]
33pub enum RelationResult {
34    /// Made progress narrowing one or more type variables. Requeue
35    /// any relations sharing those types.
36    Refined,
37    /// Constraint fully satisfied. Remove from the worklist.
38    Satisfied,
39    /// Insufficient information today. Come back when something
40    /// else refines the participating types.
41    Defer,
42    /// Hard contradiction. The solver propagates the diagnostic
43    /// back as a build error.
44    Failed(&'static str),
45}
46
47/// One type relation declared on an op. The TypeSolver instantiates
48/// each as a constraint node linked to its participating type
49/// variables via back-edges.
50#[derive(Debug)]
51pub enum TypeRelation {
52    /// All listed ports share the SAME concrete TypeNode. Implements
53    /// Julia's "diagonal variable" rule - a port declared `Tensor`
54    /// that participates in `SameType([in0, in1, out0])` collapses
55    /// to ONE element type across all three positions, regardless
56    /// of the bound's permissiveness.
57    SameType(&'static [PortRef]),
58
59    /// All listed Tensor-typed ports share the same ELEMENT type.
60    /// Shapes may differ (broadcasting is a separate concern).
61    /// `Add(x: Tensor, y: Tensor) -> Tensor` uses this.
62    SameElementType(&'static [PortRef]),
63
64    /// The output is the broadcast of two tensor inputs. Composes
65    /// with `SameElementType` to express `Add` / `Mul` / `Sub` /
66    /// `Div` fully.
67    BroadcastShape {
68        /// First broadcast operand.
69        in0: PortRef,
70        /// Second broadcast operand.
71        in1: PortRef,
72        /// Output (shape = broadcast(in0.shape, in1.shape)).
73        out: PortRef,
74    },
75
76    /// Output preserves the input's TypeNode entirely. Used by
77    /// element-wise unary ops (`Sqrt`, `Neg`, `Abs`, `Relu`, etc.):
78    /// shape preserved, element type preserved.
79    Elementwise {
80        /// Input.
81        input: PortRef,
82        /// Output.
83        output: PortRef,
84    },
85
86    /// Output is a reduction over the input: same element type,
87    /// reduced shape (driven by op attributes like `axes`).
88    /// `ReduceSum` / `ReduceMean` / `ReduceMax` use this.
89    ReduceOver {
90        /// Input tensor being reduced.
91        input: PortRef,
92        /// Output tensor (lower rank or same rank with size-1 axes).
93        output: PortRef,
94    },
95
96    /// Escape hatch for ops that don't fit a predicate. The custom
97    /// function receives the current TypeNodes for participating
98    /// ports and returns a [`RelationResult`].
99    ///
100    /// Use sparingly - `Reshape`, `Gather`, `Concat`, `Cast`, and
101    /// any op with attribute-driven type changes need this.
102    Custom {
103        /// Stable identifier for diagnostics.
104        name: &'static str,
105        /// Solver entry point. Receives the participating ports'
106        /// current resolutions (`Option<&TypeNode>`); narrows them
107        /// or returns `Failed`.
108        run: fn(&CustomRelationCtx<'_>) -> RelationResult,
109    },
110}
111
112/// Context passed to a `Custom` relation's `run` function. Borrows
113/// from the solver; exposes a read-only view of each participating
114/// port's current type resolution. Concrete shape lands when the
115/// TypeSolver (T4) materializes.
116#[derive(Debug)]
117pub struct CustomRelationCtx<'a> {
118    /// Solver-allocated handles for the ports this relation touches,
119    /// paired with their current best-known TypeNode (None = still
120    /// unresolved).
121    pub ports: &'a [(PortRef, Option<&'static TypeNode>)],
122}