Skip to main content

rlx_ir/
graph.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! The computation graph — a DAG of typed tensor operations.
17//!
18//! Graphs are append-only during construction (like SSA). Nodes reference
19//! inputs by [`NodeId`], forming a directed acyclic graph. The graph
20//! owns all nodes and provides traversal, printing, and validation.
21
22use crate::{Op, Shape};
23
24use crate::provenance::NodeOrigin;
25
26/// Stable identifier for a node in the graph. Indices are never reused.
27#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
29pub struct NodeId(pub u32);
30
31impl std::fmt::Display for NodeId {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "%{}", self.0)
34    }
35}
36
37/// A single node in the computation graph.
38#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
39#[derive(Debug, Clone)]
40pub struct Node {
41    pub id: NodeId,
42    /// The operation this node performs.
43    pub op: Op,
44    /// Input node IDs (operands). Order matches `Op::num_inputs()`.
45    pub inputs: Vec<NodeId>,
46    /// Output tensor shape (computed at construction time).
47    pub shape: Shape,
48    /// Human-readable name for debugging.
49    pub name: Option<String>,
50    /// Cross-stage provenance (HIR block, fusion pass, …).
51    pub origin: Option<NodeOrigin>,
52}
53
54/// A computation graph — the core IR data structure.
55///
56/// # Example
57/// ```
58/// use rlx_ir::*;
59///
60/// let mut g = Graph::new("bert_layer");
61///
62/// // Inputs
63/// let x = g.input("hidden", Shape::new(&[4, 15, 384], DType::F32));
64/// let w = g.param("qkv_weight", Shape::new(&[384, 1152], DType::F32));
65/// let b = g.param("qkv_bias", Shape::new(&[1152], DType::F32));
66///
67/// // QKV projection: matmul + bias
68/// let mm = g.matmul(x, w, Shape::new(&[4, 15, 1152], DType::F32));
69/// let qkv = g.binary(op::BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1152], DType::F32));
70///
71/// assert_eq!(g.len(), 5);
72/// println!("{g}");
73/// ```
74#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
75#[derive(Clone, Debug)]
76pub struct Graph {
77    pub name: String,
78    nodes: Vec<Node>,
79    /// Output node IDs (the graph's results).
80    pub outputs: Vec<NodeId>,
81}
82
83// Subgraph equality is structural: same name, same node count, same outputs.
84// Full deep equality would require comparing every node and is rarely useful;
85// this gives Op derives `PartialEq` cheap structural comparison.
86impl PartialEq for Graph {
87    fn eq(&self, other: &Self) -> bool {
88        self.name == other.name
89            && self.nodes.len() == other.nodes.len()
90            && self.outputs == other.outputs
91    }
92}
93
94impl Graph {
95    pub fn new(name: impl Into<String>) -> Self {
96        Self {
97            name: name.into(),
98            nodes: Vec::new(),
99            outputs: Vec::new(),
100        }
101    }
102
103    /// Number of nodes in the graph.
104    pub fn len(&self) -> usize {
105        self.nodes.len()
106    }
107    pub fn is_empty(&self) -> bool {
108        self.nodes.is_empty()
109    }
110
111    /// Get a node by ID.
112    pub fn node(&self, id: NodeId) -> &Node {
113        &self.nodes[id.0 as usize]
114    }
115
116    /// Iterate all nodes in topological order (insertion order = topo order).
117    pub fn nodes(&self) -> &[Node] {
118        &self.nodes
119    }
120
121    /// Get the shape of a node's output.
122    pub fn shape(&self, id: NodeId) -> &Shape {
123        &self.nodes[id.0 as usize].shape
124    }
125
126    /// Set the graph outputs.
127    pub fn set_outputs(&mut self, outputs: Vec<NodeId>) {
128        self.outputs = outputs;
129    }
130
131    /// Replace the input list of a node in place. Used by post-
132    /// construction passes (`quant_propagate`, `dce`, etc.) that
133    /// rewire consumers without inserting new nodes.
134    /// Caller is responsible for shape consistency — this does no
135    /// re-inference.
136    pub fn set_inputs(&mut self, id: NodeId, inputs: Vec<NodeId>) {
137        self.nodes[id.0 as usize].inputs = inputs;
138    }
139
140    pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
141        &mut self.nodes[id.0 as usize]
142    }
143
144    pub fn nodes_mut(&mut self) -> &mut [Node] {
145        &mut self.nodes
146    }
147
148    // ── Node constructors ───────────────────────────────────────
149
150    /// Append a node to the graph. `pub(crate)` so per-op builder
151    /// files in `rlx_ir::ops::*` can call it (plan #53).
152    /// Append a node for backend graph slicing (e.g. TPU HLO segments).
153    pub fn append_node(
154        &mut self,
155        op: Op,
156        inputs: Vec<NodeId>,
157        shape: Shape,
158        name: Option<String>,
159    ) -> NodeId {
160        self.push(op, inputs, shape, name)
161    }
162
163    pub(crate) fn push(
164        &mut self,
165        op: Op,
166        inputs: Vec<NodeId>,
167        shape: Shape,
168        name: Option<String>,
169    ) -> NodeId {
170        self.push_ext(op, inputs, shape, name, None)
171    }
172
173    pub(crate) fn push_ext(
174        &mut self,
175        op: Op,
176        inputs: Vec<NodeId>,
177        shape: Shape,
178        name: Option<String>,
179        origin: Option<NodeOrigin>,
180    ) -> NodeId {
181        let id = NodeId(self.nodes.len() as u32);
182        self.nodes.push(Node {
183            id,
184            op,
185            inputs,
186            shape,
187            name,
188            origin,
189        });
190        id
191    }
192
193    // Per-op builders moved to `crate::ops::*` (plan #53).
194    // Adding new op families = drop a new file in `ops/`, no edits here.
195
196    // ── Analysis helpers ────────────────────────────────────────
197
198    /// Find all nodes that use a given node's output.
199    pub fn users(&self, id: NodeId) -> Vec<NodeId> {
200        self.nodes
201            .iter()
202            .filter(|n| n.inputs.contains(&id))
203            .map(|n| n.id)
204            .collect()
205    }
206
207    /// Count how many nodes use a given node's output.
208    pub fn use_count(&self, id: NodeId) -> usize {
209        self.nodes.iter().filter(|n| n.inputs.contains(&id)).count()
210    }
211
212    /// Topological order (already guaranteed by construction — just node indices).
213    pub fn topo_order(&self) -> impl Iterator<Item = NodeId> + '_ {
214        (0..self.nodes.len()).map(|i| NodeId(i as u32))
215    }
216
217    /// Reverse topological order (outputs first).
218    pub fn reverse_topo(&self) -> impl Iterator<Item = NodeId> + '_ {
219        (0..self.nodes.len()).rev().map(|i| NodeId(i as u32))
220    }
221
222    // ── HIR / MIR / LIR pipeline (higher-order DX) ─────────────────
223
224    /// Fusion-first model definition at HIR level.
225    ///
226    /// Returns a [`GraphModule`] at HIR stage; call [`GraphModule::lower`]
227    /// or pass to [`rlx_opt::CompilePipeline::compile_module`].
228    pub fn define(
229        name: impl Into<String>,
230        build: impl FnOnce(&mut crate::hir::HirModule) -> crate::hir::HirNodeId,
231    ) -> crate::GraphModule {
232        crate::GraphModule::define(name, build)
233    }
234
235    /// Start an empty HIR-stage [`GraphModule`].
236    pub fn hir(name: impl Into<String>) -> crate::GraphModule {
237        crate::GraphModule::hir(name)
238    }
239
240    /// Wrap this MIR graph in a [`GraphModule`] for pipeline operations.
241    pub fn module(self) -> crate::GraphModule {
242        crate::GraphModule::from_graph(self)
243    }
244
245    /// Lower a HIR module to a MIR graph.
246    pub fn from_hir(hir: crate::hir::HirModule) -> Result<Self, crate::hir::LowerError> {
247        hir.lower_to_mir().map(|m| m.into_graph())
248    }
249
250    /// View as [`MirModule`].
251    pub fn to_mir(self) -> crate::MirModule {
252        crate::MirModule::from_graph(self)
253    }
254
255    /// Extract the MIR graph from optimized LIR.
256    pub fn from_lir(lir: crate::LirModule) -> Self {
257        lir.into_graph()
258    }
259
260    /// Annotated text dump ([`inspect_graph`]).
261    pub fn inspect(&self) -> String {
262        crate::inspect_graph(self)
263    }
264
265    /// True if any node shape uses a [`Dim::Dynamic`] symbol.
266    pub fn has_dynamic_dims(&self) -> bool {
267        crate::dynamic::has_dynamic_dims(self)
268    }
269
270    /// All dynamic symbols referenced in this graph.
271    pub fn dynamic_symbols(&self) -> Vec<u32> {
272        crate::dynamic::collect_dynamic_symbols(self)
273    }
274
275    /// Specialize symbolic dims to concrete sizes.
276    pub fn bind(&self, bindings: &crate::DimBinding) -> Self {
277        crate::dynamic::bind_graph(self, bindings)
278    }
279
280    /// Stage-aware dump when wrapped in [`GraphModule`].
281    pub fn inspect_module(module: &crate::GraphModule) -> String {
282        module.inspect()
283    }
284}
285
286/// Pretty-print the graph in a readable IR format.
287impl std::fmt::Display for Graph {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        writeln!(f, "graph @{} {{", self.name)?;
290        for node in &self.nodes {
291            write!(f, "  {} = {}", node.id, node.op)?;
292            if !node.inputs.is_empty() {
293                write!(f, "(")?;
294                for (i, inp) in node.inputs.iter().enumerate() {
295                    if i > 0 {
296                        write!(f, ", ")?;
297                    }
298                    write!(f, "{inp}")?;
299                }
300                write!(f, ")")?;
301            }
302            writeln!(f, " : {}", node.shape)?;
303        }
304        if !self.outputs.is_empty() {
305            write!(f, "  return ")?;
306            for (i, o) in self.outputs.iter().enumerate() {
307                if i > 0 {
308                    write!(f, ", ")?;
309                }
310                write!(f, "{o}")?;
311            }
312            writeln!(f)?;
313        }
314        writeln!(f, "}}")
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::{
322        DType,
323        op::{Activation, BinaryOp},
324    };
325
326    #[test]
327    fn build_simple_graph() {
328        let mut g = Graph::new("test");
329
330        let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
331        let w = g.param("weight", Shape::new(&[384, 1536], DType::F32));
332        let b = g.param("bias", Shape::new(&[1536], DType::F32));
333
334        let mm = g.matmul(x, w, Shape::new(&[4, 15, 1536], DType::F32));
335        let add = g.binary(BinaryOp::Add, mm, b, Shape::new(&[4, 15, 1536], DType::F32));
336        let out = g.activation(
337            Activation::Gelu,
338            add,
339            Shape::new(&[4, 15, 1536], DType::F32),
340        );
341
342        g.set_outputs(vec![out]);
343
344        assert_eq!(g.len(), 6);
345        assert_eq!(g.use_count(mm), 1); // matmul used by add
346        assert_eq!(g.use_count(x), 1); // x used by matmul
347
348        let printed = format!("{g}");
349        assert!(printed.contains("matmul(%0, %1)"));
350        assert!(printed.contains("Gelu(%4)"));
351        assert!(printed.contains("return %5"));
352    }
353
354    /// Build a BERT layer to verify the IR can represent real models.
355    #[test]
356    fn bert_layer_graph() {
357        let mut g = Graph::new("bert_layer");
358        let f = DType::F32;
359        let h = 384;
360        let int = 1536;
361
362        // Input
363        let x = g.input("hidden", Shape::new(&[4, 15, h], f));
364
365        // QKV
366        let qkv_w = g.param("qkv.weight", Shape::new(&[h, 3 * h], f));
367        let qkv_b = g.param("qkv.bias", Shape::new(&[3 * h], f));
368        let qkv = g.matmul(x, qkv_w, Shape::new(&[4, 15, 3 * h], f));
369        let _qkv = g.binary(BinaryOp::Add, qkv, qkv_b, Shape::new(&[4, 15, 3 * h], f));
370
371        // (would split Q/K/V, attention, out_proj here — simplified)
372
373        // FFN
374        let int_w = g.param("ffn.weight", Shape::new(&[h, int], f));
375        let int_b = g.param("ffn.bias", Shape::new(&[int], f));
376        let ffn = g.matmul(x, int_w, Shape::new(&[4, 15, int], f));
377        let ffn = g.binary(BinaryOp::Add, ffn, int_b, Shape::new(&[4, 15, int], f));
378        let ffn = g.activation(Activation::Gelu, ffn, Shape::new(&[4, 15, int], f));
379
380        let out_w = g.param("ffn_out.weight", Shape::new(&[int, h], f));
381        let ffn_out = g.matmul(ffn, out_w, Shape::new(&[4, 15, h], f));
382
383        g.set_outputs(vec![ffn_out]);
384
385        assert!(g.len() > 10);
386        println!("{g}");
387    }
388}