Skip to main content

rlx_compile/
compiler.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! HIR → MIR → LIR compiler pipeline.
17//!
18//! Explicit staging for the RLX compiler:
19//!
20//! ```text
21//! HIR (blocks)  ──lower──▶  MIR (tensor DAG)  ──opt──▶  MIR  ──plan──▶  LIR
22//! ```
23//!
24//! Backends consume [`CompileResult`] / [`LirModule`] (optimized MIR +
25//! buffer plan + fusion report) and lower to device-specific thunks.
26
27use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37    FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38    supported_for_target,
39};
40use crate::legalize::{format_legalize_error, legalize_for_backend};
41use crate::memory::{self, MemoryPlan};
42use crate::rewrite::rewrite_for_backend_with_config;
43use rlx_fusion::fusion_report::FusionReport;
44use rlx_fusion::pass::run_passes;
45use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
46use rlx_ir::OpKind;
47use rlx_ir::logical_kernel::KernelDispatchConfig;
48
49/// End-to-end compiler output: optimized LIR + fusion diagnostics.
50#[derive(Debug, Clone)]
51pub struct CompileResult {
52    pub lir: LirModule,
53    pub fusion: FusionReport,
54}
55
56impl CompileResult {
57    pub fn has_dynamic_dims(&self) -> bool {
58        self.lir.has_dynamic_dims()
59    }
60
61    pub fn dynamic_symbols(&self) -> &[u32] {
62        self.lir.dynamic_symbols()
63    }
64
65    /// Re-plan buffers after binding symbolic dims to concrete sizes.
66    pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
67        Self {
68            lir: pipeline.specialize_lir(&self.lir, binding),
69            fusion: self.fusion.clone(),
70        }
71    }
72}
73
74/// End-to-end compiler pipeline configuration.
75#[derive(Debug, Clone, Copy)]
76pub struct CompilePipeline {
77    pub target: FusionTarget,
78    pub opts: FusionOptions,
79    pub arena_alignment: usize,
80    /// When true, [`compile_hir`] / [`compile_graph`] panic if fusion
81    /// diagnostics report missed block-level patterns.
82    pub assert_fusion_clean: bool,
83    /// Backend op claim set. When `Some` and non-empty, fusion passes
84    /// are gated on these kinds and the optimized graph is legalized
85    /// afterward. When `None`, [`supported_for_target`] is used.
86    pub supported_ops: Option<&'static [OpKind]>,
87    /// Native vs common IR lowering for logical kernels (see `rlx_ir::logical_kernel`).
88    pub kernel_dispatch: KernelDispatchConfig,
89}
90
91impl Default for CompilePipeline {
92    fn default() -> Self {
93        Self {
94            target: FusionTarget::Cpu,
95            opts: FusionOptions::for_cpu(),
96            arena_alignment: 64,
97            assert_fusion_clean: false,
98            supported_ops: None,
99            kernel_dispatch: KernelDispatchConfig::from_env(),
100        }
101    }
102}
103
104fn lstm_y_shape(x: &rlx_ir::Shape, hidden_size: usize, bidirectional: bool) -> rlx_ir::Shape {
105    let dirs = if bidirectional { 2 } else { 1 };
106    if x.rank() == 3 {
107        let seq = x.dim(0).unwrap_static();
108        let batch = x.dim(1).unwrap_static().max(1);
109        return rlx_ir::Shape::new(&[seq, dirs, batch, hidden_size], x.dtype());
110    }
111    rlx_ir::Shape::new(&[1, dirs, 1, hidden_size], x.dtype())
112}
113
114/// `sync_graph_shapes` can collapse `[seq,1,C]` LSTM inputs to `[1,1,C]`; restore seq.
115fn fix_import_lstm_x_shape(x: &rlx_ir::Shape) -> rlx_ir::Shape {
116    if x.rank() != 3 {
117        return x.clone();
118    }
119    let d0 = x.dim(0).unwrap_static();
120    let d1 = x.dim(1).unwrap_static();
121    let d2 = x.dim(2).unwrap_static();
122    if d0 == 1 && d1 <= 1 && (d2 == 640 || d2 == 512) {
123        let seq = std::env::var("RLX_ONNX_SEQUENCE_LENGTH")
124            .ok()
125            .and_then(|s| s.parse().ok())
126            .unwrap_or(128);
127        return rlx_ir::Shape::new(&[seq, d1.max(1), d2], x.dtype());
128    }
129    x.clone()
130}
131
132fn fix_lstm_output_shapes(graph: &mut Graph) {
133    use rlx_ir::Op;
134    let ids: Vec<NodeId> = graph.nodes().iter().map(|n| n.id).collect();
135    for id in ids {
136        let node = graph.node(id).clone();
137        let Op::Custom { name, attrs, .. } = &node.op else {
138            continue;
139        };
140        if !name.contains("LSTM") {
141            continue;
142        }
143        let hidden_size = if attrs.len() >= 4 {
144            u32::from_le_bytes(attrs[0..4].try_into().unwrap()) as usize
145        } else {
146            256
147        };
148        let bidirectional = attrs.len() > 4 && attrs[4] != 0;
149        let x_id = node.inputs[0];
150        let x = fix_import_lstm_x_shape(&graph.node(x_id).shape);
151        graph.node_mut(x_id).shape = x.clone();
152        graph.node_mut(id).shape = lstm_y_shape(&x, hidden_size, bidirectional);
153    }
154}
155
156/// `sync_graph_shapes` can collapse `[1, seq, C]` activations to `[1, 1, C]`
157/// when seq>1; restore from `RLX_ONNX_SEQUENCE_LENGTH` and propagate once.
158///
159/// Only runs when `RLX_ONNX_SEQUENCE_LENGTH` is set explicitly — decode graphs such as
160/// Qwen3 talker use legitimate `[1, 1, H]` hidden states and must not be expanded.
161fn fix_import_sequence_axis(graph: &mut Graph) {
162    let Ok(seq_str) = std::env::var("RLX_ONNX_SEQUENCE_LENGTH") else {
163        return;
164    };
165    let seq: usize = match seq_str.parse() {
166        Ok(s) if s > 1 => s,
167        _ => return,
168    };
169    for id in graph.nodes().iter().map(|n| n.id).collect::<Vec<_>>() {
170        let node = graph.node(id);
171        if node.shape.rank() != 3 {
172            continue;
173        }
174        let dims: Vec<_> = node
175            .shape
176            .dims()
177            .iter()
178            .map(|d| d.unwrap_static())
179            .collect();
180        if dims[0] == 1 && dims[1] == 1 && dims[2] >= 64 {
181            graph.node_mut(id).shape = rlx_ir::Shape::new(&[1, seq, dims[2]], node.shape.dtype());
182        }
183    }
184    for id in graph.topo_order().collect::<Vec<_>>() {
185        let node = graph.node(id).clone();
186        if let Some(shape) = rlx_ir::infer_shape::infer_output_shape(graph, &node) {
187            graph.node_mut(id).shape = shape;
188        }
189    }
190}
191
192impl CompilePipeline {
193    pub fn new(target: FusionTarget) -> Self {
194        let mut opts = match target {
195            FusionTarget::Cpu => FusionOptions::for_cpu(),
196            FusionTarget::Metal => FusionOptions::for_metal(),
197            FusionTarget::Wgpu => FusionOptions::for_wgpu(),
198            _ => FusionOptions::default(),
199        };
200        opts.fusion_limits = fusion_limits_for_target(target);
201        Self {
202            target,
203            opts,
204            ..Self::default()
205        }
206    }
207
208    pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
209        self.assert_fusion_clean = assert;
210        self
211    }
212
213    /// HIR → MIR (block lowering only).
214    pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
215        let mut mir = hir.lower_to_mir()?;
216        rlx_ir::dynamic::sync_graph_shapes(mir.as_graph_mut());
217        debug_assert_graph!(mir.as_graph(), "hir→mir");
218        Ok(mir)
219    }
220
221    /// Optional cleanup before fusion (DCE + control-flow lowering).
222    pub fn preprocess_mir(mir: MirModule) -> MirModule {
223        use rlx_fusion::pass::Pass as _;
224        let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
225        let graph = DeadCodeElimination.run(graph);
226        MirModule::from_graph(graph)
227    }
228
229    pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
230        self.supported_ops = Some(ops);
231        self
232    }
233
234    pub fn with_kernel_dispatch(
235        mut self,
236        policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
237    ) -> Self {
238        self.kernel_dispatch.policy = policy;
239        self
240    }
241
242    pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
243        self.kernel_dispatch = config;
244        self
245    }
246
247    fn effective_supported(&self) -> &'static [OpKind] {
248        self.supported_ops
249            .unwrap_or_else(|| supported_for_target(self.target))
250    }
251
252    fn backend_name(&self) -> &'static str {
253        match self.target {
254            FusionTarget::Cpu => "cpu",
255            FusionTarget::Metal => "metal",
256            FusionTarget::Mlx => "mlx",
257            FusionTarget::Wgpu => "wgpu",
258            FusionTarget::Cuda => "cuda",
259            FusionTarget::Rocm => "rocm",
260            FusionTarget::Tpu => "tpu",
261        }
262    }
263
264    /// Run fusion + cleanup passes on MIR, returning fusion diagnostics.
265    pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
266        let before = mir.as_graph().clone();
267        let passes =
268            fusion_passes_for_supported(self.effective_supported(), self.opts, self.target);
269        let limits = self.opts.fusion_limits;
270        let graph = with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false));
271        let graph = clip_elementwise_regions(graph, limits);
272        debug_assert_graph!(&graph, "fusion");
273        let mut graph = self.legalize_after_fusion(graph);
274        rlx_ir::dynamic::sync_graph_shapes(&mut graph);
275        fix_import_sequence_axis(&mut graph);
276        fix_lstm_output_shapes(&mut graph);
277        debug_assert_graph!(&graph, "legalize");
278        let mir = MirModule::from_graph(graph);
279        let fusion = FusionReport::analyze(&before, mir.as_graph());
280        (mir, fusion)
281    }
282
283    /// Rewrite / legalize fused IR against the backend op claim set.
284    /// Runs when [`supported_ops`](Self::supported_ops) is set (including
285    /// auto-wiring from [`Backend::supported_ops`] in [`crate::stages::pipeline_for`]).
286    pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
287        let Some(supported) = self.supported_ops else {
288            if self.kernel_dispatch.force_common_kinds.is_empty()
289                && self.kernel_dispatch.policy
290                    == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
291            {
292                return graph;
293            }
294            return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
295        };
296        if supported.is_empty() {
297            return graph;
298        }
299        let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
300        if let Err(errors) = legalize_for_backend(&graph, supported) {
301            panic!("{}", format_legalize_error(self.backend_name(), &errors));
302        }
303        graph
304    }
305
306    /// Run fusion + cleanup passes on MIR.
307    pub fn optimize(&self, mir: MirModule) -> MirModule {
308        self.optimize_with_report(mir).0
309    }
310
311    /// MIR → LIR (memory plan + schedule + phases + I/O manifest).
312    pub fn plan_lir(&self, mir: MirModule) -> LirModule {
313        self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
314    }
315
316    /// MIR → LIR with explicit boundary allocation policy.
317    pub fn plan_lir_with_options(
318        &self,
319        mir: MirModule,
320        opts: memory::MemoryPlanOptions,
321    ) -> LirModule {
322        let graph = mir.as_graph().clone();
323        let plan = memory::plan_memory_with_options(&graph, self.arena_alignment, opts);
324        LirModule::new(
325            mir,
326            lir_buffer_plan_from_memory(&graph, &plan, self.arena_alignment),
327        )
328    }
329
330    /// Bind symbolic dims and re-run buffer planning on specialized MIR.
331    pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
332        use rlx_ir::dynamic::{
333            bind_graph, sync_concat_shapes, sync_expand_ops, sync_graph_shapes, sync_narrow_ops,
334            sync_reshape_ops,
335        };
336        let mut bound = bind_graph(lir.as_graph(), binding);
337        sync_reshape_ops(&mut bound);
338        sync_concat_shapes(&mut bound);
339        sync_narrow_ops(&mut bound);
340        sync_expand_ops(&mut bound);
341        sync_graph_shapes(&mut bound);
342        debug_assert_graph!(&bound, "specialize");
343        self.plan_lir(MirModule::from_graph(bound))
344    }
345
346    fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
347        debug_assert_graph!(mir.as_graph(), "pre-lir");
348        if self.assert_fusion_clean && !fusion.missed.is_empty() {
349            panic!(
350                "fusion contract violated: {} missed patterns\n{fusion}",
351                fusion.missed.len()
352            );
353        }
354        CompileResult {
355            lir: self.plan_lir(mir),
356            fusion,
357        }
358    }
359
360    /// HIR → LIR in one call with fusion report.
361    pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
362        if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
363            let name = hir.name.clone();
364            let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
365            crate::inspect::maybe_dump_pipeline(&dump, &name);
366        }
367        let mir = Self::lower_hir(hir)?;
368        let (mir, fusion) = self.optimize_with_report(mir);
369        Ok(self.finish(mir, fusion))
370    }
371
372    /// Legacy MIR entry: optimize + plan with fusion report.
373    pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
374        let (mir, fusion) = self.optimize_with_report(mir);
375        self.finish(mir, fusion)
376    }
377
378    /// Legacy entry: optimize an existing graph and plan buffers.
379    pub fn compile_graph(&self, graph: Graph) -> CompileResult {
380        self.compile_mir(MirModule::from_graph(graph))
381    }
382
383    /// Unified entry for [`GraphModule`] at any pipeline stage.
384    pub fn compile_module(
385        &self,
386        module: GraphModule,
387    ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
388        match module.stage() {
389            GraphStage::Hir => {
390                let hir = module
391                    .into_hir()
392                    .expect("GraphModule stage() / into_hir mismatch");
393                self.compile_hir(hir)
394            }
395            GraphStage::Mir => {
396                let mir = module.into_mir()?;
397                Ok(self.compile_mir(mir))
398            }
399            GraphStage::Lir => Ok(CompileResult {
400                lir: module
401                    .into_lir()
402                    .expect("GraphModule stage() / into_lir mismatch"),
403                fusion: FusionReport::default(),
404            }),
405        }
406    }
407}
408
409impl From<&MemoryPlan> for LirBufferPlan {
410    fn from(plan: &MemoryPlan) -> Self {
411        LirBufferPlan {
412            arena_size: plan.arena_size,
413            assignments: plan
414                .assignments
415                .iter()
416                .map(|(id, slot)| {
417                    (
418                        *id,
419                        LirBufferSlot {
420                            offset: slot.offset,
421                            size: slot.size,
422                        },
423                    )
424                })
425                .collect(),
426            schedule: plan.schedule.clone(),
427            ..Default::default()
428        }
429    }
430}
431
432impl From<&LirBufferPlan> for MemoryPlan {
433    fn from(plan: &LirBufferPlan) -> Self {
434        MemoryPlan {
435            arena_size: plan.arena_size,
436            assignments: plan
437                .assignments
438                .iter()
439                .map(|(id, slot)| {
440                    (
441                        *id,
442                        memory::BufferSlot {
443                            offset: slot.offset,
444                            size: slot.size,
445                        },
446                    )
447                })
448                .collect(),
449            schedule: plan.schedule.clone(),
450        }
451    }
452}
453
454pub(crate) fn lir_buffer_plan_from_memory(
455    graph: &Graph,
456    plan: &MemoryPlan,
457    alignment: usize,
458) -> LirBufferPlan {
459    let view_aliases = memory::collect_view_aliases(graph)
460        .into_iter()
461        .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
462        .collect();
463    LirBufferPlan {
464        arena_size: plan.arena_size,
465        assignments: plan
466            .assignments
467            .iter()
468            .map(|(id, slot)| {
469                (
470                    *id,
471                    LirBufferSlot {
472                        offset: slot.offset,
473                        size: slot.size,
474                    },
475                )
476            })
477            .collect(),
478        schedule: plan.schedule.clone(),
479        view_aliases,
480        phases: derive_phases(graph),
481        io: LirIoManifest::collect(graph),
482        alignment,
483        dynamic_symbols: collect_dynamic_symbols(graph),
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use rlx_ir::DType;
491    use rlx_ir::Op;
492    use rlx_ir::Shape;
493    use rlx_ir::hir::FusionPolicy;
494
495    fn f32_shape(d: &[usize]) -> Shape {
496        Shape::new(d, DType::F32)
497    }
498
499    #[test]
500    fn pipeline_hir_to_lir() {
501        let mut hir = HirModule::new("layer");
502        let x = hir.input("x", f32_shape(&[2, 128]));
503        let w = hir.param("w", f32_shape(&[128, 128]));
504        let b = hir.param("b", f32_shape(&[128]));
505        let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
506        hir.outputs = vec![h];
507
508        let pipe = CompilePipeline::new(FusionTarget::Cpu);
509        let result = pipe.compile_hir(hir).expect("compile");
510        assert!(result.lir.mir.len() <= 5);
511        assert!(result.lir.arena_size() > 0);
512        assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
513        assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
514    }
515
516    #[test]
517    fn direct_hir_swiglu_emits_fused_op() {
518        let mut hir = HirModule::new("ffn");
519        let x = hir.input("x", f32_shape(&[4, 768]));
520        let up_w = hir.param("up", f32_shape(&[768, 2048]));
521        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
522        let down_w = hir.param("down", f32_shape(&[2048, 768]));
523        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
524        hir.outputs = vec![out];
525
526        let pipe = CompilePipeline::new(FusionTarget::Cpu);
527        let result = pipe.compile_hir(hir).expect("compile");
528        let g = result.lir.mir.as_graph();
529        assert!(
530            g.nodes()
531                .iter()
532                .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
533            "direct HIR SwiGLU should lower to FusedSwiGLU"
534        );
535        assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
536    }
537
538    #[test]
539    fn compile_module_from_graph_define() {
540        let module = GraphModule::define("ffn", |m| {
541            let x = m.input("x", f32_shape(&[2, 64]));
542            let w = m.param("w", f32_shape(&[64, 64]));
543            m.linear(x, w, None, None, f32_shape(&[2, 64]))
544        });
545        assert_eq!(module.stage(), GraphStage::Hir);
546
547        let pipe = CompilePipeline::new(FusionTarget::Cpu);
548        let result = pipe.compile_module(module).expect("compile_module");
549        assert!(result.lir.arena_size() > 0);
550    }
551
552    #[test]
553    fn fusable_policy_leaves_room_for_passes() {
554        let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
555        let x = hir.input("x", f32_shape(&[4, 768]));
556        let up_w = hir.param("up", f32_shape(&[768, 2048]));
557        let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
558        let down_w = hir.param("down", f32_shape(&[2048, 768]));
559        let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
560        hir.outputs = vec![out];
561
562        let mir = CompilePipeline::lower_hir(hir).expect("lower");
563        let g = mir.as_graph();
564        assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
565        assert_eq!(g.len(), 9);
566
567        let pipe = CompilePipeline::new(FusionTarget::Cpu);
568        let result = pipe.compile_mir(mir);
569        assert!(result.fusion.fused_swiglu >= 1);
570    }
571
572    #[test]
573    fn lir_plan_includes_phases_io_and_fingerprint() {
574        use rlx_ir::phase::Phase;
575
576        let mut hir = HirModule::new("stream");
577        let x = hir.input("x", f32_shape(&[1, 8]));
578        let w = hir.param("w", f32_shape(&[8, 4]));
579        let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
580        hir.set_outputs(vec![mm]);
581
582        let result = CompilePipeline::new(FusionTarget::Cpu)
583            .compile_hir(hir)
584            .expect("compile");
585        assert!(!result.lir.buffers.phases.is_empty());
586        let input_id = result.lir.buffers.io.inputs[0].1;
587        assert_eq!(
588            result.lir.buffers.phases.get(input_id),
589            Some(Phase::Prologue)
590        );
591        assert_eq!(result.lir.buffers.io.inputs.len(), 1);
592        assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
593        assert_eq!(result.lir.buffers.alignment, 64);
594    }
595
596    #[test]
597    fn decode_hidden_shape_not_expanded_without_env() {
598        // Qwen3 talker decode uses [1, 1, H] hidden states; must not be expanded to
599        // [1, RLX_ONNX_SEQUENCE_LENGTH, H] unless that env is set explicitly.
600        let mut g = Graph::new("decode_out");
601        let x = g.input("x", f32_shape(&[1, 1, 1024]));
602        g.set_outputs(vec![x]);
603        let pipe = CompilePipeline::new(FusionTarget::Cpu);
604        let result = pipe.compile_graph(g);
605        let out = result
606            .lir
607            .mir
608            .as_graph()
609            .node(result.lir.mir.as_graph().outputs[0]);
610        assert_eq!(out.shape.dims()[1].unwrap_static(), 1);
611        assert_eq!(out.shape.num_elements(), Some(1024));
612    }
613
614    #[test]
615    fn dynamic_graph_compiles_and_specializes() {
616        use rlx_ir::DimBinding;
617        use rlx_ir::infer::GraphExt as _;
618        use rlx_ir::sym;
619
620        let mut g = Graph::new("dyn");
621        let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
622        let w = g.param("w", Shape::new(&[4, 8], DType::F32));
623        let y = g.mm(x, w);
624        g.set_outputs(vec![y]);
625
626        let pipe = CompilePipeline::new(FusionTarget::Cpu);
627        let result = pipe.compile_graph(g);
628        assert!(result.has_dynamic_dims());
629        assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
630
631        let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
632        assert!(bound.lir.is_fully_static());
633        assert!(bound.lir.arena_size() > 0);
634    }
635}