Skip to main content

rlx_compile/
compiler.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HIR → MIR → LIR compiler pipeline.
17//!
18//! Explicit staging for the RLX compiler:
19//!
20//! ```text
21//! HIR (blocks)  ──lower──▶  MIR (tensor DAG)  ──opt──▶  MIR  ──plan──▶  LIR
22//! ```
23//!
24//! Backends consume [`CompileResult`] / [`LirModule`] (optimized MIR +
25//! buffer plan + fusion report) and lower to device-specific thunks.
26
27use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37    FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38    supported_for_target,
39};
40use crate::legalize::{format_legalize_error, legalize_for_backend};
41use crate::memory::{self, MemoryPlan};
42use crate::rewrite::rewrite_for_backend_with_config;
43use rlx_fusion::fusion_report::FusionReport;
44use rlx_fusion::pass::run_passes;
45use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
46use rlx_ir::OpKind;
47use rlx_ir::logical_kernel::KernelDispatchConfig;
48
49/// End-to-end compiler output: optimized LIR + fusion diagnostics.
50#[derive(Debug, Clone)]
51pub struct CompileResult {
52    pub lir: LirModule,
53    pub fusion: FusionReport,
54}
55
56impl CompileResult {
57    pub fn has_dynamic_dims(&self) -> bool {
58        self.lir.has_dynamic_dims()
59    }
60
61    pub fn dynamic_symbols(&self) -> &[u32] {
62        self.lir.dynamic_symbols()
63    }
64
65    /// Re-plan buffers after binding symbolic dims to concrete sizes.
66    pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
67        Self {
68            lir: pipeline.specialize_lir(&self.lir, binding),
69            fusion: self.fusion.clone(),
70        }
71    }
72}
73
74/// End-to-end compiler pipeline configuration.
75#[derive(Debug, Clone, Copy)]
76pub struct CompilePipeline {
77    pub target: FusionTarget,
78    pub opts: FusionOptions,
79    pub arena_alignment: usize,
80    /// When true, [`compile_hir`] / [`compile_graph`] panic if fusion
81    /// diagnostics report missed block-level patterns.
82    pub assert_fusion_clean: bool,
83    /// Backend op claim set. When `Some` and non-empty, fusion passes
84    /// are gated on these kinds and the optimized graph is legalized
85    /// afterward. When `None`, [`supported_for_target`] is used.
86    pub supported_ops: Option<&'static [OpKind]>,
87    /// Native vs common IR lowering for logical kernels (see `rlx_ir::logical_kernel`).
88    pub kernel_dispatch: KernelDispatchConfig,
89}
90
91impl Default for CompilePipeline {
92    fn default() -> Self {
93        Self {
94            target: FusionTarget::Cpu,
95            opts: FusionOptions::for_cpu(),
96            arena_alignment: 64,
97            assert_fusion_clean: false,
98            supported_ops: None,
99            kernel_dispatch: KernelDispatchConfig::from_env(),
100        }
101    }
102}
103
104impl CompilePipeline {
105    pub fn new(target: FusionTarget) -> Self {
106        let mut opts = match target {
107            FusionTarget::Cpu => FusionOptions::for_cpu(),
108            FusionTarget::Metal => FusionOptions::from_metal_env(),
109            _ => FusionOptions::default(),
110        };
111        opts.fusion_limits = fusion_limits_for_target(target);
112        Self {
113            target,
114            opts,
115            ..Self::default()
116        }
117    }
118
119    pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
120        self.assert_fusion_clean = assert;
121        self
122    }
123
124    /// HIR → MIR (block lowering only).
125    pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
126        let mir = hir.lower_to_mir()?;
127        debug_assert_graph!(mir.as_graph(), "hir→mir");
128        Ok(mir)
129    }
130
131    /// Optional cleanup before fusion (DCE + control-flow lowering).
132    pub fn preprocess_mir(mir: MirModule) -> MirModule {
133        use rlx_fusion::pass::Pass as _;
134        let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
135        let graph = DeadCodeElimination.run(graph);
136        MirModule::from_graph(graph)
137    }
138
139    pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
140        self.supported_ops = Some(ops);
141        self
142    }
143
144    pub fn with_kernel_dispatch(
145        mut self,
146        policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
147    ) -> Self {
148        self.kernel_dispatch.policy = policy;
149        self
150    }
151
152    pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
153        self.kernel_dispatch = config;
154        self
155    }
156
157    fn effective_supported(&self) -> &'static [OpKind] {
158        self.supported_ops
159            .unwrap_or_else(|| supported_for_target(self.target))
160    }
161
162    fn backend_name(&self) -> &'static str {
163        match self.target {
164            FusionTarget::Cpu => "cpu",
165            FusionTarget::Metal => "metal",
166            FusionTarget::Mlx => "mlx",
167            FusionTarget::Wgpu => "wgpu",
168            FusionTarget::Cuda => "cuda",
169            FusionTarget::Rocm => "rocm",
170            FusionTarget::Tpu => "tpu",
171        }
172    }
173
174    /// Run fusion + cleanup passes on MIR, returning fusion diagnostics.
175    pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
176        let before = mir.as_graph().clone();
177        let passes = fusion_passes_for_supported(self.effective_supported(), self.opts);
178        let limits = self.opts.fusion_limits;
179        let graph = with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false));
180        let graph = clip_elementwise_regions(graph, limits);
181        debug_assert_graph!(&graph, "fusion");
182        let graph = self.legalize_after_fusion(graph);
183        debug_assert_graph!(&graph, "legalize");
184        let mir = MirModule::from_graph(graph);
185        let fusion = FusionReport::analyze(&before, mir.as_graph());
186        (mir, fusion)
187    }
188
189    /// Rewrite / legalize fused IR against the backend op claim set.
190    /// Runs when [`supported_ops`](Self::supported_ops) is set (including
191    /// auto-wiring from [`Backend::supported_ops`] in [`crate::stages::pipeline_for`]).
192    pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
193        let Some(supported) = self.supported_ops else {
194            if self.kernel_dispatch.force_common_kinds.is_empty()
195                && self.kernel_dispatch.policy
196                    == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
197            {
198                return graph;
199            }
200            return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
201        };
202        if supported.is_empty() {
203            return graph;
204        }
205        let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
206        if let Err(errors) = legalize_for_backend(&graph, supported) {
207            panic!("{}", format_legalize_error(self.backend_name(), &errors));
208        }
209        graph
210    }
211
212    /// Run fusion + cleanup passes on MIR.
213    pub fn optimize(&self, mir: MirModule) -> MirModule {
214        self.optimize_with_report(mir).0
215    }
216
217    /// MIR → LIR (memory plan + schedule + phases + I/O manifest).
218    pub fn plan_lir(&self, mir: MirModule) -> LirModule {
219        self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
220    }
221
222    /// MIR → LIR with explicit boundary allocation policy.
223    pub fn plan_lir_with_options(
224        &self,
225        mir: MirModule,
226        opts: memory::MemoryPlanOptions,
227    ) -> LirModule {
228        let graph = mir.as_graph().clone();
229        let plan = memory::plan_memory_with_options(&graph, self.arena_alignment, opts);
230        LirModule::new(
231            mir,
232            lir_buffer_plan_from_memory(&graph, &plan, self.arena_alignment),
233        )
234    }
235
236    /// Bind symbolic dims and re-run buffer planning on specialized MIR.
237    pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
238        use rlx_ir::dynamic::{
239            bind_graph, sync_concat_shapes, sync_graph_shapes, sync_narrow_ops, sync_reshape_ops,
240        };
241        let mut bound = bind_graph(lir.as_graph(), binding);
242        sync_reshape_ops(&mut bound);
243        sync_concat_shapes(&mut bound);
244        sync_narrow_ops(&mut bound);
245        sync_graph_shapes(&mut bound);
246        debug_assert_graph!(&bound, "specialize");
247        self.plan_lir(MirModule::from_graph(bound))
248    }
249
250    fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
251        debug_assert_graph!(mir.as_graph(), "pre-lir");
252        if self.assert_fusion_clean && !fusion.missed.is_empty() {
253            panic!(
254                "fusion contract violated: {} missed patterns\n{fusion}",
255                fusion.missed.len()
256            );
257        }
258        CompileResult {
259            lir: self.plan_lir(mir),
260            fusion,
261        }
262    }
263
264    /// HIR → LIR in one call with fusion report.
265    pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
266        if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
267            let name = hir.name.clone();
268            let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
269            crate::inspect::maybe_dump_pipeline(&dump, &name);
270        }
271        let mir = Self::lower_hir(hir)?;
272        let (mir, fusion) = self.optimize_with_report(mir);
273        Ok(self.finish(mir, fusion))
274    }
275
276    /// Legacy MIR entry: optimize + plan with fusion report.
277    pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
278        let (mir, fusion) = self.optimize_with_report(mir);
279        self.finish(mir, fusion)
280    }
281
282    /// Legacy entry: optimize an existing graph and plan buffers.
283    pub fn compile_graph(&self, graph: Graph) -> CompileResult {
284        self.compile_mir(MirModule::from_graph(graph))
285    }
286
287    /// Unified entry for [`GraphModule`] at any pipeline stage.
288    pub fn compile_module(
289        &self,
290        module: GraphModule,
291    ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
292        match module.stage() {
293            GraphStage::Hir => {
294                let hir = module
295                    .into_hir()
296                    .expect("GraphModule stage() / into_hir mismatch");
297                self.compile_hir(hir)
298            }
299            GraphStage::Mir => {
300                let mir = module.into_mir()?;
301                Ok(self.compile_mir(mir))
302            }
303            GraphStage::Lir => Ok(CompileResult {
304                lir: module
305                    .into_lir()
306                    .expect("GraphModule stage() / into_lir mismatch"),
307                fusion: FusionReport::default(),
308            }),
309        }
310    }
311}
312
313impl From<&MemoryPlan> for LirBufferPlan {
314    fn from(plan: &MemoryPlan) -> Self {
315        LirBufferPlan {
316            arena_size: plan.arena_size,
317            assignments: plan
318                .assignments
319                .iter()
320                .map(|(id, slot)| {
321                    (
322                        *id,
323                        LirBufferSlot {
324                            offset: slot.offset,
325                            size: slot.size,
326                        },
327                    )
328                })
329                .collect(),
330            schedule: plan.schedule.clone(),
331            ..Default::default()
332        }
333    }
334}
335
336impl From<&LirBufferPlan> for MemoryPlan {
337    fn from(plan: &LirBufferPlan) -> Self {
338        MemoryPlan {
339            arena_size: plan.arena_size,
340            assignments: plan
341                .assignments
342                .iter()
343                .map(|(id, slot)| {
344                    (
345                        *id,
346                        memory::BufferSlot {
347                            offset: slot.offset,
348                            size: slot.size,
349                        },
350                    )
351                })
352                .collect(),
353            schedule: plan.schedule.clone(),
354        }
355    }
356}
357
358pub(crate) fn lir_buffer_plan_from_memory(
359    graph: &Graph,
360    plan: &MemoryPlan,
361    alignment: usize,
362) -> LirBufferPlan {
363    let view_aliases = memory::collect_view_aliases(graph)
364        .into_iter()
365        .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
366        .collect();
367    LirBufferPlan {
368        arena_size: plan.arena_size,
369        assignments: plan
370            .assignments
371            .iter()
372            .map(|(id, slot)| {
373                (
374                    *id,
375                    LirBufferSlot {
376                        offset: slot.offset,
377                        size: slot.size,
378                    },
379                )
380            })
381            .collect(),
382        schedule: plan.schedule.clone(),
383        view_aliases,
384        phases: derive_phases(graph),
385        io: LirIoManifest::collect(graph),
386        alignment,
387        dynamic_symbols: collect_dynamic_symbols(graph),
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use rlx_ir::DType;
395    use rlx_ir::Op;
396    use rlx_ir::Shape;
397    use rlx_ir::hir::FusionPolicy;
398
399    fn f32_shape(d: &[usize]) -> Shape {
400        Shape::new(d, DType::F32)
401    }
402
403    #[test]
404    fn pipeline_hir_to_lir() {
405        let mut hir = HirModule::new("layer");
406        let x = hir.input("x", f32_shape(&[2, 128]));
407        let w = hir.param("w", f32_shape(&[128, 128]));
408        let b = hir.param("b", f32_shape(&[128]));
409        let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
410        hir.outputs = vec![h];
411
412        let pipe = CompilePipeline::new(FusionTarget::Cpu);
413        let result = pipe.compile_hir(hir).expect("compile");
414        assert!(result.lir.mir.len() <= 5);
415        assert!(result.lir.arena_size() > 0);
416        assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
417        assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
418    }
419
420    #[test]
421    fn direct_hir_swiglu_emits_fused_op() {
422        let mut hir = HirModule::new("ffn");
423        let x = hir.input("x", f32_shape(&[4, 768]));
424        let up_w = hir.param("up", f32_shape(&[768, 2048]));
425        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
426        let down_w = hir.param("down", f32_shape(&[2048, 768]));
427        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
428        hir.outputs = vec![out];
429
430        let pipe = CompilePipeline::new(FusionTarget::Cpu);
431        let result = pipe.compile_hir(hir).expect("compile");
432        let g = result.lir.mir.as_graph();
433        assert!(
434            g.nodes()
435                .iter()
436                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
437            "direct HIR SwiGLU should lower to FusedSwiGLU"
438        );
439        assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
440    }
441
442    #[test]
443    fn compile_module_from_graph_define() {
444        let module = GraphModule::define("ffn", |m| {
445            let x = m.input("x", f32_shape(&[2, 64]));
446            let w = m.param("w", f32_shape(&[64, 64]));
447            m.linear(x, w, None, None, f32_shape(&[2, 64]))
448        });
449        assert_eq!(module.stage(), GraphStage::Hir);
450
451        let pipe = CompilePipeline::new(FusionTarget::Cpu);
452        let result = pipe.compile_module(module).expect("compile_module");
453        assert!(result.lir.arena_size() > 0);
454    }
455
456    #[test]
457    fn fusable_policy_leaves_room_for_passes() {
458        let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
459        let x = hir.input("x", f32_shape(&[4, 768]));
460        let up_w = hir.param("up", f32_shape(&[768, 2048]));
461        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
462        let down_w = hir.param("down", f32_shape(&[2048, 768]));
463        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
464        hir.outputs = vec![out];
465
466        let mir = CompilePipeline::lower_hir(hir).expect("lower");
467        let g = mir.as_graph();
468        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
469        assert_eq!(g.len(), 9);
470
471        let pipe = CompilePipeline::new(FusionTarget::Cpu);
472        let result = pipe.compile_mir(mir);
473        assert!(result.fusion.fused_swiglu >= 1);
474    }
475
476    #[test]
477    fn lir_plan_includes_phases_io_and_fingerprint() {
478        use rlx_ir::phase::Phase;
479
480        let mut hir = HirModule::new("stream");
481        let x = hir.input("x", f32_shape(&[1, 8]));
482        let w = hir.param("w", f32_shape(&[8, 4]));
483        let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
484        hir.set_outputs(vec![mm]);
485
486        let result = CompilePipeline::new(FusionTarget::Cpu)
487            .compile_hir(hir)
488            .expect("compile");
489        assert!(!result.lir.buffers.phases.is_empty());
490        let input_id = result.lir.buffers.io.inputs[0].1;
491        assert_eq!(
492            result.lir.buffers.phases.get(input_id),
493            Some(Phase::Prologue)
494        );
495        assert_eq!(result.lir.buffers.io.inputs.len(), 1);
496        assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
497        assert_eq!(result.lir.buffers.alignment, 64);
498    }
499
500    #[test]
501    fn dynamic_graph_compiles_and_specializes() {
502        use rlx_ir::DimBinding;
503        use rlx_ir::infer::GraphExt as _;
504        use rlx_ir::sym;
505
506        let mut g = Graph::new("dyn");
507        let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
508        let w = g.param("w", Shape::new(&[4, 8], DType::F32));
509        let y = g.mm(x, w);
510        g.set_outputs(vec![y]);
511
512        let pipe = CompilePipeline::new(FusionTarget::Cpu);
513        let result = pipe.compile_graph(g);
514        assert!(result.has_dynamic_dims());
515        assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
516
517        let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
518        assert!(bound.lir.is_fully_static());
519        assert!(bound.lir.arena_size() > 0);
520    }
521}