1use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage, NodeId};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37 FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38 supported_for_target,
39};
40use crate::fusion_target::with_fusion_target;
41use crate::legalize::{format_legalize_error, legalize_for_backend};
42use crate::memory::{self, MemoryPlan};
43use crate::rewrite::rewrite_for_backend_with_config;
44use rlx_fusion::fusion_report::FusionReport;
45use rlx_fusion::pass::run_passes;
46use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
47use rlx_ir::OpKind;
48use rlx_ir::logical_kernel::KernelDispatchConfig;
49
50#[derive(Debug, Clone)]
52pub struct CompileResult {
53 pub lir: LirModule,
54 pub fusion: FusionReport,
55}
56
57impl CompileResult {
58 pub fn has_dynamic_dims(&self) -> bool {
59 self.lir.has_dynamic_dims()
60 }
61
62 pub fn dynamic_symbols(&self) -> &[u32] {
63 self.lir.dynamic_symbols()
64 }
65
66 pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
68 Self {
69 lir: pipeline.specialize_lir(&self.lir, binding),
70 fusion: self.fusion.clone(),
71 }
72 }
73}
74
75#[derive(Debug, Clone, Copy)]
77pub struct CompilePipeline {
78 pub target: FusionTarget,
79 pub opts: FusionOptions,
80 pub arena_alignment: usize,
81 pub assert_fusion_clean: bool,
84 pub supported_ops: Option<&'static [OpKind]>,
88 pub kernel_dispatch: KernelDispatchConfig,
90}
91
92impl Default for CompilePipeline {
93 fn default() -> Self {
94 Self {
95 target: FusionTarget::Cpu,
96 opts: FusionOptions::for_cpu(),
97 arena_alignment: 64,
98 assert_fusion_clean: false,
99 supported_ops: None,
100 kernel_dispatch: KernelDispatchConfig::from_env(),
101 }
102 }
103}
104
105fn lstm_y_shape(x: &rlx_ir::Shape, hidden_size: usize, bidirectional: bool) -> rlx_ir::Shape {
106 let dirs = if bidirectional { 2 } else { 1 };
107 if x.rank() == 3 {
108 let seq = x.dim(0).unwrap_static();
109 let batch = x.dim(1).unwrap_static().max(1);
110 return rlx_ir::Shape::new(&[seq, dirs, batch, hidden_size], x.dtype());
111 }
112 rlx_ir::Shape::new(&[1, dirs, 1, hidden_size], x.dtype())
113}
114
115fn fix_import_lstm_x_shape(x: &rlx_ir::Shape) -> rlx_ir::Shape {
117 if x.rank() != 3 {
118 return x.clone();
119 }
120 let d0 = x.dim(0).unwrap_static();
121 let d1 = x.dim(1).unwrap_static();
122 let d2 = x.dim(2).unwrap_static();
123 if d0 == 1 && d1 <= 1 && (d2 == 640 || d2 == 512) {
124 let seq = std::env::var("RLX_ONNX_SEQUENCE_LENGTH")
125 .ok()
126 .and_then(|s| s.parse().ok())
127 .unwrap_or(128);
128 return rlx_ir::Shape::new(&[seq, d1.max(1), d2], x.dtype());
129 }
130 x.clone()
131}
132
133fn fix_lstm_output_shapes(graph: &mut Graph) {
134 use rlx_ir::Op;
135 let ids: Vec<NodeId> = graph.nodes().iter().map(|n| n.id).collect();
136 for id in ids {
137 let node = graph.node(id).clone();
138 let Op::Custom { name, attrs, .. } = &node.op else {
139 continue;
140 };
141 if !name.contains("LSTM") {
142 continue;
143 }
144 let hidden_size = if attrs.len() >= 4 {
145 u32::from_le_bytes(attrs[0..4].try_into().unwrap()) as usize
146 } else {
147 256
148 };
149 let bidirectional = attrs.len() > 4 && attrs[4] != 0;
150 let x_id = node.inputs[0];
151 let x = fix_import_lstm_x_shape(&graph.node(x_id).shape);
152 graph.node_mut(x_id).shape = x.clone();
153 graph.node_mut(id).shape = lstm_y_shape(&x, hidden_size, bidirectional);
154 }
155}
156
157fn fix_import_sequence_axis(graph: &mut Graph) {
163 let Ok(seq_str) = std::env::var("RLX_ONNX_SEQUENCE_LENGTH") else {
164 return;
165 };
166 let seq: usize = match seq_str.parse() {
167 Ok(s) if s > 1 => s,
168 _ => return,
169 };
170 for id in graph.nodes().iter().map(|n| n.id).collect::<Vec<_>>() {
171 let node = graph.node(id);
172 if node.shape.rank() != 3 {
173 continue;
174 }
175 let dims: Vec<_> = node
176 .shape
177 .dims()
178 .iter()
179 .map(|d| d.unwrap_static())
180 .collect();
181 if dims[0] == 1 && dims[1] == 1 && dims[2] >= 64 {
182 graph.node_mut(id).shape = rlx_ir::Shape::new(&[1, seq, dims[2]], node.shape.dtype());
183 }
184 }
185 for id in graph.topo_order().collect::<Vec<_>>() {
186 let node = graph.node(id).clone();
187 if let Some(shape) = rlx_ir::infer_shape::infer_output_shape(graph, &node) {
188 graph.node_mut(id).shape = shape;
189 }
190 }
191}
192
193impl CompilePipeline {
194 pub fn new(target: FusionTarget) -> Self {
195 let mut opts = match target {
196 FusionTarget::Cpu => FusionOptions::for_cpu(),
197 FusionTarget::Metal => FusionOptions::for_metal(),
198 FusionTarget::Wgpu => FusionOptions::for_wgpu(),
199 _ => FusionOptions::default(),
200 };
201 opts.fusion_limits = fusion_limits_for_target(target);
202 Self {
203 target,
204 opts,
205 ..Self::default()
206 }
207 }
208
209 pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
210 self.assert_fusion_clean = assert;
211 self
212 }
213
214 pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
216 let mut mir = hir.lower_to_mir()?;
217 rlx_ir::dynamic::sync_graph_shapes(mir.as_graph_mut());
218 debug_assert_graph!(mir.as_graph(), "hir→mir");
219 Ok(mir)
220 }
221
222 pub fn preprocess_mir(mir: MirModule) -> MirModule {
224 use rlx_fusion::pass::Pass as _;
225 let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
226 let graph = DeadCodeElimination.run(graph);
227 MirModule::from_graph(graph)
228 }
229
230 pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
231 self.supported_ops = Some(ops);
232 self
233 }
234
235 pub fn with_kernel_dispatch(
236 mut self,
237 policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
238 ) -> Self {
239 self.kernel_dispatch.policy = policy;
240 self
241 }
242
243 pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
244 self.kernel_dispatch = config;
245 self
246 }
247
248 fn effective_supported(&self) -> &'static [OpKind] {
249 self.supported_ops
250 .unwrap_or_else(|| supported_for_target(self.target))
251 }
252
253 fn backend_name(&self) -> &'static str {
254 match self.target {
255 FusionTarget::Cpu => "cpu",
256 FusionTarget::Metal => "metal",
257 FusionTarget::Mlx => "mlx",
258 FusionTarget::Wgpu => "wgpu",
259 FusionTarget::Cuda => "cuda",
260 FusionTarget::Rocm => "rocm",
261 FusionTarget::Tpu => "tpu",
262 }
263 }
264
265 pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
267 let before = mir.as_graph().clone();
268 let passes =
269 fusion_passes_for_supported(self.effective_supported(), self.opts, self.target);
270 let limits = self.opts.fusion_limits;
271 let graph = with_fusion_target(self.target, || {
272 with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false))
273 });
274 let graph = clip_elementwise_regions(graph, limits);
275 debug_assert_graph!(&graph, "fusion");
276 let mut graph = self.legalize_after_fusion(graph);
277 rlx_ir::dynamic::sync_graph_shapes(&mut graph);
278 fix_import_sequence_axis(&mut graph);
279 fix_lstm_output_shapes(&mut graph);
280 debug_assert_graph!(&graph, "legalize");
281 let mir = MirModule::from_graph(graph);
282 let fusion = FusionReport::analyze(&before, mir.as_graph());
283 (mir, fusion)
284 }
285
286 pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
290 let Some(supported) = self.supported_ops else {
291 if self.kernel_dispatch.force_common_kinds.is_empty()
292 && self.kernel_dispatch.policy
293 == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
294 {
295 return graph;
296 }
297 return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
298 };
299 if supported.is_empty() {
300 return graph;
301 }
302 let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
303 if let Err(errors) = legalize_for_backend(&graph, supported) {
304 panic!("{}", format_legalize_error(self.backend_name(), &errors));
305 }
306 graph
307 }
308
309 pub fn optimize(&self, mir: MirModule) -> MirModule {
311 self.optimize_with_report(mir).0
312 }
313
314 pub fn plan_lir(&self, mir: MirModule) -> LirModule {
316 self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
317 }
318
319 pub fn plan_lir_with_options(
321 &self,
322 mir: MirModule,
323 opts: memory::MemoryPlanOptions,
324 ) -> LirModule {
325 let graph = mir.as_graph().clone();
326 let plan = memory::plan_memory_with_options(&graph, self.arena_alignment, opts);
327 LirModule::new(
328 mir,
329 lir_buffer_plan_from_memory(&graph, &plan, self.arena_alignment),
330 )
331 }
332
333 pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
335 use rlx_ir::dynamic::{
336 bind_graph, sync_concat_shapes, sync_expand_ops, sync_graph_shapes, sync_narrow_ops,
337 sync_reshape_ops,
338 };
339 let mut bound = bind_graph(lir.as_graph(), binding);
340 sync_reshape_ops(&mut bound);
341 sync_concat_shapes(&mut bound);
342 sync_narrow_ops(&mut bound);
343 sync_expand_ops(&mut bound);
344 sync_graph_shapes(&mut bound);
345 debug_assert_graph!(&bound, "specialize");
346 self.plan_lir(MirModule::from_graph(bound))
347 }
348
349 fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
350 debug_assert_graph!(mir.as_graph(), "pre-lir");
351 if self.assert_fusion_clean && !fusion.missed.is_empty() {
352 panic!(
353 "fusion contract violated: {} missed patterns\n{fusion}",
354 fusion.missed.len()
355 );
356 }
357 CompileResult {
358 lir: self.plan_lir(mir),
359 fusion,
360 }
361 }
362
363 pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
365 if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
366 let name = hir.name.clone();
367 let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
368 crate::inspect::maybe_dump_pipeline(&dump, &name);
369 }
370 let mir = Self::lower_hir(hir)?;
371 let (mir, fusion) = self.optimize_with_report(mir);
372 Ok(self.finish(mir, fusion))
373 }
374
375 pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
377 let (mir, fusion) = self.optimize_with_report(mir);
378 self.finish(mir, fusion)
379 }
380
381 pub fn compile_graph(&self, graph: Graph) -> CompileResult {
383 self.compile_mir(MirModule::from_graph(graph))
384 }
385
386 pub fn compile_module(
388 &self,
389 module: GraphModule,
390 ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
391 match module.stage() {
392 GraphStage::Hir => {
393 let hir = module
394 .into_hir()
395 .expect("GraphModule stage() / into_hir mismatch");
396 self.compile_hir(hir)
397 }
398 GraphStage::Mir => {
399 let mir = module.into_mir()?;
400 Ok(self.compile_mir(mir))
401 }
402 GraphStage::Lir => Ok(CompileResult {
403 lir: module
404 .into_lir()
405 .expect("GraphModule stage() / into_lir mismatch"),
406 fusion: FusionReport::default(),
407 }),
408 }
409 }
410}
411
412impl From<&MemoryPlan> for LirBufferPlan {
413 fn from(plan: &MemoryPlan) -> Self {
414 LirBufferPlan {
415 arena_size: plan.arena_size,
416 assignments: plan
417 .assignments
418 .iter()
419 .map(|(id, slot)| {
420 (
421 *id,
422 LirBufferSlot {
423 offset: slot.offset,
424 size: slot.size,
425 },
426 )
427 })
428 .collect(),
429 schedule: plan.schedule.clone(),
430 ..Default::default()
431 }
432 }
433}
434
435impl From<&LirBufferPlan> for MemoryPlan {
436 fn from(plan: &LirBufferPlan) -> Self {
437 MemoryPlan {
438 arena_size: plan.arena_size,
439 assignments: plan
440 .assignments
441 .iter()
442 .map(|(id, slot)| {
443 (
444 *id,
445 memory::BufferSlot {
446 offset: slot.offset,
447 size: slot.size,
448 },
449 )
450 })
451 .collect(),
452 schedule: plan.schedule.clone(),
453 }
454 }
455}
456
457pub(crate) fn lir_buffer_plan_from_memory(
458 graph: &Graph,
459 plan: &MemoryPlan,
460 alignment: usize,
461) -> LirBufferPlan {
462 let view_aliases = memory::collect_view_aliases(graph)
463 .into_iter()
464 .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
465 .collect();
466 LirBufferPlan {
467 arena_size: plan.arena_size,
468 assignments: plan
469 .assignments
470 .iter()
471 .map(|(id, slot)| {
472 (
473 *id,
474 LirBufferSlot {
475 offset: slot.offset,
476 size: slot.size,
477 },
478 )
479 })
480 .collect(),
481 schedule: plan.schedule.clone(),
482 view_aliases,
483 phases: derive_phases(graph),
484 io: LirIoManifest::collect(graph),
485 alignment,
486 dynamic_symbols: collect_dynamic_symbols(graph),
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::*;
493 use rlx_ir::DType;
494 use rlx_ir::Op;
495 use rlx_ir::Shape;
496 use rlx_ir::hir::FusionPolicy;
497
498 fn f32_shape(d: &[usize]) -> Shape {
499 Shape::new(d, DType::F32)
500 }
501
502 #[test]
503 fn pipeline_hir_to_lir() {
504 let mut hir = HirModule::new("layer");
505 let x = hir.input("x", f32_shape(&[2, 128]));
506 let w = hir.param("w", f32_shape(&[128, 128]));
507 let b = hir.param("b", f32_shape(&[128]));
508 let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
509 hir.outputs = vec![h];
510
511 let pipe = CompilePipeline::new(FusionTarget::Cpu);
512 let result = pipe.compile_hir(hir).expect("compile");
513 assert!(result.lir.mir.len() <= 5);
514 assert!(result.lir.arena_size() > 0);
515 assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
516 assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
517 }
518
519 #[test]
520 fn direct_hir_swiglu_emits_fused_op() {
521 let mut hir = HirModule::new("ffn");
522 let x = hir.input("x", f32_shape(&[4, 768]));
523 let up_w = hir.param("up", f32_shape(&[768, 2048]));
524 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
525 let down_w = hir.param("down", f32_shape(&[2048, 768]));
526 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
527 hir.outputs = vec![out];
528
529 let pipe = CompilePipeline::new(FusionTarget::Cpu);
530 let result = pipe.compile_hir(hir).expect("compile");
531 let g = result.lir.mir.as_graph();
532 assert!(
533 g.nodes()
534 .iter()
535 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
536 "direct HIR SwiGLU should lower to FusedSwiGLU"
537 );
538 assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
539 }
540
541 #[test]
542 fn compile_module_from_graph_define() {
543 let module = GraphModule::define("ffn", |m| {
544 let x = m.input("x", f32_shape(&[2, 64]));
545 let w = m.param("w", f32_shape(&[64, 64]));
546 m.linear(x, w, None, None, f32_shape(&[2, 64]))
547 });
548 assert_eq!(module.stage(), GraphStage::Hir);
549
550 let pipe = CompilePipeline::new(FusionTarget::Cpu);
551 let result = pipe.compile_module(module).expect("compile_module");
552 assert!(result.lir.arena_size() > 0);
553 }
554
555 #[test]
556 fn fusable_policy_leaves_room_for_passes() {
557 let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
558 let x = hir.input("x", f32_shape(&[4, 768]));
559 let up_w = hir.param("up", f32_shape(&[768, 2048]));
560 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
561 let down_w = hir.param("down", f32_shape(&[2048, 768]));
562 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
563 hir.outputs = vec![out];
564
565 let mir = CompilePipeline::lower_hir(hir).expect("lower");
566 let g = mir.as_graph();
567 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
568 assert_eq!(g.len(), 9);
569
570 let pipe = CompilePipeline::new(FusionTarget::Cpu);
571 let result = pipe.compile_mir(mir);
572 assert!(result.fusion.fused_swiglu >= 1);
573 }
574
575 #[test]
576 fn lir_plan_includes_phases_io_and_fingerprint() {
577 use rlx_ir::phase::Phase;
578
579 let mut hir = HirModule::new("stream");
580 let x = hir.input("x", f32_shape(&[1, 8]));
581 let w = hir.param("w", f32_shape(&[8, 4]));
582 let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
583 hir.set_outputs(vec![mm]);
584
585 let result = CompilePipeline::new(FusionTarget::Cpu)
586 .compile_hir(hir)
587 .expect("compile");
588 assert!(!result.lir.buffers.phases.is_empty());
589 let input_id = result.lir.buffers.io.inputs[0].1;
590 assert_eq!(
591 result.lir.buffers.phases.get(input_id),
592 Some(Phase::Prologue)
593 );
594 assert_eq!(result.lir.buffers.io.inputs.len(), 1);
595 assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
596 assert_eq!(result.lir.buffers.alignment, 64);
597 }
598
599 #[test]
600 fn decode_hidden_shape_not_expanded_without_env() {
601 let mut g = Graph::new("decode_out");
604 let x = g.input("x", f32_shape(&[1, 1, 1024]));
605 g.set_outputs(vec![x]);
606 let pipe = CompilePipeline::new(FusionTarget::Cpu);
607 let result = pipe.compile_graph(g);
608 let out = result
609 .lir
610 .mir
611 .as_graph()
612 .node(result.lir.mir.as_graph().outputs[0]);
613 assert_eq!(out.shape.dims()[1].unwrap_static(), 1);
614 assert_eq!(out.shape.num_elements(), Some(1024));
615 }
616
617 #[test]
618 fn dynamic_graph_compiles_and_specializes() {
619 use rlx_ir::DimBinding;
620 use rlx_ir::infer::GraphExt as _;
621 use rlx_ir::sym;
622
623 let mut g = Graph::new("dyn");
624 let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
625 let w = g.param("w", Shape::new(&[4, 8], DType::F32));
626 let y = g.mm(x, w);
627 g.set_outputs(vec![y]);
628
629 let pipe = CompilePipeline::new(FusionTarget::Cpu);
630 let result = pipe.compile_graph(g);
631 assert!(result.has_dynamic_dims());
632 assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
633
634 let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
635 assert!(bound.lir.is_fully_static());
636 assert!(bound.lir.arena_size() > 0);
637 }
638}