1use rlx_ir::dynamic::collect_dynamic_symbols;
28use rlx_ir::hir::HirModule;
29use rlx_ir::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest, LirModule, LirViewAlias};
30use rlx_ir::mir::MirModule;
31use rlx_ir::phase::derive_phases;
32use rlx_ir::{Graph, GraphModule, GraphStage};
33
34use crate::DeadCodeElimination;
35use crate::debug_assert_graph;
36use crate::fusion_pipeline::{
37 FusionOptions, FusionTarget, fusion_limits_for_target, fusion_passes_for_supported,
38 supported_for_target,
39};
40use crate::legalize::{format_legalize_error, legalize_for_backend};
41use crate::memory::{self, MemoryPlan};
42use crate::rewrite::rewrite_for_backend_with_config;
43use rlx_fusion::fusion_report::FusionReport;
44use rlx_fusion::pass::run_passes;
45use rlx_fusion::{clip_elementwise_regions, with_fusion_limits};
46use rlx_ir::OpKind;
47use rlx_ir::logical_kernel::KernelDispatchConfig;
48
49#[derive(Debug, Clone)]
51pub struct CompileResult {
52 pub lir: LirModule,
53 pub fusion: FusionReport,
54}
55
56impl CompileResult {
57 pub fn has_dynamic_dims(&self) -> bool {
58 self.lir.has_dynamic_dims()
59 }
60
61 pub fn dynamic_symbols(&self) -> &[u32] {
62 self.lir.dynamic_symbols()
63 }
64
65 pub fn specialize(&self, pipeline: &CompilePipeline, binding: &rlx_ir::DimBinding) -> Self {
67 Self {
68 lir: pipeline.specialize_lir(&self.lir, binding),
69 fusion: self.fusion.clone(),
70 }
71 }
72}
73
74#[derive(Debug, Clone, Copy)]
76pub struct CompilePipeline {
77 pub target: FusionTarget,
78 pub opts: FusionOptions,
79 pub arena_alignment: usize,
80 pub assert_fusion_clean: bool,
83 pub supported_ops: Option<&'static [OpKind]>,
87 pub kernel_dispatch: KernelDispatchConfig,
89}
90
91impl Default for CompilePipeline {
92 fn default() -> Self {
93 Self {
94 target: FusionTarget::Cpu,
95 opts: FusionOptions::for_cpu(),
96 arena_alignment: 64,
97 assert_fusion_clean: false,
98 supported_ops: None,
99 kernel_dispatch: KernelDispatchConfig::from_env(),
100 }
101 }
102}
103
104impl CompilePipeline {
105 pub fn new(target: FusionTarget) -> Self {
106 let mut opts = match target {
107 FusionTarget::Cpu => FusionOptions::for_cpu(),
108 FusionTarget::Metal => FusionOptions::from_metal_env(),
109 _ => FusionOptions::default(),
110 };
111 opts.fusion_limits = fusion_limits_for_target(target);
112 Self {
113 target,
114 opts,
115 ..Self::default()
116 }
117 }
118
119 pub fn with_assert_fusion_clean(mut self, assert: bool) -> Self {
120 self.assert_fusion_clean = assert;
121 self
122 }
123
124 pub fn lower_hir(hir: HirModule) -> Result<MirModule, rlx_ir::hir::LowerError> {
126 let mir = hir.lower_to_mir()?;
127 debug_assert_graph!(mir.as_graph(), "hir→mir");
128 Ok(mir)
129 }
130
131 pub fn preprocess_mir(mir: MirModule) -> MirModule {
133 use rlx_fusion::pass::Pass as _;
134 let graph = rlx_fusion::control_flow::LowerControlFlow.run(mir.into_graph());
135 let graph = DeadCodeElimination.run(graph);
136 MirModule::from_graph(graph)
137 }
138
139 pub fn with_supported_ops(mut self, ops: &'static [OpKind]) -> Self {
140 self.supported_ops = Some(ops);
141 self
142 }
143
144 pub fn with_kernel_dispatch(
145 mut self,
146 policy: rlx_ir::logical_kernel::KernelDispatchPolicy,
147 ) -> Self {
148 self.kernel_dispatch.policy = policy;
149 self
150 }
151
152 pub fn with_kernel_dispatch_config(mut self, config: KernelDispatchConfig) -> Self {
153 self.kernel_dispatch = config;
154 self
155 }
156
157 fn effective_supported(&self) -> &'static [OpKind] {
158 self.supported_ops
159 .unwrap_or_else(|| supported_for_target(self.target))
160 }
161
162 fn backend_name(&self) -> &'static str {
163 match self.target {
164 FusionTarget::Cpu => "cpu",
165 FusionTarget::Metal => "metal",
166 FusionTarget::Mlx => "mlx",
167 FusionTarget::Wgpu => "wgpu",
168 FusionTarget::Cuda => "cuda",
169 FusionTarget::Rocm => "rocm",
170 FusionTarget::Tpu => "tpu",
171 }
172 }
173
174 pub fn optimize_with_report(&self, mir: MirModule) -> (MirModule, FusionReport) {
176 let before = mir.as_graph().clone();
177 let passes = fusion_passes_for_supported(self.effective_supported(), self.opts);
178 let limits = self.opts.fusion_limits;
179 let graph = with_fusion_limits(limits, || run_passes(mir.into_graph(), &passes, false));
180 let graph = clip_elementwise_regions(graph, limits);
181 debug_assert_graph!(&graph, "fusion");
182 let graph = self.legalize_after_fusion(graph);
183 debug_assert_graph!(&graph, "legalize");
184 let mir = MirModule::from_graph(graph);
185 let fusion = FusionReport::analyze(&before, mir.as_graph());
186 (mir, fusion)
187 }
188
189 pub(crate) fn legalize_after_fusion(&self, graph: Graph) -> Graph {
193 let Some(supported) = self.supported_ops else {
194 if self.kernel_dispatch.force_common_kinds.is_empty()
195 && self.kernel_dispatch.policy
196 == rlx_ir::logical_kernel::KernelDispatchPolicy::PreferNative
197 {
198 return graph;
199 }
200 return rewrite_for_backend_with_config(graph, &[], self.kernel_dispatch);
201 };
202 if supported.is_empty() {
203 return graph;
204 }
205 let graph = rewrite_for_backend_with_config(graph, supported, self.kernel_dispatch);
206 if let Err(errors) = legalize_for_backend(&graph, supported) {
207 panic!("{}", format_legalize_error(self.backend_name(), &errors));
208 }
209 graph
210 }
211
212 pub fn optimize(&self, mir: MirModule) -> MirModule {
214 self.optimize_with_report(mir).0
215 }
216
217 pub fn plan_lir(&self, mir: MirModule) -> LirModule {
219 self.plan_lir_with_options(mir, memory::MemoryPlanOptions::default())
220 }
221
222 pub fn plan_lir_with_options(
224 &self,
225 mir: MirModule,
226 opts: memory::MemoryPlanOptions,
227 ) -> LirModule {
228 let graph = mir.as_graph().clone();
229 let plan = memory::plan_memory_with_options(&graph, self.arena_alignment, opts);
230 LirModule::new(
231 mir,
232 lir_buffer_plan_from_memory(&graph, &plan, self.arena_alignment),
233 )
234 }
235
236 pub fn specialize_lir(&self, lir: &LirModule, binding: &rlx_ir::DimBinding) -> LirModule {
238 use rlx_ir::dynamic::{
239 bind_graph, sync_concat_shapes, sync_graph_shapes, sync_narrow_ops, sync_reshape_ops,
240 };
241 let mut bound = bind_graph(lir.as_graph(), binding);
242 sync_reshape_ops(&mut bound);
243 sync_concat_shapes(&mut bound);
244 sync_narrow_ops(&mut bound);
245 sync_graph_shapes(&mut bound);
246 debug_assert_graph!(&bound, "specialize");
247 self.plan_lir(MirModule::from_graph(bound))
248 }
249
250 fn finish(&self, mir: MirModule, fusion: FusionReport) -> CompileResult {
251 debug_assert_graph!(mir.as_graph(), "pre-lir");
252 if self.assert_fusion_clean && !fusion.missed.is_empty() {
253 panic!(
254 "fusion contract violated: {} missed patterns\n{fusion}",
255 fusion.missed.len()
256 );
257 }
258 CompileResult {
259 lir: self.plan_lir(mir),
260 fusion,
261 }
262 }
263
264 pub fn compile_hir(&self, hir: HirModule) -> Result<CompileResult, rlx_ir::hir::LowerError> {
266 if rlx_ir::env::var("RLX_IR_DUMP").is_some() {
267 let name = hir.name.clone();
268 let dump = crate::inspect::inspect_pipeline(self, hir.clone())?;
269 crate::inspect::maybe_dump_pipeline(&dump, &name);
270 }
271 let mir = Self::lower_hir(hir)?;
272 let (mir, fusion) = self.optimize_with_report(mir);
273 Ok(self.finish(mir, fusion))
274 }
275
276 pub fn compile_mir(&self, mir: MirModule) -> CompileResult {
278 let (mir, fusion) = self.optimize_with_report(mir);
279 self.finish(mir, fusion)
280 }
281
282 pub fn compile_graph(&self, graph: Graph) -> CompileResult {
284 self.compile_mir(MirModule::from_graph(graph))
285 }
286
287 pub fn compile_module(
289 &self,
290 module: GraphModule,
291 ) -> Result<CompileResult, rlx_ir::hir::LowerError> {
292 match module.stage() {
293 GraphStage::Hir => {
294 let hir = module
295 .into_hir()
296 .expect("GraphModule stage() / into_hir mismatch");
297 self.compile_hir(hir)
298 }
299 GraphStage::Mir => {
300 let mir = module.into_mir()?;
301 Ok(self.compile_mir(mir))
302 }
303 GraphStage::Lir => Ok(CompileResult {
304 lir: module
305 .into_lir()
306 .expect("GraphModule stage() / into_lir mismatch"),
307 fusion: FusionReport::default(),
308 }),
309 }
310 }
311}
312
313impl From<&MemoryPlan> for LirBufferPlan {
314 fn from(plan: &MemoryPlan) -> Self {
315 LirBufferPlan {
316 arena_size: plan.arena_size,
317 assignments: plan
318 .assignments
319 .iter()
320 .map(|(id, slot)| {
321 (
322 *id,
323 LirBufferSlot {
324 offset: slot.offset,
325 size: slot.size,
326 },
327 )
328 })
329 .collect(),
330 schedule: plan.schedule.clone(),
331 ..Default::default()
332 }
333 }
334}
335
336impl From<&LirBufferPlan> for MemoryPlan {
337 fn from(plan: &LirBufferPlan) -> Self {
338 MemoryPlan {
339 arena_size: plan.arena_size,
340 assignments: plan
341 .assignments
342 .iter()
343 .map(|(id, slot)| {
344 (
345 *id,
346 memory::BufferSlot {
347 offset: slot.offset,
348 size: slot.size,
349 },
350 )
351 })
352 .collect(),
353 schedule: plan.schedule.clone(),
354 }
355 }
356}
357
358pub(crate) fn lir_buffer_plan_from_memory(
359 graph: &Graph,
360 plan: &MemoryPlan,
361 alignment: usize,
362) -> LirBufferPlan {
363 let view_aliases = memory::collect_view_aliases(graph)
364 .into_iter()
365 .map(|(id, (root, byte_offset))| (id, LirViewAlias { root, byte_offset }))
366 .collect();
367 LirBufferPlan {
368 arena_size: plan.arena_size,
369 assignments: plan
370 .assignments
371 .iter()
372 .map(|(id, slot)| {
373 (
374 *id,
375 LirBufferSlot {
376 offset: slot.offset,
377 size: slot.size,
378 },
379 )
380 })
381 .collect(),
382 schedule: plan.schedule.clone(),
383 view_aliases,
384 phases: derive_phases(graph),
385 io: LirIoManifest::collect(graph),
386 alignment,
387 dynamic_symbols: collect_dynamic_symbols(graph),
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394 use rlx_ir::DType;
395 use rlx_ir::Op;
396 use rlx_ir::Shape;
397 use rlx_ir::hir::FusionPolicy;
398
399 fn f32_shape(d: &[usize]) -> Shape {
400 Shape::new(d, DType::F32)
401 }
402
403 #[test]
404 fn pipeline_hir_to_lir() {
405 let mut hir = HirModule::new("layer");
406 let x = hir.input("x", f32_shape(&[2, 128]));
407 let w = hir.param("w", f32_shape(&[128, 128]));
408 let b = hir.param("b", f32_shape(&[128]));
409 let h = hir.linear(x, w, Some(b), None, f32_shape(&[2, 128]));
410 hir.outputs = vec![h];
411
412 let pipe = CompilePipeline::new(FusionTarget::Cpu);
413 let result = pipe.compile_hir(hir).expect("compile");
414 assert!(result.lir.mir.len() <= 5);
415 assert!(result.lir.arena_size() > 0);
416 assert!(result.lir.buffers.bytes_saved() <= result.lir.buffers.total_unshared_bytes());
417 assert!(result.fusion.fused_matmul_bias_act >= 1 || result.lir.mir.len() <= 5);
418 }
419
420 #[test]
421 fn direct_hir_swiglu_emits_fused_op() {
422 let mut hir = HirModule::new("ffn");
423 let x = hir.input("x", f32_shape(&[4, 768]));
424 let up_w = hir.param("up", f32_shape(&[768, 2048]));
425 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
426 let down_w = hir.param("down", f32_shape(&[2048, 768]));
427 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
428 hir.outputs = vec![out];
429
430 let pipe = CompilePipeline::new(FusionTarget::Cpu);
431 let result = pipe.compile_hir(hir).expect("compile");
432 let g = result.lir.mir.as_graph();
433 assert!(
434 g.nodes()
435 .iter()
436 .any(|n| matches!(n.op, Op::FusedSwiGLU { .. })),
437 "direct HIR SwiGLU should lower to FusedSwiGLU"
438 );
439 assert!(result.fusion.missed_matmul_bias_act() == 0 || result.fusion.fused_swiglu >= 1);
440 }
441
442 #[test]
443 fn compile_module_from_graph_define() {
444 let module = GraphModule::define("ffn", |m| {
445 let x = m.input("x", f32_shape(&[2, 64]));
446 let w = m.param("w", f32_shape(&[64, 64]));
447 m.linear(x, w, None, None, f32_shape(&[2, 64]))
448 });
449 assert_eq!(module.stage(), GraphStage::Hir);
450
451 let pipe = CompilePipeline::new(FusionTarget::Cpu);
452 let result = pipe.compile_module(module).expect("compile_module");
453 assert!(result.lir.arena_size() > 0);
454 }
455
456 #[test]
457 fn fusable_policy_leaves_room_for_passes() {
458 let mut hir = HirModule::new("ffn").with_fusion_policy(FusionPolicy::Fusable);
459 let x = hir.input("x", f32_shape(&[4, 768]));
460 let up_w = hir.param("up", f32_shape(&[768, 2048]));
461 let gate_w = hir.param("gate", f32_shape(&[768, 2048]));
462 let down_w = hir.param("down", f32_shape(&[2048, 768]));
463 let out = hir.swiglu_ffn(x, up_w, gate_w, down_w, f32_shape(&[4, 768]));
464 hir.outputs = vec![out];
465
466 let mir = CompilePipeline::lower_hir(hir).expect("lower");
467 let g = mir.as_graph();
468 assert!(g.nodes().iter().any(|n| matches!(n.op, Op::MatMul)));
469 assert_eq!(g.len(), 9);
470
471 let pipe = CompilePipeline::new(FusionTarget::Cpu);
472 let result = pipe.compile_mir(mir);
473 assert!(result.fusion.fused_swiglu >= 1);
474 }
475
476 #[test]
477 fn lir_plan_includes_phases_io_and_fingerprint() {
478 use rlx_ir::phase::Phase;
479
480 let mut hir = HirModule::new("stream");
481 let x = hir.input("x", f32_shape(&[1, 8]));
482 let w = hir.param("w", f32_shape(&[8, 4]));
483 let mm = hir.linear(x, w, None, None, f32_shape(&[1, 4]));
484 hir.set_outputs(vec![mm]);
485
486 let result = CompilePipeline::new(FusionTarget::Cpu)
487 .compile_hir(hir)
488 .expect("compile");
489 assert!(!result.lir.buffers.phases.is_empty());
490 let input_id = result.lir.buffers.io.inputs[0].1;
491 assert_eq!(
492 result.lir.buffers.phases.get(input_id),
493 Some(Phase::Prologue)
494 );
495 assert_eq!(result.lir.buffers.io.inputs.len(), 1);
496 assert_eq!(result.lir.fingerprint(), result.lir.fingerprint());
497 assert_eq!(result.lir.buffers.alignment, 64);
498 }
499
500 #[test]
501 fn dynamic_graph_compiles_and_specializes() {
502 use rlx_ir::DimBinding;
503 use rlx_ir::infer::GraphExt as _;
504 use rlx_ir::sym;
505
506 let mut g = Graph::new("dyn");
507 let x = g.input("x", Shape::batch_seq_2d(sym::BATCH, sym::SEQ, DType::F32));
508 let w = g.param("w", Shape::new(&[4, 8], DType::F32));
509 let y = g.mm(x, w);
510 g.set_outputs(vec![y]);
511
512 let pipe = CompilePipeline::new(FusionTarget::Cpu);
513 let result = pipe.compile_graph(g);
514 assert!(result.has_dynamic_dims());
515 assert!(result.lir.buffers.dynamic_symbols.contains(&sym::SEQ));
516
517 let bound = result.specialize(&pipe, &DimBinding::batch_seq(2, 16));
518 assert!(bound.lir.is_fully_static());
519 assert!(bound.lir.arena_size() > 0);
520 }
521}