Skip to main content

rlx_compile/
inspect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pipeline inspection — dump every HIR / MIR / LIR stage as text.
17
18use std::fmt::Write as _;
19
20use rlx_ir::hir::{HirModule, LowerError};
21use rlx_ir::{inspect_graph_diff, inspect_hir, inspect_lir, inspect_mir, inspect_mir_stats};
22
23use crate::compiler::{CompilePipeline, CompileResult};
24use rlx_fusion::fusion_report::FusionReport;
25
26/// Text dump of each compiler pipeline stage.
27#[derive(Debug, Clone)]
28pub struct PipelineInspect {
29    pub hir: String,
30    pub mir_lowered: String,
31    pub mir_diff: String,
32    pub mir_optimized: String,
33    pub lir: String,
34    pub fusion: String,
35}
36
37impl std::fmt::Display for PipelineInspect {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        writeln_section(f, "HIR", &self.hir)?;
40        writeln_section(f, "MIR (lowered)", &self.mir_lowered)?;
41        if !self.mir_diff.is_empty() {
42            writeln_section(f, "MIR (fusion diff)", &self.mir_diff)?;
43        }
44        writeln_section(f, "MIR (optimized)", &self.mir_optimized)?;
45        writeln_section(f, "FUSION", &self.fusion)?;
46        writeln_section(f, "LIR", &self.lir)
47    }
48}
49
50fn writeln_section(f: &mut std::fmt::Formatter<'_>, title: &str, body: &str) -> std::fmt::Result {
51    let mut header = String::new();
52    banner(&mut header, title);
53    write!(f, "{header}{body}")?;
54    if !body.ends_with('\n') {
55        writeln!(f)?;
56    }
57    Ok(())
58}
59
60/// Inspect every lowering stage for `hir` through `pipeline`.
61pub fn inspect_pipeline(
62    pipeline: &CompilePipeline,
63    hir: HirModule,
64) -> Result<PipelineInspect, LowerError> {
65    let hir_text = inspect_hir(&hir);
66    let mir_raw = CompilePipeline::lower_hir(hir)?;
67    let mir_lowered = inspect_mir(&mir_raw);
68    let mir_before = mir_raw.clone();
69    let (mir_opt, fusion) = pipeline.optimize_with_report(mir_raw);
70    let mir_diff = inspect_graph_diff(mir_before.as_graph(), mir_opt.as_graph());
71    let fusion_text = format!("{}\n{}", fusion, inspect_mir_stats(&mir_opt));
72    let lir = pipeline.plan_lir(mir_opt);
73    Ok(PipelineInspect {
74        hir: hir_text,
75        mir_lowered,
76        mir_diff,
77        mir_optimized: inspect_mir(&lir.mir),
78        lir: inspect_lir(&lir),
79        fusion: fusion_text,
80    })
81}
82
83/// Inspect a completed [`CompileResult`] plus the original HIR text.
84pub fn inspect_compiled(hir_text: &str, result: &CompileResult) -> PipelineInspect {
85    PipelineInspect {
86        hir: hir_text.to_string(),
87        mir_lowered: String::new(),
88        mir_diff: String::new(),
89        mir_optimized: inspect_mir(&result.lir.mir),
90        lir: inspect_lir(&result.lir),
91        fusion: format!("{}", result.fusion),
92    }
93}
94
95/// Write a full pipeline dump when `RLX_IR_DUMP` is set (path prefix or directory).
96pub fn maybe_dump_pipeline(dump: &PipelineInspect, module_name: &str) {
97    let Some(path) = rlx_ir::env::var("RLX_IR_DUMP") else {
98        return;
99    };
100    let target = if path.ends_with('/') || path.ends_with('\\') {
101        format!("{path}{module_name}.ir.txt")
102    } else {
103        path
104    };
105    if let Err(e) = std::fs::write(&target, dump.to_string()) {
106        eprintln!("[rlx] RLX_IR_DUMP write failed ({target}): {e}");
107    } else {
108        eprintln!("[rlx] wrote IR dump to {target}");
109    }
110}
111
112/// Fusion report only (post-optimize diagnostics).
113pub fn inspect_fusion(report: &FusionReport) -> String {
114    format!("{report}")
115}
116
117fn banner(out: &mut String, title: &str) {
118    let line = "═".repeat(title.len() + 4);
119    writeln!(out, "{line}").unwrap();
120    writeln!(out, "══ {title} ══").unwrap();
121    writeln!(out, "{line}").unwrap();
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use rlx_ir::DType;
128    use rlx_ir::Shape;
129
130    fn f32_shape(d: &[usize]) -> Shape {
131        Shape::new(d, DType::F32)
132    }
133
134    #[test]
135    fn inspect_pipeline_covers_all_stages() {
136        let mut hir = HirModule::new("probe");
137        let x = hir.input("x", f32_shape(&[2, 64]));
138        let w = hir.param("w", f32_shape(&[64, 64]));
139        let h = hir.linear(x, w, None, None, f32_shape(&[2, 64]));
140        hir.outputs = vec![h];
141
142        let pipe = CompilePipeline::default();
143        let dump = inspect_pipeline(&pipe, hir).expect("inspect");
144        assert!(dump.hir.contains("hir @probe"));
145        assert!(dump.mir_lowered.contains("mir @probe"));
146        assert!(dump.mir_optimized.contains("mir @probe"));
147        assert!(dump.lir.contains("lir @probe"));
148        assert!(dump.fusion.contains("nodes="));
149        let full = dump.to_string();
150        assert!(full.contains("══ HIR ══"));
151        assert!(full.contains("══ LIR ══"));
152    }
153}