Skip to main content

rlx_ir/
module.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! [`GraphModule`] — unified higher-order DX over HIR / MIR / LIR.
17
18use std::ops::{Deref, DerefMut};
19
20use crate::hir::{FusionPolicy, HirModule, HirNodeId, LowerError};
21use crate::inspect::{inspect_hir, inspect_lir, inspect_mir};
22use crate::lir::LirModule;
23use crate::mir::MirModule;
24use crate::op::Activation;
25use crate::op::MaskKind;
26use crate::quant::QuantScheme;
27use crate::{Graph, NodeId, Op, Shape};
28
29/// Which stage of the HIR → MIR → LIR pipeline a [`GraphModule`] holds.
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum GraphStage {
32    Hir,
33    Mir,
34    Lir,
35}
36
37#[derive(Debug, Clone)]
38enum Stage {
39    Hir(HirModule),
40    Mir(MirModule),
41    Lir(LirModule),
42}
43
44/// Unified model module — primary builder surface above HIR/MIR/LIR.
45#[derive(Debug, Clone)]
46pub struct GraphModule {
47    stage: Stage,
48}
49
50impl GraphModule {
51    pub fn define(
52        name: impl Into<String>,
53        build: impl FnOnce(&mut HirModule) -> HirNodeId,
54    ) -> Self {
55        let mut hir = HirModule::new(name);
56        let out = build(&mut hir);
57        hir.set_outputs(vec![out]);
58        Self {
59            stage: Stage::Hir(hir),
60        }
61    }
62
63    /// Start an empty HIR-stage module (like [`Graph::new`] for MIR).
64    pub fn hir(name: impl Into<String>) -> Self {
65        Self {
66            stage: Stage::Hir(HirModule::new(name)),
67        }
68    }
69
70    /// Start an empty MIR-stage module.
71    pub fn mir(name: impl Into<String>) -> Self {
72        Self {
73            stage: Stage::Mir(MirModule::new(name)),
74        }
75    }
76
77    pub fn from_hir(hir: HirModule) -> Self {
78        Self {
79            stage: Stage::Hir(hir),
80        }
81    }
82
83    pub fn from_graph(graph: Graph) -> Self {
84        Self {
85            stage: Stage::Mir(MirModule::from_graph(graph)),
86        }
87    }
88
89    pub fn from_mir(mir: MirModule) -> Self {
90        Self {
91            stage: Stage::Mir(mir),
92        }
93    }
94
95    pub fn from_lir(lir: LirModule) -> Self {
96        Self {
97            stage: Stage::Lir(lir),
98        }
99    }
100
101    pub fn block(
102        hir: &mut HirModule,
103        name: impl Into<String>,
104        build: impl FnOnce(&mut HirModule) -> HirNodeId,
105    ) -> HirNodeId {
106        hir.named(name, build)
107    }
108
109    pub fn fusion_policy(&self) -> Option<FusionPolicy> {
110        self.as_hir().map(|h| h.fusion_policy)
111    }
112
113    pub fn with_fusion_policy(mut self, policy: FusionPolicy) -> Self {
114        if let Stage::Hir(h) = &mut self.stage {
115            h.fusion_policy = policy;
116        } else {
117            panic!("GraphModule::with_fusion_policy requires HIR stage");
118        }
119        self
120    }
121
122    /// Set graph outputs at the current stage.
123    ///
124    /// At HIR stage accepts [`HirNodeId`] values; at MIR/LIR the same
125    /// indices map to [`NodeId`] (both are insertion-order node ids).
126    pub fn set_outputs(&mut self, outputs: Vec<HirNodeId>) {
127        match &mut self.stage {
128            Stage::Hir(h) => h.set_outputs(outputs),
129            Stage::Mir(m) => m.set_outputs(outputs.into_iter().map(|h| NodeId(h.0)).collect()),
130            Stage::Lir(l) => l
131                .mir
132                .set_outputs(outputs.into_iter().map(|h| NodeId(h.0)).collect()),
133        }
134    }
135
136    pub fn set_hir_outputs(&mut self, outputs: Vec<HirNodeId>) {
137        self.set_outputs(outputs);
138    }
139
140    /// Finish HIR construction and set the module output.
141    pub fn finish_hir(mut self, output: HirNodeId) -> Self {
142        self.set_hir_outputs(vec![output]);
143        self
144    }
145
146    fn hir_mut(&mut self) -> &mut HirModule {
147        self.as_hir_mut()
148            .expect("GraphModule: HIR builder methods require HIR stage — use GraphModule::hir() or Graph::define()")
149    }
150
151    // ── HIR block builders (forward to HirModule) ─────────────────
152
153    pub fn input(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
154        match &mut self.stage {
155            Stage::Hir(h) => h.input(name, shape),
156            Stage::Mir(m) => {
157                let id = m.as_graph_mut().input(name, shape);
158                HirNodeId(id.0)
159            }
160            Stage::Lir(l) => {
161                let id = l.mir.as_graph_mut().input(name, shape);
162                HirNodeId(id.0)
163            }
164        }
165    }
166
167    pub fn param(&mut self, name: impl Into<String>, shape: Shape) -> HirNodeId {
168        match &mut self.stage {
169            Stage::Hir(h) => h.param(name, shape),
170            Stage::Mir(m) => {
171                let id = m.as_graph_mut().param(name, shape);
172                HirNodeId(id.0)
173            }
174            Stage::Lir(l) => {
175                let id = l.mir.as_graph_mut().param(name, shape);
176                HirNodeId(id.0)
177            }
178        }
179    }
180
181    pub fn linear(
182        &mut self,
183        x: HirNodeId,
184        weight: HirNodeId,
185        bias: Option<HirNodeId>,
186        activation: Option<Activation>,
187        out_shape: Shape,
188    ) -> HirNodeId {
189        self.hir_mut()
190            .linear(x, weight, bias, activation, out_shape)
191    }
192
193    pub fn linear_fused(
194        &mut self,
195        x: HirNodeId,
196        weight: HirNodeId,
197        bias: HirNodeId,
198        activation: Option<Activation>,
199        out_shape: Shape,
200    ) -> HirNodeId {
201        self.hir_mut()
202            .linear_fused(x, weight, bias, activation, out_shape)
203    }
204
205    pub fn shared_linear_pair(
206        &mut self,
207        x: HirNodeId,
208        w_first: HirNodeId,
209        w_second: HirNodeId,
210        out_shape: Shape,
211    ) -> (HirNodeId, HirNodeId) {
212        self.hir_mut()
213            .shared_linear_pair(x, w_first, w_second, out_shape)
214    }
215
216    pub fn swiglu_ffn(
217        &mut self,
218        x: HirNodeId,
219        up_w: HirNodeId,
220        gate_w: HirNodeId,
221        down_w: HirNodeId,
222        out_shape: Shape,
223    ) -> HirNodeId {
224        self.hir_mut()
225            .swiglu_ffn(x, up_w, gate_w, down_w, out_shape)
226    }
227
228    pub fn residual_rms_norm(
229        &mut self,
230        x: HirNodeId,
231        residual: HirNodeId,
232        gamma: HirNodeId,
233        beta: HirNodeId,
234        eps: f32,
235        out_shape: Shape,
236    ) -> HirNodeId {
237        self.hir_mut()
238            .residual_rms_norm(x, residual, gamma, beta, eps, out_shape)
239    }
240
241    pub fn attention(
242        &mut self,
243        q: HirNodeId,
244        k: HirNodeId,
245        v: HirNodeId,
246        mask: Option<HirNodeId>,
247        num_heads: usize,
248        head_dim: usize,
249        mask_kind: MaskKind,
250        out_shape: Shape,
251    ) -> HirNodeId {
252        self.hir_mut()
253            .attention(q, k, v, mask, num_heads, head_dim, mask_kind, out_shape)
254    }
255
256    pub fn depthwise_conv1d_causal(
257        &mut self,
258        input: HirNodeId,
259        weight: HirNodeId,
260        left_pad: HirNodeId,
261        kernel_size: usize,
262        out_shape: Shape,
263    ) -> HirNodeId {
264        self.hir_mut()
265            .depthwise_conv1d_causal(input, weight, left_pad, kernel_size, out_shape)
266    }
267
268    pub fn dequant_matmul(
269        &mut self,
270        x: HirNodeId,
271        w: HirNodeId,
272        scale: Option<HirNodeId>,
273        zp: Option<HirNodeId>,
274        scheme: QuantScheme,
275        out_shape: Shape,
276    ) -> HirNodeId {
277        self.hir_mut()
278            .dequant_matmul(x, w, scale, zp, scheme, out_shape)
279    }
280
281    pub fn gated_delta_net(
282        &mut self,
283        q: HirNodeId,
284        k: HirNodeId,
285        v: HirNodeId,
286        g: HirNodeId,
287        beta: HirNodeId,
288        state_size: usize,
289        out_shape: Shape,
290    ) -> HirNodeId {
291        self.hir_mut()
292            .gated_delta_net(q, k, v, g, beta, state_size, out_shape)
293    }
294
295    pub fn gated_delta_net_carry(
296        &mut self,
297        q: HirNodeId,
298        k: HirNodeId,
299        v: HirNodeId,
300        g: HirNodeId,
301        beta: HirNodeId,
302        state: HirNodeId,
303        state_size: usize,
304        out_shape: Shape,
305    ) -> HirNodeId {
306        self.hir_mut()
307            .gated_delta_net_carry(q, k, v, g, beta, state, state_size, out_shape)
308    }
309
310    pub fn rope(
311        &mut self,
312        x: HirNodeId,
313        cos: HirNodeId,
314        sin: HirNodeId,
315        head_dim: usize,
316        n_rot: usize,
317        out_shape: Shape,
318    ) -> HirNodeId {
319        self.hir_mut().rope(x, cos, sin, head_dim, n_rot, out_shape)
320    }
321
322    pub fn rms_norm(
323        &mut self,
324        x: HirNodeId,
325        gamma: HirNodeId,
326        beta: HirNodeId,
327        eps: f32,
328        out_shape: Shape,
329    ) -> HirNodeId {
330        self.hir_mut().rms_norm(x, gamma, beta, eps, out_shape)
331    }
332
333    pub fn hir_mir(&mut self, op: Op, inputs: Vec<HirNodeId>, shape: Shape) -> HirNodeId {
334        self.hir_mut().mir(op, inputs, shape)
335    }
336
337    pub fn named(
338        &mut self,
339        name: impl Into<String>,
340        build: impl FnOnce(&mut HirModule) -> HirNodeId,
341    ) -> HirNodeId {
342        self.hir_mut().named(name, build)
343    }
344
345    pub fn stage(&self) -> GraphStage {
346        match &self.stage {
347            Stage::Hir(_) => GraphStage::Hir,
348            Stage::Mir(_) => GraphStage::Mir,
349            Stage::Lir(_) => GraphStage::Lir,
350        }
351    }
352
353    pub fn name(&self) -> &str {
354        match &self.stage {
355            Stage::Hir(h) => &h.name,
356            Stage::Mir(m) => m.name(),
357            Stage::Lir(l) => l.name(),
358        }
359    }
360
361    pub fn lower(self) -> Result<Self, LowerError> {
362        match self.stage {
363            Stage::Hir(hir) => Ok(Self {
364                stage: Stage::Mir(hir.lower_to_mir()?),
365            }),
366            other => Ok(Self { stage: other }),
367        }
368    }
369
370    pub fn into_hir(self) -> Option<HirModule> {
371        match self.stage {
372            Stage::Hir(h) => Some(h),
373            _ => None,
374        }
375    }
376
377    pub fn into_mir(self) -> Result<MirModule, LowerError> {
378        match self.stage {
379            Stage::Hir(hir) => hir.lower_to_mir(),
380            Stage::Mir(m) => Ok(m),
381            Stage::Lir(l) => Ok(l.mir),
382        }
383    }
384
385    pub fn into_lir(self) -> Option<LirModule> {
386        match self.stage {
387            Stage::Lir(l) => Some(l),
388            _ => None,
389        }
390    }
391
392    pub fn into_graph(self) -> Result<Graph, LowerError> {
393        Ok(self.into_mir()?.into_graph())
394    }
395
396    pub fn as_hir(&self) -> Option<&HirModule> {
397        match &self.stage {
398            Stage::Hir(h) => Some(h),
399            _ => None,
400        }
401    }
402
403    pub fn as_hir_mut(&mut self) -> Option<&mut HirModule> {
404        match &mut self.stage {
405            Stage::Hir(h) => Some(h),
406            _ => None,
407        }
408    }
409
410    pub fn as_mir(&self) -> Option<&MirModule> {
411        match &self.stage {
412            Stage::Mir(m) => Some(m),
413            Stage::Lir(l) => Some(&l.mir),
414            _ => None,
415        }
416    }
417
418    pub fn as_lir(&self) -> Option<&LirModule> {
419        match &self.stage {
420            Stage::Lir(l) => Some(l),
421            _ => None,
422        }
423    }
424
425    pub fn as_graph(&self) -> Option<&Graph> {
426        match &self.stage {
427            Stage::Mir(m) => Some(m.as_graph()),
428            Stage::Lir(l) => Some(l.as_graph()),
429            Stage::Hir(_) => None,
430        }
431    }
432
433    pub fn inspect(&self) -> String {
434        match &self.stage {
435            Stage::Hir(h) => inspect_hir(h),
436            Stage::Mir(m) => inspect_mir(m),
437            Stage::Lir(l) => inspect_lir(l),
438        }
439    }
440}
441
442impl Deref for GraphModule {
443    type Target = Graph;
444
445    fn deref(&self) -> &Graph {
446        self.as_graph()
447            .expect("GraphModule: HIR stage — call lower() before accessing MIR Graph")
448    }
449}
450
451impl DerefMut for GraphModule {
452    fn deref_mut(&mut self) -> &mut Graph {
453        match &mut self.stage {
454            Stage::Mir(m) => m.as_graph_mut(),
455            Stage::Lir(l) => l.mir.as_graph_mut(),
456            Stage::Hir(_) => panic!("GraphModule: HIR stage — use as_hir_mut() or lower() first"),
457        }
458    }
459}
460
461impl From<Graph> for GraphModule {
462    fn from(graph: Graph) -> Self {
463        Self::from_graph(graph)
464    }
465}
466
467impl TryFrom<GraphModule> for Graph {
468    type Error = LowerError;
469
470    fn try_from(module: GraphModule) -> Result<Self, LowerError> {
471        module.into_graph()
472    }
473}
474
475impl From<MirModule> for GraphModule {
476    fn from(mir: MirModule) -> Self {
477        Self::from_mir(mir)
478    }
479}
480
481impl From<HirModule> for GraphModule {
482    fn from(hir: HirModule) -> Self {
483        Self::from_hir(hir)
484    }
485}
486
487impl std::fmt::Display for GraphModule {
488    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489        match &self.stage {
490            Stage::Hir(h) => write!(f, "{h}"),
491            Stage::Mir(m) => write!(f, "{m}"),
492            Stage::Lir(l) => write!(f, "lir @{}", l.name()),
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::DType;
501    use crate::Graph;
502    use crate::Shape;
503
504    fn f32_shape(d: &[usize]) -> Shape {
505        Shape::new(d, DType::F32)
506    }
507
508    #[test]
509    fn define_lowers_to_mir_graph() {
510        let module = GraphModule::define("m", |m| {
511            let x = m.input("x", f32_shape(&[2, 8]));
512            let w = m.param("w", f32_shape(&[8, 8]));
513            m.linear(x, w, None, None, f32_shape(&[2, 8]))
514        });
515        assert_eq!(module.stage(), GraphStage::Hir);
516        let module = module.lower().expect("lower");
517        assert_eq!(module.stage(), GraphStage::Mir);
518        assert!(module.len() >= 3);
519    }
520
521    #[test]
522    fn mir_module_deref_builds_graph() {
523        let mut module = GraphModule::mir("raw");
524        let x = module.input("x", f32_shape(&[4]));
525        module.set_outputs(vec![x]);
526        assert_eq!(module.len(), 1);
527    }
528
529    #[test]
530    fn hir_module_block_builders_via_graph_module() {
531        use crate::quant::QuantScheme;
532
533        let mut module = GraphModule::hir("layer");
534        let x = module.input("x", f32_shape(&[2, 128]));
535        let w = module.param("w", f32_shape(&[128, 128]));
536        let y = module.dequant_matmul(x, w, None, None, QuantScheme::GgufQ4K, f32_shape(&[2, 128]));
537        module.set_outputs(vec![y]);
538        assert_eq!(module.stage(), GraphStage::Hir);
539
540        let module = module.lower().expect("lower");
541        assert_eq!(module.stage(), GraphStage::Mir);
542        assert!(module.len() >= 3);
543    }
544
545    #[test]
546    fn graph_hir_entry_matches_define() {
547        let via_graph = Graph::hir("m");
548        let via_define = Graph::define("m", |m| {
549            let x = m.input("x", f32_shape(&[4]));
550            m.rms_norm(x, x, x, 1e-5, f32_shape(&[4]))
551        });
552        assert_eq!(via_graph.stage(), GraphStage::Hir);
553        assert_eq!(via_define.stage(), GraphStage::Hir);
554    }
555}