Skip to main content

rlx_ir/
inspect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Text exporters for inspecting HIR / MIR / LIR during lowering.
17//!
18//! Use [`inspect_hir`], [`inspect_mir`], and [`inspect_lir`] to dump
19//! each pipeline stage as human-readable text (similar to LLVM `-print-*`
20//! flags). [`inspect_graph`] is the MIR body formatter shared by MIR and
21//! LIR dumps.
22
23use std::collections::BTreeMap;
24use std::fmt::Write as _;
25
26use crate::hir::{HirModule, HirNode, HirOp};
27use crate::lir::{LirBufferPlan, LirModule, LirViewAlias};
28use crate::mir::MirModule;
29use crate::phase::Phase;
30use crate::pretty::{header_line, op_kinds_line, pretty_print};
31use crate::{Graph, NodeId};
32
33/// Annotated HIR module dump.
34pub fn inspect_hir(hir: &HirModule) -> String {
35    let mut out = String::new();
36    writeln!(
37        out,
38        "hir @{} ({} nodes, {} outputs, fusion={:?})",
39        hir.name,
40        hir.len(),
41        hir.outputs.len(),
42        hir.fusion_policy,
43    )
44    .unwrap();
45    writeln!(out, "{}", hir_op_kinds_line(hir)).unwrap();
46    writeln!(out).unwrap();
47
48    let mut tag_w = 0usize;
49    for node in hir.nodes() {
50        let t = hir_node_tag(node);
51        tag_w = tag_w.max(t.len());
52    }
53
54    for node in hir.nodes() {
55        let tag = hir_node_tag(node);
56        write!(out, "  {tag:<width$} = ", width = tag_w).unwrap();
57        write!(out, "{}", format_hir_op(&node.op)).unwrap();
58        if !node.inputs.is_empty() {
59            write!(out, "(").unwrap();
60            for (i, inp) in node.inputs.iter().enumerate() {
61                if i > 0 {
62                    write!(out, ", ").unwrap();
63                }
64                write!(out, "{inp}").unwrap();
65            }
66            write!(out, ")").unwrap();
67        }
68        write!(out, " : {}", node.shape).unwrap();
69        if hir.outputs.contains(&node.id) {
70            write!(out, "  ← output").unwrap();
71        }
72        writeln!(out).unwrap();
73    }
74    if !hir.outputs.is_empty() {
75        write!(out, "  return ").unwrap();
76        for (i, o) in hir.outputs.iter().enumerate() {
77            if i > 0 {
78                write!(out, ", ").unwrap();
79            }
80            write!(out, "{o}").unwrap();
81        }
82        writeln!(out).unwrap();
83    }
84    out
85}
86
87/// Annotated MIR module dump (optimized tensor DAG).
88pub fn inspect_mir(mir: &MirModule) -> String {
89    inspect_mir_with_diff(mir, None)
90}
91
92/// MIR dump with optional fusion diff against a pre-optimize snapshot.
93pub fn inspect_mir_with_diff(mir: &MirModule, before: Option<&MirModule>) -> String {
94    let g = mir.as_graph();
95    let mut out = String::new();
96    writeln!(out, "mir @{} {{", mir.name()).unwrap();
97    if let Some(b) = before {
98        writeln!(out).unwrap();
99        out.push_str(&inspect_graph_diff(b.as_graph(), g));
100        writeln!(out).unwrap();
101        writeln!(out, "--- graph ---").unwrap();
102    }
103    writeln!(out).unwrap();
104    out.push_str(&pretty_print(g));
105    if !out.ends_with('\n') {
106        out.push('\n');
107    }
108    write!(out, "}}").unwrap();
109    out
110}
111
112/// Diff two MIR snapshots (typically pre/post fusion).
113pub fn inspect_mir_diff(before: &MirModule, after: &MirModule) -> String {
114    inspect_graph_diff(before.as_graph(), after.as_graph())
115}
116
117/// Summarize graph changes between pipeline stages.
118pub fn inspect_graph_diff(before: &Graph, after: &Graph) -> String {
119    use std::collections::BTreeMap;
120
121    let mut out = String::new();
122    writeln!(
123        out,
124        "  diff: {} → {} nodes ({} → {} outputs)",
125        before.len(),
126        after.len(),
127        before.outputs.len(),
128        after.outputs.len(),
129    )
130    .unwrap();
131
132    let count_kinds = |g: &Graph| {
133        let mut h: BTreeMap<String, i32> = BTreeMap::new();
134        for n in g.nodes() {
135            *h.entry(format!("{:?}", n.op.kind())).or_insert(0) += 1;
136        }
137        h
138    };
139    let b = count_kinds(before);
140    let a = count_kinds(after);
141    let mut keys: Vec<String> = b.keys().chain(a.keys()).cloned().collect();
142    keys.sort();
143    keys.dedup();
144    let mut changes = Vec::new();
145    for k in keys {
146        let d = a.get(&k).copied().unwrap_or(0) - b.get(&k).copied().unwrap_or(0);
147        if d != 0 {
148            changes.push(format!("{k}{d:+}"));
149        }
150    }
151    if !changes.is_empty() {
152        writeln!(out, "  op delta: {}", changes.join(", ")).unwrap();
153    }
154    out
155}
156
157/// Annotated LIR dump: optimized MIR + buffer plan + schedule.
158pub fn inspect_lir(lir: &LirModule) -> String {
159    let mut out = String::new();
160    writeln!(out, "lir @{} {{", lir.name()).unwrap();
161    writeln!(out, "  fingerprint: {:016x}", lir.fingerprint().0).unwrap();
162    writeln!(out).unwrap();
163    out.push_str(&inspect_buffer_plan(&lir.buffers));
164    if !lir.buffers.phases.is_empty() {
165        writeln!(out).unwrap();
166        out.push_str(&inspect_phases(&lir.buffers));
167    }
168    if !lir.buffers.io.inputs.is_empty() || !lir.buffers.io.params.is_empty() {
169        writeln!(out).unwrap();
170        out.push_str(&inspect_io_manifest(&lir.buffers));
171    }
172    writeln!(out).unwrap();
173    writeln!(out, "--- mir ---").unwrap();
174    out.push_str(&pretty_print(lir.as_graph()));
175    if !out.ends_with('\n') {
176        out.push('\n');
177    }
178    write!(out, "}}").unwrap();
179    out
180}
181
182/// Annotated graph dump (MIR body). Alias for [`pretty_print`].
183pub fn inspect_graph(g: &Graph) -> String {
184    pretty_print(g)
185}
186
187/// One-line HIR summary (header + op histogram).
188pub fn inspect_hir_stats(hir: &HirModule) -> String {
189    format!(
190        "hir @{} ({} nodes, {} outputs, fusion={:?})\n{}",
191        hir.name,
192        hir.len(),
193        hir.outputs.len(),
194        hir.fusion_policy,
195        hir_op_kinds_line(hir),
196    )
197}
198
199/// One-line MIR summary.
200pub fn inspect_mir_stats(mir: &MirModule) -> String {
201    let g = mir.as_graph();
202    format!(
203        "mir @{} — {}\n{}",
204        mir.name(),
205        header_line(g),
206        op_kinds_line(g),
207    )
208}
209
210/// Buffer plan section for LIR inspection.
211pub fn inspect_buffer_plan(plan: &LirBufferPlan) -> String {
212    let mut out = String::new();
213    let saved = plan.bytes_saved();
214    let naive = plan.total_unshared_bytes();
215    writeln!(
216        out,
217        "  arena: {} bytes (saved {} vs {} naive, align={})",
218        plan.arena_size, saved, naive, plan.alignment,
219    )
220    .unwrap();
221    writeln!(
222        out,
223        "  schedule: {} nodes, {} views",
224        plan.schedule.len(),
225        plan.view_aliases.len(),
226    )
227    .unwrap();
228    if !plan.dynamic_symbols.is_empty() {
229        let syms: Vec<String> = plan
230            .dynamic_symbols
231            .iter()
232            .map(|s| format!("?{s}"))
233            .collect();
234        writeln!(out, "  dynamic: {}", syms.join(", ")).unwrap();
235    }
236    writeln!(out).unwrap();
237    writeln!(out, "  # offset\tsize\tnode").unwrap();
238
239    let mut rows: Vec<(usize, usize, NodeId)> = plan
240        .assignments
241        .iter()
242        .map(|(id, slot)| (slot.offset, slot.size, *id))
243        .collect();
244    rows.sort_by_key(|(off, _, _)| *off);
245    for (off, sz, id) in rows {
246        let sched = plan
247            .schedule
248            .iter()
249            .position(|&n| n == id)
250            .map(|i| format!(" sched={i}"))
251            .unwrap_or_default();
252        let view = plan
253            .view_aliases
254            .get(&id)
255            .map(|LirViewAlias { root, byte_offset }| format!(" view→{root}+{byte_offset}"))
256            .unwrap_or_default();
257        let phase = plan
258            .phases
259            .get(id)
260            .map(|p| format!(" {p:?}"))
261            .unwrap_or_default();
262        writeln!(out, "  {off}\t{sz}\t{id}{sched}{view}{phase}").unwrap();
263    }
264    out
265}
266
267fn inspect_phases(plan: &LirBufferPlan) -> String {
268    let mut out = String::from("  phases:\n");
269    for phase in [Phase::Prologue, Phase::SteadyState, Phase::Epilogue] {
270        let nodes = plan.nodes_in_phase(phase);
271        if !nodes.is_empty() {
272            writeln!(out, "    {phase:?}: {nodes:?}").unwrap();
273        }
274    }
275    out
276}
277
278fn inspect_io_manifest(plan: &LirBufferPlan) -> String {
279    let mut out = String::from("  io:\n");
280    for (name, id) in &plan.io.inputs {
281        writeln!(out, "    input \"{name}\" → {id}").unwrap();
282    }
283    for (name, id) in &plan.io.params {
284        writeln!(out, "    param \"{name}\" → {id}").unwrap();
285    }
286    if !plan.io.outputs.is_empty() {
287        write!(out, "    outputs: {:?}", plan.io.outputs).unwrap();
288        out.push('\n');
289    }
290    out
291}
292
293fn hir_op_kinds_line(hir: &HirModule) -> String {
294    let mut hist: BTreeMap<String, usize> = BTreeMap::new();
295    for node in hir.nodes() {
296        *hist.entry(hir_op_kind(&node.op)).or_insert(0) += 1;
297    }
298    let parts: Vec<String> = hist.into_iter().map(|(k, c)| format!("{k}={c}")).collect();
299    format!("  block ops: {}", parts.join(", "))
300}
301
302fn hir_op_kind(op: &HirOp) -> String {
303    match op {
304        HirOp::Input { .. } => "Input".into(),
305        HirOp::Param { .. } => "Param".into(),
306        HirOp::Constant { .. } => "Constant".into(),
307        HirOp::Linear { .. } => "Linear".into(),
308        HirOp::LinearFused { .. } => "LinearFused".into(),
309        HirOp::SharedLinearPair { .. } => "SharedLinearPair".into(),
310        HirOp::SwiGLU => "SwiGLU".into(),
311        HirOp::ResidualRmsNorm { .. } => "ResidualRmsNorm".into(),
312        HirOp::Attention { .. } => "Attention".into(),
313        HirOp::DepthwiseConv1dCausal { .. } => "DepthwiseConv1dCausal".into(),
314        HirOp::DequantMatMul { .. } => "DequantMatMul".into(),
315        HirOp::GatedDeltaNet { .. } => "GatedDeltaNet".into(),
316        HirOp::RoPE { .. } => "RoPE".into(),
317        HirOp::RmsNorm { .. } => "RmsNorm".into(),
318        HirOp::Mir(_) => "Mir".into(),
319        HirOp::LlamaDecoderBlock { .. } => "LlamaDecoderBlock".into(),
320        HirOp::Qwen35MtpHead { .. } => "Qwen35MtpHead".into(),
321    }
322}
323
324fn hir_node_tag(node: &HirNode) -> String {
325    let label: Option<String> = match &node.op {
326        HirOp::Input { name } => Some(format!("input \"{name}\"")),
327        HirOp::Param { name } => Some(format!("param \"{name}\"")),
328        _ => node.name.as_deref().map(|s| format!("\"{s}\"")),
329    };
330    match label {
331        Some(s) => format!("{} [{s}]", node.id),
332        None => format!("{}", node.id),
333    }
334}
335
336fn format_hir_op(op: &HirOp) -> String {
337    match op {
338        HirOp::Input { name } => format!("input(\"{name}\")"),
339        HirOp::Param { name } => format!("param(\"{name}\")"),
340        HirOp::Constant { data } => format!("constant({} bytes)", data.len()),
341        HirOp::Linear {
342            activation,
343            has_bias,
344        } => {
345            let mut s = String::from("linear");
346            if *has_bias {
347                s.push_str("+bias");
348            }
349            if let Some(act) = activation {
350                write!(s, "+{act:?}").unwrap();
351            }
352            s
353        }
354        HirOp::LinearFused { activation } => match activation {
355            Some(act) => format!("linear_fused({act:?})"),
356            None => "linear_fused".into(),
357        },
358        HirOp::SharedLinearPair { slot } => format!("shared_linear_pair(out={slot})"),
359        HirOp::SwiGLU => "swiglu_ffn".into(),
360        HirOp::ResidualRmsNorm { eps } => format!("residual_rms_norm(eps={eps})"),
361        HirOp::Attention {
362            num_heads,
363            head_dim,
364            mask,
365        } => format!("attention(heads={num_heads}, dim={head_dim}, mask={mask:?})"),
366        HirOp::DepthwiseConv1dCausal { kernel_size } => {
367            format!("depthwise_conv1d_causal(k={kernel_size})")
368        }
369        HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
370        HirOp::GatedDeltaNet {
371            state_size,
372            carry_state,
373        } => {
374            if *carry_state {
375                format!("gated_delta_net(n={state_size},carry)")
376            } else {
377                format!("gated_delta_net(n={state_size})")
378            }
379        }
380        HirOp::RoPE { head_dim, n_rot } => format!("rope(d={head_dim}, n_rot={n_rot})"),
381        HirOp::RmsNorm { eps } => format!("rms_norm(eps={eps})"),
382        HirOp::LlamaDecoderBlock {
383            num_heads,
384            head_dim,
385            num_kv_heads,
386            eps,
387            mask,
388        } => format!(
389            "llama_decoder_block(heads={num_heads}, dim={head_dim}, kv={num_kv_heads}, eps={eps}, mask={mask:?})"
390        ),
391        HirOp::Qwen35MtpHead {
392            num_heads,
393            head_dim,
394            mtp_vocab,
395            ..
396        } => format!("qwen35_mtp_head(heads={num_heads}, dim={head_dim}, vocab={mtp_vocab})"),
397        HirOp::Mir(inner) => format!("mir({inner})"),
398    }
399}
400
401// ── convenience methods on pipeline types ───────────────────────────────
402
403impl HirModule {
404    /// Text dump for inspection. Alias for [`inspect_hir`].
405    pub fn inspect(&self) -> String {
406        inspect_hir(self)
407    }
408}
409
410impl MirModule {
411    /// Text dump for inspection. Alias for [`inspect_mir`].
412    pub fn inspect(&self) -> String {
413        inspect_mir(self)
414    }
415}
416
417impl LirModule {
418    /// Text dump for inspection. Alias for [`inspect_lir`].
419    pub fn inspect(&self) -> String {
420        inspect_lir(self)
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::DType;
428    use crate::Shape;
429
430    fn f32_shape(d: &[usize]) -> Shape {
431        Shape::new(d, DType::F32)
432    }
433
434    #[test]
435    fn inspect_hir_includes_blocks_and_outputs() {
436        let mut hir = HirModule::new("layer");
437        let x = hir.input("x", f32_shape(&[2, 128]));
438        let w = hir.param("w", f32_shape(&[128, 128]));
439        let h = hir.linear(x, w, None, None, f32_shape(&[2, 128]));
440        hir.outputs = vec![h];
441
442        let text = inspect_hir(&hir);
443        assert!(text.contains("hir @layer"));
444        assert!(text.contains("linear"));
445        assert!(text.contains("← output"));
446        assert!(text.contains("fusion=Direct"));
447    }
448
449    #[test]
450    fn inspect_mir_wraps_pretty_print() {
451        let mut hir = HirModule::new("m");
452        let x = hir.input("x", f32_shape(&[4]));
453        hir.outputs = vec![x];
454        let mir = hir.lower_to_mir().expect("lower");
455
456        let text = inspect_mir(&mir);
457        assert!(text.contains("mir @m"));
458        assert!(text.contains("graph @m"));
459        assert!(text.contains("input(\"x\")"));
460    }
461
462    #[test]
463    fn named_block_appears_in_hir_dump() {
464        let mut hir = HirModule::new("layer");
465        let x = hir.input("x", f32_shape(&[2, 8]));
466        let w = hir.param("w", f32_shape(&[8, 8]));
467        let out = hir.named("layer0.ffn", |h| {
468            h.linear(x, w, None, None, f32_shape(&[2, 8]))
469        });
470        hir.outputs = vec![out];
471
472        let text = inspect_hir(&hir);
473        assert!(text.contains("layer0.ffn"));
474    }
475
476    #[test]
477    fn provenance_survives_lower() {
478        let mut hir = HirModule::new("m");
479        let x = hir.input("x", f32_shape(&[2, 8]));
480        let w = hir.param("w", f32_shape(&[8, 8]));
481        let out = hir.named("block", |h| h.linear(x, w, None, None, f32_shape(&[2, 8])));
482        hir.outputs = vec![out];
483
484        let mir = hir.lower_to_mir().expect("lower");
485        let text = inspect_mir(&mir);
486        assert!(text.contains("hir=h"));
487        assert!(text.contains("block"));
488    }
489
490    #[test]
491    fn inspect_lir_includes_buffer_plan() {
492        use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
493
494        let mut hir = HirModule::new("l");
495        let x = hir.input("x", f32_shape(&[4]));
496        hir.outputs = vec![x];
497        let mir = hir.lower_to_mir().expect("lower");
498        let plan = LirBufferPlan {
499            arena_size: 16,
500            assignments: [(
501                NodeId(0),
502                LirBufferSlot {
503                    offset: 0,
504                    size: 16,
505                },
506            )]
507            .into_iter()
508            .collect(),
509            schedule: vec![NodeId(0)],
510            io: LirIoManifest {
511                inputs: vec![("x".into(), NodeId(0))],
512                ..Default::default()
513            },
514            ..Default::default()
515        };
516        let lir = LirModule::new(mir, plan);
517
518        let text = inspect_lir(&lir);
519        assert!(text.contains("lir @l"));
520        assert!(text.contains("arena: 16 bytes"));
521        assert!(text.contains("fingerprint:"));
522        assert!(text.contains("--- mir ---"));
523    }
524}