1use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37 FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38 supported_for_target,
39};
40use crate::legalize::{format_legalize_error, legalize_for_backend};
41use crate::memory::{self, MemoryPlan};
42use crate::rewrite::rewrite_for_backend_with_config;
43use rlx_fusion::fusion_report::FusionReport;
44use rlx_fusion::pass::run_passes;
45use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
46use rlx_ir::OpKind;
47use rlx_ir::logical_kernel::KernelDispatchConfig;
48
49#[derive(Debug, Clone)]
51pub struct CompileResult {
52 pub lir: LirModule,
53 pub fusion: FusionReport,
54}
55
56impl CompileResult {
57 pub fn has_dynamic_dims(&self) -> bool {
58 self.lir.has_dynamic_dims()
59 }
60
61 pub fn dynamic_symbols(&self) -> &[u32] {
62 self.lir.dynamic_symbols()
63 }
64
65 pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
67 Self {
68 lir: pipeline.specialize_lir(&self.lir, binding),
69 fusion: self.fusion.clone(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
76pub struct CompilePipeline {
77 pub target: FusionTarget,
78 pub opts: FusionOptions,
79 pub arena_alignment: usize,
80 pub assert_fusion_clean: bool,
83 pub supported_ops: Option<&'static [OpKind]>,
87 pub kernel_dispatch: KernelDispatchConfig,
89}
90
91impl Default for CompilePipeline {
92 fn default() -> Self {
93 Self {
94 target: FusionTarget::Cpu,
95 opts: FusionOptions::for_cpu(),
96 arena_alignment: 64,
97 assert_fusion_clean: false,
98 supported_ops: None,
99 kernel_dispatch: KernelDispatchConfig::from_env(),
100 }
101 }
102}
103
104fn lstm_y_shape(x: &rlx_ir::Shape, hidden_size: usize, bidirectional: bool) -> rlx_ir::Shape {
105 let dirs = if bidirectional { 2 } else { 1 };
106 if x.rank() == 3 {
107 let seq = x.dim(0).unwrap_static();
108 let batch = x.dim(1).unwrap_static().max(1);
109 return rlx_ir::Shape::new(&[seq, dirs, batch, hidden_size], x.dtype());
110 }
111 rlx_ir::Shape::new(&[1, dirs, 1, hidden_size], x.dtype())
112}
113
114fn fix_import_lstm_x_shape(x: &rlx_ir::Shape) -> rlx_ir::Shape {
116 if x.rank() != 3 {
117 return x.clone();
118 }
119 let d0 = x.dim(0).unwrap_static();
120 let d1 = x.dim(1).unwrap_static();
121 let d2 = x.dim(2).unwrap_static();
122 if d0 == 1 && d1 <= 1 && (d2 == 640 || d2 == 512) {
123 let seq = std::env::var("RLX_ONNX_SEQUENCE_LENGTH")
124 .ok()
125 .and_then(|s| s.parse().ok())
126 .unwrap_or(128);
127 return rlx_ir::Shape::new(&[seq, d1.max(1), d2], x.dtype());
128 }
129 x.clone()
130}
131
132fn fix_lstm_output_shapes(graph: &mut Graph) {
133 use rlx_ir::Op;
134 let ids: Vec<NodeId> = graph.nodes().iter().map(|n| n.id).collect();
135 for id in ids {
136 let node = graph.node(id).clone();
137 let Op::Custom { name, attrs, .. } = &node.op else {
138 continue;
139 };
140 if !name.contains("LSTM") {
141 continue;
142 }
143 let hidden_size = if attrs.len() >= 4 {
144 u32::from_le_bytes(attrs[0..4].try_into().unwrap()) as usize
145 } else {
146 256
147 };
148 let bidirectional = attrs.len() > 4 && attrs[4] != 0;
149 let x_id = node.inputs[0];
150 let x = fix_import_lstm_x_shape(&graph.node(x_id).shape);
151 graph.node_mut(x_id).shape = x.clone();
152 graph.node_mut(id).shape = lstm_y_shape(&x, hidden_size, bidirectional);
153 }
154}
155
156fn fix_import_sequence_axis(graph: &mut Graph) {
162 let Ok(seq_str) = std::env::var("RLX_ONNX_SEQUENCE_LENGTH") else {
163 return;
164 };
165 let seq: usize = match seq_str.parse() {
166 Ok(s) if s > 1 => s,
167 _ => return,
168 };
169 for id in graph.nodes().iter().map(|n| n.id).collect::<Vec<_>>() {
170 let node = graph.node(id);
171 if node.shape.rank() != 3 {
172 continue;
173 }
174 let dims: Vec<_> = node
175 .shape
176 .dims()
177 .iter()
178 .map(|d| d.unwrap_static())
179 .collect();
180 if dims[0] == 1 && dims[1] == 1 && dims[2] >= 64 {
181 graph.node_mut(id).shape = rlx_ir::Shape::new(&[1, seq, dims[2]], node.shape.dtype());
182 }
183 }
184 for id in graph.topo_order().collect::<Vec<_>>() {
185 let node = graph.node(id).clone();
186 if let Some(shape) = rlx_ir::infer_shape::infer_output_shape(graph, &node) {
187 graph.node_mut(id).shape = shape;
188 }
189 }
190}
191
192impl CompilePipeline {
193 pub fn new(target: FusionTarget) -> Self {
194 let mut opts = match target {
195 FusionTarget::Cpu => FusionOptions::for_cpu(),
196 FusionTarget::Metal => FusionOptions::for_metal(),
197 FusionTarget::Wgpu => FusionOptions::for_wgpu(),
198 _ => FusionOptions::default(),
199 };
200 opts.fusion_limits = fusion_limits_for_target(target);
201 Self {
202 target,
203 opts,
204 ..Self::default()
205 }
206 }
207
208 pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
209 self.assert_fusion_clean = assert;
210 self
211 }
212
213 pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
215 let mut mir = hir.lower_to_mir()?;
216 rlx_ir::dynamic::sync_graph_shapes(mir.as_graph_mut());
217 debug_assert_graph!(mir.as_graph(), "hir→mir");
218 Ok(mir)
219 }
220
221 pub fn preprocess_mir(mir: MirModule) -> MirModule {
223 use rlx_fusion::pass::Pass as _;
224 let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
225 let graph = DeadCodeElimination.run(graph);
226 MirModule::from_graph(graph)
227 }
228
229 pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
230 self.supported_ops = Some(ops);
231 self
232 }
233
234 pub fn with_kernel_dispatch(
235 mut self,
236 policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
237 ) -> Self {
238 self.kernel_dispatch.policy = policy;
239 self
240 }
241
242 pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
243 self.kernel_dispatch = config;
244 self
245 }
246
247 fn effective_supported(&self) -> &'static [OpKind] {
248 self.supported_ops
249 .unwrap_or_else(|| supported_for_target(self.target))
250 }
251
252 fn backend_name(&self) -> &'static str {
253 match self.target {
254 FusionTarget::Cpu => "cpu",
255 FusionTarget::Metal => "metal",
256 FusionTarget::Mlx => "mlx",
257 FusionTarget::Wgpu => "wgpu",
258 FusionTarget::Cuda => "cuda",
259 FusionTarget::Rocm => "rocm",
260 FusionTarget::Tpu => "tpu",
261 }
262 }
263
264 pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
266 let before = mir.as_graph().clone();
267 let passes =
268 fusion_passes_for_supported(self.effective_supported(), self.opts, self.target);
269 let limits = self.opts.fusion_limits;
270 let graph = with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false));
271 let graph = clip_elementwise_regions(graph, limits);
272 debug_assert_graph!(&graph, "fusion");
273 let mut graph = self.legalize_after_fusion(graph);
274 rlx_ir::dynamic::sync_graph_shapes(&mut graph);
275 fix_import_sequence_axis(&mut graph);
276 fix_lstm_output_shapes(&mut graph);
277 debug_assert_graph!(&graph, "legalize");
278 let mir = MirModule::from_graph(graph);
279 let fusion = FusionReport::analyze(&before, mir.as_graph());
280 (mir, fusion)
281 }
282
283 pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
287 let Some(supported) = self.supported_ops else {
288 if self.kernel_dispatch.force_common_kinds.is_empty()
289 && self.kernel_dispatch.policy
290 == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
291 {
292 return graph;
293 }
294 return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
295 };
296 if supported.is_empty() {
297 return graph;
298 }
299 let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
300 if let Err(errors) = legalize_for_backend(&graph, supported) {
301 panic!("{}", format_legalize_error(self.backend_name(), &errors));
302 }
303 graph
304 }
305
306 pub fn optimize(&self, mir: MirModule) -> MirModule {
308 self.optimize_with_report(mir).0
309 }
310
311 pub fn plan_lir(&self, mir: MirModule) -> LirModule {
313 self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
314 }
315
316 pub fn plan_lir_with_options(
318 &self,
319 mir: MirModule,
320 opts: memory::MemoryPlanOptions,
321 ) -> LirModule {
322 let graph = mir.as_graph().clone();
323 let plan = memory::plan_memory_with_options(&graph, self.arena_alignment, opts);
324 LirModule::new(
325 mir,
326 lir_buffer_plan_from_memory(&graph, &plan, self.arena_alignment),
327 )
328 }
329
330 pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
332 use rlx_ir::dynamic::{
333 bind_graph, sync_concat_shapes, sync_expand_ops, sync_graph_shapes, sync_narrow_ops,
334 sync_reshape_ops,
335 };
336 let mut bound = bind_graph(lir.as_graph(), binding);
337 sync_reshape_ops(&mut bound);
338 sync_concat_shapes(&mut bound);
339 sync_narrow_ops(&mut bound);
340 sync_expand_ops(&mut bound);
341 sync_graph_shapes(&mut bound);
342 debug_assert_graph!(&bound, "specialize");
343 self.plan_lir(MirModule::from_graph(bound))
344 }
345
346 fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
347 debug_assert_graph!(mir.as_graph(), "pre-lir");
348 if self.assert_fusion_clean && !fusion.missed.is_empty() {
349 panic!(
350 "fusion contract violated: {} missed patterns\n{fusion}",
351 fusion.missed.len()
352 );
353 }
354 CompileResult {
355 lir: self.plan_lir(mir),
356 fusion,
357 }
358 }
359
360 pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
362 if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
363 let name = hir.name.clone();
364 let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
365 crate::inspect::maybe_dump_pipeline(&dump, &name);
366 }
367 let mir = Self::lower_hir(hir)?;
368 let (mir, fusion) = self.optimize_with_report(mir);
369 Ok(self.finish(mir, fusion))
370 }
371
372 pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
374 let (mir, fusion) = self.optimize_with_report(mir);
375 self.finish(mir, fusion)
376 }
377
378 pub fn compile_graph(&self, graph: Graph) -> CompileResult {
380 self.compile_mir(MirModule::from_graph(graph))
381 }
382
383 pub fn compile_module(
385 &self,
386 module: GraphModule,
387 ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
388 match module.stage() {
389 GraphStage::Hir => {
390 let hir = module
391 .into_hir()
392 .expect("GraphModule stage() / into_hir mismatch");
393 self.compile_hir(hir)
394 }
395 GraphStage::Mir => {
396 let mir = module.into_mir()?;
397 Ok(self.compile_mir(mir))
398 }
399 GraphStage::Lir => Ok(CompileResult {
400 lir: module
401 .into_lir()
402 .expect("GraphModule stage() / into_lir mismatch"),
403 fusion: FusionReport::default(),
404 }),
405 }
406 }
407}
408
409impl From<&MemoryPlan> for LirBufferPlan {
410 fn from(plan: &MemoryPlan) -> Self {
411 LirBufferPlan {
412 arena_size: plan.arena_size,
413 assignments: plan
414 .assignments
415 .iter()
416 .map(|(id, slot)| {
417 (
418 *id,
419 LirBufferSlot {
420 offset: slot.offset,
421 size: slot.size,
422 },
423 )
424 })
425 .collect(),
426 schedule: plan.schedule.clone(),
427 ..Default::default()
428 }
429 }
430}
431
432impl From<&LirBufferPlan> for MemoryPlan {
433 fn from(plan: &LirBufferPlan) -> Self {
434 MemoryPlan {
435 arena_size: plan.arena_size,
436 assignments: plan
437 .assignments
438 .iter()
439 .map(|(id, slot)| {
440 (
441 *id,
442 memory::BufferSlot {
443 offset: slot.offset,
444 size: slot.size,
445 },
446 )
447 })
448 .collect(),
449 schedule: plan.schedule.clone(),
450 }
451 }
452}
453
454pub(crate) fn lir_buffer_plan_from_memory(
455 graph: &Graph,
456 plan: &MemoryPlan,
457 alignment: usize,
458) -> LirBufferPlan {
459 let view_aliases = memory::collect_view_aliases(graph)
460 .into_iter()
461 .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
462 .collect();
463 LirBufferPlan {
464 arena_size: plan.arena_size,
465 assignments: plan
466 .assignments
467 .iter()
468 .map(|(id, slot)| {
469 (
470 *id,
471 LirBufferSlot {
472 offset: slot.offset,
473 size: slot.size,
474 },
475 )
476 })
477 .collect(),
478 schedule: plan.schedule.clone(),
479 view_aliases,
480 phases: derive_phases(graph),
481 io: LirIoManifest::collect(graph),
482 alignment,
483 dynamic_symbols: collect_dynamic_symbols(graph),
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use rlx_ir::DType;
491 use rlx_ir::Op;
492 use rlx_ir::Shape;
493 use rlx_ir::hir::FusionPolicy;
494
495 fn f32_shape(d: &[usize]) -> Shape {
496 Shape::new(d, DType::F32)
497 }
498
499 #[test]
500 fn pipeline_hir_to_lir() {
501 let mut hir = HirModule::new("layer");
502 let x = hir.input("x", f32_shape(&[2, 128]));
503 let w = hir.param("w", f32_shape(&[128, 128]));
504 let b = hir.param("b", f32_shape(&[128]));
505 let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
506 hir.outputs = vec![h];
507
508 let pipe = CompilePipeline::new(FusionTarget::Cpu);
509 let result = pipe.compile_hir(hir).expect("compile");
510 assert!(result.lir.mir.len() <= 5);
511 assert!(result.lir.arena_size() > 0);
512 assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
513 assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
514 }
515
516 #[test]
517 fn direct_hir_swiglu_emits_fused_op() {
518 let mut hir = HirModule::new("ffn");
519 let x = hir.input("x", f32_shape(&[4, 768]));
520 let up_w = hir.param("up", f32_shape(&[768, 2048]));
521 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
522 let down_w = hir.param("down", f32_shape(&[2048, 768]));
523 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
524 hir.outputs = vec![out];
525
526 let pipe = CompilePipeline::new(FusionTarget::Cpu);
527 let result = pipe.compile_hir(hir).expect("compile");
528 let g = result.lir.mir.as_graph();
529 assert!(
530 g.nodes()
531 .iter()
532 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
533 "direct HIR SwiGLU should lower to FusedSwiGLU"
534 );
535 assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
536 }
537
538 #[test]
539 fn compile_module_from_graph_define() {
540 let module = GraphModule::define("ffn", |m| {
541 let x = m.input("x", f32_shape(&[2, 64]));
542 let w = m.param("w", f32_shape(&[64, 64]));
543 m.linear(x, w, None, None, f32_shape(&[2, 64]))
544 });
545 assert_eq!(module.stage(), GraphStage::Hir);
546
547 let pipe = CompilePipeline::new(FusionTarget::Cpu);
548 let result = pipe.compile_module(module).expect("compile_module");
549 assert!(result.lir.arena_size() > 0);
550 }
551
552 #[test]
553 fn fusable_policy_leaves_room_for_passes() {
554 let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
555 let x = hir.input("x", f32_shape(&[4, 768]));
556 let up_w = hir.param("up", f32_shape(&[768, 2048]));
557 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
558 let down_w = hir.param("down", f32_shape(&[2048, 768]));
559 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
560 hir.outputs = vec![out];
561
562 let mir = CompilePipeline::lower_hir(hir).expect("lower");
563 let g = mir.as_graph();
564 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
565 assert_eq!(g.len(), 9);
566
567 let pipe = CompilePipeline::new(FusionTarget::Cpu);
568 let result = pipe.compile_mir(mir);
569 assert!(result.fusion.fused_swiglu >= 1);
570 }
571
572 #[test]
573 fn lir_plan_includes_phases_io_and_fingerprint() {
574 use rlx_ir::phase::Phase;
575
576 let mut hir = HirModule::new("stream");
577 let x = hir.input("x", f32_shape(&[1, 8]));
578 let w = hir.param("w", f32_shape(&[8, 4]));
579 let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
580 hir.set_outputs(vec![mm]);
581
582 let result = CompilePipeline::new(FusionTarget::Cpu)
583 .compile_hir(hir)
584 .expect("compile");
585 assert!(!result.lir.buffers.phases.is_empty());
586 let input_id = result.lir.buffers.io.inputs[0].1;
587 assert_eq!(
588 result.lir.buffers.phases.get(input_id),
589 Some(Phase::Prologue)
590 );
591 assert_eq!(result.lir.buffers.io.inputs.len(), 1);
592 assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
593 assert_eq!(result.lir.buffers.alignment, 64);
594 }
595
596 #[test]
597 fn decode_hidden_shape_not_expanded_without_env() {
598 let mut g = Graph::new("decode_out");
601 let x = g.input("x", f32_shape(&[1, 1, 1024]));
602 g.set_outputs(vec![x]);
603 let pipe = CompilePipeline::new(FusionTarget::Cpu);
604 let result = pipe.compile_graph(g);
605 let out = result
606 .lir
607 .mir
608 .as_graph()
609 .node(result.lir.mir.as_graph().outputs[0]);
610 assert_eq!(out.shape.dims()[1].unwrap_static(), 1);
611 assert_eq!(out.shape.num_elements(), Some(1024));
612 }
613
614 #[test]
615 fn dynamic_graph_compiles_and_specializes() {
616 use rlx_ir::DimBinding;
617 use rlx_ir::infer::GraphExt as _;
618 use rlx_ir::sym;
619
620 let mut g = Graph::new("dyn");
621 let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
622 let w = g.param("w", Shape::new(&[4, 8], DType::F32));
623 let y = g.mm(x, w);
624 g.set_outputs(vec![y]);
625
626 let pipe = CompilePipeline::new(FusionTarget::Cpu);
627 let result = pipe.compile_graph(g);
628 assert!(result.has_dynamic_dims());
629 assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
630
631 let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
632 assert!(bound.lir.is_fully_static());
633 assert!(bound.lir.arena_size() > 0);
634 }
635}