Skip to main content

rlx_compile/
compiler.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HIR → MIR → LIR compiler pipeline.
17//!
18//! Explicit staging for the RLX compiler:
19//!
20//! ```text
21//! HIR (blocks)  ──lower──▶  MIR (tensor DAG)  ──opt──▶  MIR  ──plan──▶  LIR
22//! ```
23//!
24//! Backends consume [`CompileResult`] / [`LirModule`] (optimized MIR +
25//! buffer plan + fusion report) and lower to device-specific thunks.
26
27use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37    FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38    supported_for_target,
39};
40use crate::fusion_target::with_fusion_target;
41use crate::legalize::{format_legalize_error, legalize_for_backend};
42use crate::memory::{self, MemoryPlan};
43use crate::rewrite::rewrite_for_backend_with_config;
44use rlx_fusion::fusion_report::FusionReport;
45use rlx_fusion::pass::run_passes;
46use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
47use rlx_ir::OpKind;
48use rlx_ir::logical_kernel::KernelDispatchConfig;
49
50/// End-to-end compiler output: optimized LIR + fusion diagnostics.
51#[derive(Debug, Clone)]
52pub struct CompileResult {
53    pub lir: LirModule,
54    pub fusion: FusionReport,
55}
56
57impl CompileResult {
58    pub fn has_dynamic_dims(&self) -> bool {
59        self.lir.has_dynamic_dims()
60    }
61
62    pub fn dynamic_symbols(&self) -> &[u32] {
63        self.lir.dynamic_symbols()
64    }
65
66    /// Re-plan buffers after binding symbolic dims to concrete sizes.
67    pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
68        Self {
69            lir: pipeline.specialize_lir(&self.lir, binding),
70            fusion: self.fusion.clone(),
71        }
72    }
73}
74
75/// End-to-end compiler pipeline configuration.
76#[derive(Debug, Clone, Copy)]
77pub struct CompilePipeline {
78    pub target: FusionTarget,
79    pub opts: FusionOptions,
80    pub arena_alignment: usize,
81    /// When true, [`compile_hir`] / [`compile_graph`] panic if fusion
82    /// diagnostics report missed block-level patterns.
83    pub assert_fusion_clean: bool,
84    /// Backend op claim set. When `Some` and non-empty, fusion passes
85    /// are gated on these kinds and the optimized graph is legalized
86    /// afterward. When `None`, [`supported_for_target`] is used.
87    pub supported_ops: Option<&'static [OpKind]>,
88    /// Native vs common IR lowering for logical kernels (see `rlx_ir::logical_kernel`).
89    pub kernel_dispatch: KernelDispatchConfig,
90}
91
92impl Default for CompilePipeline {
93    fn default() -> Self {
94        Self {
95            target: FusionTarget::Cpu,
96            opts: FusionOptions::for_cpu(),
97            arena_alignment: 64,
98            assert_fusion_clean: false,
99            supported_ops: None,
100            kernel_dispatch: KernelDispatchConfig::from_env(),
101        }
102    }
103}
104
105fn lstm_y_shape(x: &rlx_ir::Shape, hidden_size: usize, bidirectional: bool) -> rlx_ir::Shape {
106    let dirs = if bidirectional { 2 } else { 1 };
107    if x.rank() == 3 {
108        let seq = x.dim(0).unwrap_static();
109        let batch = x.dim(1).unwrap_static().max(1);
110        return rlx_ir::Shape::new(&[seq, dirs, batch, hidden_size], x.dtype());
111    }
112    rlx_ir::Shape::new(&[1, dirs, 1, hidden_size], x.dtype())
113}
114
115/// `sync_graph_shapes` can collapse `[seq,1,C]` LSTM inputs to `[1,1,C]`; restore seq.
116fn fix_import_lstm_x_shape(x: &rlx_ir::Shape) -> rlx_ir::Shape {
117    if x.rank() != 3 {
118        return x.clone();
119    }
120    let d0 = x.dim(0).unwrap_static();
121    let d1 = x.dim(1).unwrap_static();
122    let d2 = x.dim(2).unwrap_static();
123    if d0 == 1 && d1 <= 1 && (d2 == 640 || d2 == 512) {
124        let seq = std::env::var("RLX_ONNX_SEQUENCE_LENGTH")
125            .ok()
126            .and_then(|s| s.parse().ok())
127            .unwrap_or(128);
128        return rlx_ir::Shape::new(&[seq, d1.max(1), d2], x.dtype());
129    }
130    x.clone()
131}
132
133fn fix_lstm_output_shapes(graph: &mut Graph) {
134    use rlx_ir::Op;
135    let ids: Vec<NodeId> = graph.nodes().iter().map(|n| n.id).collect();
136    for id in ids {
137        let node = graph.node(id).clone();
138        let Op::Custom { name, attrs, .. } = &node.op else {
139            continue;
140        };
141        if !name.contains("LSTM") {
142            continue;
143        }
144        let hidden_size = if attrs.len() >= 4 {
145            u32::from_le_bytes(attrs[0..4].try_into().unwrap()) as usize
146        } else {
147            256
148        };
149        let bidirectional = attrs.len() > 4 && attrs[4] != 0;
150        let x_id = node.inputs[0];
151        let x = fix_import_lstm_x_shape(&graph.node(x_id).shape);
152        graph.node_mut(x_id).shape = x.clone();
153        graph.node_mut(id).shape = lstm_y_shape(&x, hidden_size, bidirectional);
154    }
155}
156
157/// `sync_graph_shapes` can collapse `[1, seq, C]` activations to `[1, 1, C]`
158/// when seq>1; restore from `RLX_ONNX_SEQUENCE_LENGTH` and propagate once.
159///
160/// Only runs when `RLX_ONNX_SEQUENCE_LENGTH` is set explicitly — decode graphs such as
161/// Qwen3 talker use legitimate `[1, 1, H]` hidden states and must not be expanded.
162fn fix_import_sequence_axis(graph: &mut Graph) {
163    let Ok(seq_str) = std::env::var("RLX_ONNX_SEQUENCE_LENGTH") else {
164        return;
165    };
166    let seq: usize = match seq_str.parse() {
167        Ok(s) if s > 1 => s,
168        _ => return,
169    };
170    for id in graph.nodes().iter().map(|n| n.id).collect::<Vec<_>>() {
171        let node = graph.node(id);
172        if node.shape.rank() != 3 {
173            continue;
174        }
175        let dims: Vec<_> = node
176            .shape
177            .dims()
178            .iter()
179            .map(|d| d.unwrap_static())
180            .collect();
181        if dims[0] == 1 && dims[1] == 1 && dims[2] >= 64 {
182            graph.node_mut(id).shape = rlx_ir::Shape::new(&[1, seq, dims[2]], node.shape.dtype());
183        }
184    }
185    for id in graph.topo_order().collect::<Vec<_>>() {
186        let node = graph.node(id).clone();
187        if let Some(shape) = rlx_ir::infer_shape::infer_output_shape(graph, &node) {
188            graph.node_mut(id).shape = shape;
189        }
190    }
191}
192
193impl CompilePipeline {
194    pub fn new(target: FusionTarget) -> Self {
195        let mut opts = match target {
196            FusionTarget::Cpu => FusionOptions::for_cpu(),
197            FusionTarget::Metal => FusionOptions::for_metal(),
198            FusionTarget::Wgpu => FusionOptions::for_wgpu(),
199            _ => FusionOptions::default(),
200        };
201        opts.fusion_limits = fusion_limits_for_target(target);
202        Self {
203            target,
204            opts,
205            ..Self::default()
206        }
207    }
208
209    pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
210        self.assert_fusion_clean = assert;
211        self
212    }
213
214    /// HIR → MIR (block lowering only).
215    pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
216        let mut mir = hir.lower_to_mir()?;
217        rlx_ir::dynamic::sync_graph_shapes(mir.as_graph_mut());
218        debug_assert_graph!(mir.as_graph(), "hir→mir");
219        Ok(mir)
220    }
221
222    /// Optional cleanup before fusion (DCE + control-flow lowering).
223    pub fn preprocess_mir(mir: MirModule) -> MirModule {
224        use rlx_fusion::pass::Pass as _;
225        let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
226        let graph = DeadCodeElimination.run(graph);
227        MirModule::from_graph(graph)
228    }
229
230    pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
231        self.supported_ops = Some(ops);
232        self
233    }
234
235    pub fn with_kernel_dispatch(
236        mut self,
237        policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
238    ) -> Self {
239        self.kernel_dispatch.policy = policy;
240        self
241    }
242
243    pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
244        self.kernel_dispatch = config;
245        self
246    }
247
248    fn effective_supported(&self) -> &'static [OpKind] {
249        self.supported_ops
250            .unwrap_or_else(|| supported_for_target(self.target))
251    }
252
253    fn backend_name(&self) -> &'static str {
254        match self.target {
255            FusionTarget::Cpu => "cpu",
256            FusionTarget::Metal => "metal",
257            FusionTarget::Mlx => "mlx",
258            FusionTarget::Wgpu => "wgpu",
259            FusionTarget::Cuda => "cuda",
260            FusionTarget::Rocm => "rocm",
261            FusionTarget::Tpu => "tpu",
262        }
263    }
264
265    /// Run fusion + cleanup passes on MIR, returning fusion diagnostics.
266    pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
267        let need_fusion_diff = self.assert_fusion_clean || rlx_ir::env::flag("RLX_FUSION_REPORT");
268        let before = need_fusion_diff.then(|| mir.as_graph().clone());
269        let passes =
270            fusion_passes_for_supported(self.effective_supported(), self.opts, self.target);
271        let limits = self.opts.fusion_limits;
272        let graph = with_fusion_target(self.target, || {
273            with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false))
274        });
275        let graph = clip_elementwise_regions(graph, limits);
276        debug_assert_graph!(&graph, "fusion");
277        let mut graph = self.legalize_after_fusion(graph);
278        rlx_ir::dynamic::sync_graph_shapes(&mut graph);
279        fix_import_sequence_axis(&mut graph);
280        fix_lstm_output_shapes(&mut graph);
281        debug_assert_graph!(&graph, "legalize");
282        let mir = MirModule::from_graph(graph);
283        let fusion = if let Some(ref before) = before {
284            rlx_fusion::FusionReport::analyze(before, mir.as_graph())
285        } else {
286            rlx_fusion::FusionReport::scan(mir.as_graph())
287        };
288        (mir, fusion)
289    }
290
291    /// Rewrite / legalize fused IR against the backend op claim set.
292    /// Runs when [`supported_ops`](Self::supported_ops) is set (including
293    /// auto-wiring from [`Backend::supported_ops`] in [`crate::stages::pipeline_for`]).
294    pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
295        let Some(supported) = self.supported_ops else {
296            if self.kernel_dispatch.force_common_kinds.is_empty()
297                && self.kernel_dispatch.policy
298                    == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
299            {
300                return graph;
301            }
302            return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
303        };
304        if supported.is_empty() {
305            return graph;
306        }
307        let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
308        if let Err(errors) = legalize_for_backend(&graph, supported) {
309            panic!("{}", format_legalize_error(self.backend_name(), &errors));
310        }
311        graph
312    }
313
314    /// Run fusion + cleanup passes on MIR.
315    pub fn optimize(&self, mir: MirModule) -> MirModule {
316        self.optimize_with_report(mir).0
317    }
318
319    /// MIR → LIR (memory plan + schedule + phases + I/O manifest).
320    pub fn plan_lir(&self, mir: MirModule) -> LirModule {
321        self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
322    }
323
324    /// MIR → LIR with explicit boundary allocation policy.
325    pub fn plan_lir_with_options(
326        &self,
327        mir: MirModule,
328        opts: memory::MemoryPlanOptions,
329    ) -> LirModule {
330        let graph = mir.as_graph();
331        let plan = memory::plan_memory_with_options(graph, self.arena_alignment, opts);
332        let buffers = lir_buffer_plan_from_memory(graph, &plan, self.arena_alignment);
333        LirModule::new(mir, buffers)
334    }
335
336    /// Bind symbolic dims and re-run buffer planning on specialized MIR.
337    pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
338        use rlx_ir::dynamic::{
339            bind_graph, sync_concat_shapes, sync_expand_ops, sync_graph_shapes, sync_narrow_ops,
340            sync_reshape_ops,
341        };
342        let mut bound = bind_graph(lir.as_graph(), binding);
343        sync_reshape_ops(&mut bound);
344        sync_concat_shapes(&mut bound);
345        sync_narrow_ops(&mut bound);
346        sync_expand_ops(&mut bound);
347        sync_graph_shapes(&mut bound);
348        debug_assert_graph!(&bound, "specialize");
349        self.plan_lir(MirModule::from_graph(bound))
350    }
351
352    fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
353        debug_assert_graph!(mir.as_graph(), "pre-lir");
354        if self.assert_fusion_clean && !fusion.missed.is_empty() {
355            panic!(
356                "fusion contract violated: {} missed patterns\n{fusion}",
357                fusion.missed.len()
358            );
359        }
360        CompileResult {
361            lir: self.plan_lir(mir),
362            fusion,
363        }
364    }
365
366    /// HIR → LIR in one call with fusion report.
367    pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
368        if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
369            let name = hir.name.clone();
370            let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
371            crate::inspect::maybe_dump_pipeline(&dump, &name);
372        }
373        let mir = Self::lower_hir(hir)?;
374        let (mir, fusion) = self.optimize_with_report(mir);
375        Ok(self.finish(mir, fusion))
376    }
377
378    /// Legacy MIR entry: optimize + plan with fusion report.
379    pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
380        let (mir, fusion) = self.optimize_with_report(mir);
381        self.finish(mir, fusion)
382    }
383
384    /// Legacy entry: optimize an existing graph and plan buffers.
385    pub fn compile_graph(&self, graph: Graph) -> CompileResult {
386        self.compile_mir(MirModule::from_graph(graph))
387    }
388
389    /// Unified entry for [`GraphModule`] at any pipeline stage.
390    pub fn compile_module(
391        &self,
392        module: GraphModule,
393    ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
394        match module.stage() {
395            GraphStage::Hir => {
396                let hir = module
397                    .into_hir()
398                    .expect("GraphModule stage() / into_hir mismatch");
399                self.compile_hir(hir)
400            }
401            GraphStage::Mir => {
402                let mir = module.into_mir()?;
403                Ok(self.compile_mir(mir))
404            }
405            GraphStage::Lir => Ok(CompileResult {
406                lir: module
407                    .into_lir()
408                    .expect("GraphModule stage() / into_lir mismatch"),
409                fusion: FusionReport::default(),
410            }),
411        }
412    }
413}
414
415impl From<&MemoryPlan> for LirBufferPlan {
416    fn from(plan: &MemoryPlan) -> Self {
417        LirBufferPlan {
418            arena_size: plan.arena_size,
419            assignments: plan
420                .assignments
421                .iter()
422                .map(|(id, slot)| {
423                    (
424                        *id,
425                        LirBufferSlot {
426                            offset: slot.offset,
427                            size: slot.size,
428                        },
429                    )
430                })
431                .collect(),
432            schedule: plan.schedule.clone(),
433            ..Default::default()
434        }
435    }
436}
437
438impl From<&LirBufferPlan> for MemoryPlan {
439    fn from(plan: &LirBufferPlan) -> Self {
440        MemoryPlan {
441            arena_size: plan.arena_size,
442            assignments: plan
443                .assignments
444                .iter()
445                .map(|(id, slot)| {
446                    (
447                        *id,
448                        memory::BufferSlot {
449                            offset: slot.offset,
450                            size: slot.size,
451                        },
452                    )
453                })
454                .collect(),
455            schedule: plan.schedule.clone(),
456        }
457    }
458}
459
460pub(crate) fn lir_buffer_plan_from_memory(
461    graph: &Graph,
462    plan: &MemoryPlan,
463    alignment: usize,
464) -> LirBufferPlan {
465    let view_aliases = memory::collect_view_aliases(graph)
466        .into_iter()
467        .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
468        .collect();
469    LirBufferPlan {
470        arena_size: plan.arena_size,
471        assignments: plan
472            .assignments
473            .iter()
474            .map(|(id, slot)| {
475                (
476                    *id,
477                    LirBufferSlot {
478                        offset: slot.offset,
479                        size: slot.size,
480                    },
481                )
482            })
483            .collect(),
484        schedule: plan.schedule.clone(),
485        view_aliases,
486        phases: derive_phases(graph),
487        io: LirIoManifest::collect(graph),
488        alignment,
489        dynamic_symbols: collect_dynamic_symbols(graph),
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use rlx_ir::DType;
497    use rlx_ir::Op;
498    use rlx_ir::Shape;
499    use rlx_ir::hir::FusionPolicy;
500
501    fn f32_shape(d: &[usize]) -> Shape {
502        Shape::new(d, DType::F32)
503    }
504
505    #[test]
506    fn pipeline_hir_to_lir() {
507        let mut hir = HirModule::new("layer");
508        let x = hir.input("x", f32_shape(&[2, 128]));
509        let w = hir.param("w", f32_shape(&[128, 128]));
510        let b = hir.param("b", f32_shape(&[128]));
511        let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
512        hir.outputs = vec![h];
513
514        let pipe = CompilePipeline::new(FusionTarget::Cpu);
515        let result = pipe.compile_hir(hir).expect("compile");
516        assert!(result.lir.mir.len() <= 5);
517        assert!(result.lir.arena_size() > 0);
518        assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
519        assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
520    }
521
522    #[test]
523    fn direct_hir_swiglu_emits_fused_op() {
524        let mut hir = HirModule::new("ffn");
525        let x = hir.input("x", f32_shape(&[4, 768]));
526        let up_w = hir.param("up", f32_shape(&[768, 2048]));
527        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
528        let down_w = hir.param("down", f32_shape(&[2048, 768]));
529        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
530        hir.outputs = vec![out];
531
532        let pipe = CompilePipeline::new(FusionTarget::Cpu);
533        let result = pipe.compile_hir(hir).expect("compile");
534        let g = result.lir.mir.as_graph();
535        assert!(
536            g.nodes()
537                .iter()
538                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
539            "direct HIR SwiGLU should lower to FusedSwiGLU"
540        );
541        assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
542    }
543
544    #[test]
545    fn compile_module_from_graph_define() {
546        let module = GraphModule::define("ffn", |m| {
547            let x = m.input("x", f32_shape(&[2, 64]));
548            let w = m.param("w", f32_shape(&[64, 64]));
549            m.linear(x, w, None, None, f32_shape(&[2, 64]))
550        });
551        assert_eq!(module.stage(), GraphStage::Hir);
552
553        let pipe = CompilePipeline::new(FusionTarget::Cpu);
554        let result = pipe.compile_module(module).expect("compile_module");
555        assert!(result.lir.arena_size() > 0);
556    }
557
558    #[test]
559    fn fusable_policy_leaves_room_for_passes() {
560        let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
561        let x = hir.input("x", f32_shape(&[4, 768]));
562        let up_w = hir.param("up", f32_shape(&[768, 2048]));
563        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
564        let down_w = hir.param("down", f32_shape(&[2048, 768]));
565        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
566        hir.outputs = vec![out];
567
568        let mir = CompilePipeline::lower_hir(hir).expect("lower");
569        let g = mir.as_graph();
570        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
571        assert_eq!(g.len(), 9);
572
573        let pipe = CompilePipeline::new(FusionTarget::Cpu);
574        let result = pipe.compile_mir(mir);
575        assert!(result.fusion.fused_swiglu >= 1);
576    }
577
578    #[test]
579    fn lir_plan_includes_phases_io_and_fingerprint() {
580        use rlx_ir::phase::Phase;
581
582        let mut hir = HirModule::new("stream");
583        let x = hir.input("x", f32_shape(&[1, 8]));
584        let w = hir.param("w", f32_shape(&[8, 4]));
585        let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
586        hir.set_outputs(vec![mm]);
587
588        let result = CompilePipeline::new(FusionTarget::Cpu)
589            .compile_hir(hir)
590            .expect("compile");
591        assert!(!result.lir.buffers.phases.is_empty());
592        let input_id = result.lir.buffers.io.inputs[0].1;
593        assert_eq!(
594            result.lir.buffers.phases.get(input_id),
595            Some(Phase::Prologue)
596        );
597        assert_eq!(result.lir.buffers.io.inputs.len(), 1);
598        assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
599        assert_eq!(result.lir.buffers.alignment, 64);
600    }
601
602    #[test]
603    fn decode_hidden_shape_not_expanded_without_env() {
604        // Qwen3 talker decode uses [1, 1, H] hidden states; must not be expanded to
605        // [1, RLX_ONNX_SEQUENCE_LENGTH, H] unless that env is set explicitly.
606        let mut g = Graph::new("decode_out");
607        let x = g.input("x", f32_shape(&[1, 1, 1024]));
608        g.set_outputs(vec![x]);
609        let pipe = CompilePipeline::new(FusionTarget::Cpu);
610        let result = pipe.compile_graph(g);
611        let out = result
612            .lir
613            .mir
614            .as_graph()
615            .node(result.lir.mir.as_graph().outputs[0]);
616        assert_eq!(out.shape.dims()[1].unwrap_static(), 1);
617        assert_eq!(out.shape.num_elements(), Some(1024));
618    }
619
620    #[test]
621    fn dynamic_graph_compiles_and_specializes() {
622        use rlx_ir::DimBinding;
623        use rlx_ir::infer::GraphExt as _;
624        use rlx_ir::sym;
625
626        let mut g = Graph::new("dyn");
627        let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
628        let w = g.param("w", Shape::new(&[4, 8], DType::F32));
629        let y = g.mm(x, w);
630        g.set_outputs(vec![y]);
631
632        let pipe = CompilePipeline::new(FusionTarget::Cpu);
633        let result = pipe.compile_graph(g);
634        assert!(result.has_dynamic_dims());
635        assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
636
637        let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
638        assert!(bound.lir.is_fully_static());
639        assert!(bound.lir.arena_size() > 0);
640    }
641}