Skip to main content

rlx_ir/
inspect.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Text exporters for inspecting HIR / MIR / LIR during lowering.
17//!
18//! Use [`inspect_hir`], [`inspect_mir`], and [`inspect_lir`] to dump
19//! each pipeline stage as human-readable text (similar to LLVM `-print-*`
20//! flags). [`inspect_graph`] is the MIR body formatter shared by MIR and
21//! LIR dumps.
22
23use std::collections::BTreeMap;
24use std::fmt::Write as _;
25
26use crate::hir::{HirModule, HirNode, HirOp};
27use crate::lir::{LirBufferPlan, LirModule, LirViewAlias};
28use crate::mir::MirModule;
29use crate::phase::Phase;
30use crate::pretty::{header_line, op_kinds_line, pretty_print};
31use crate::{Graph, NodeId};
32
33/// Annotated HIR module dump.
34pub fn inspect_hir(hir: &HirModule) -> String {
35    let mut out = String::new();
36    writeln!(
37        out,
38        "hir @{} ({} nodes, {} outputs, fusion={:?})",
39        hir.name,
40        hir.len(),
41        hir.outputs.len(),
42        hir.fusion_policy,
43    )
44    .unwrap();
45    writeln!(out, "{}", hir_op_kinds_line(hir)).unwrap();
46    writeln!(out).unwrap();
47
48    let mut tag_w = 0usize;
49    for node in hir.nodes() {
50        let t = hir_node_tag(node);
51        tag_w = tag_w.max(t.len());
52    }
53
54    for node in hir.nodes() {
55        let tag = hir_node_tag(node);
56        write!(out, "  {tag:<width$} = ", width = tag_w).unwrap();
57        write!(out, "{}", format_hir_op(&node.op)).unwrap();
58        if !node.inputs.is_empty() {
59            write!(out, "(").unwrap();
60            for (i, inp) in node.inputs.iter().enumerate() {
61                if i > 0 {
62                    write!(out, ", ").unwrap();
63                }
64                write!(out, "{inp}").unwrap();
65            }
66            write!(out, ")").unwrap();
67        }
68        write!(out, " : {}", node.shape).unwrap();
69        if hir.outputs.contains(&node.id) {
70            write!(out, "  ← output").unwrap();
71        }
72        writeln!(out).unwrap();
73    }
74    if !hir.outputs.is_empty() {
75        write!(out, "  return ").unwrap();
76        for (i, o) in hir.outputs.iter().enumerate() {
77            if i > 0 {
78                write!(out, ", ").unwrap();
79            }
80            write!(out, "{o}").unwrap();
81        }
82        writeln!(out).unwrap();
83    }
84    out
85}
86
87/// Annotated MIR module dump (optimized tensor DAG).
88pub fn inspect_mir(mir: &MirModule) -> String {
89    inspect_mir_with_diff(mir, None)
90}
91
92/// MIR dump with optional fusion diff against a pre-optimize snapshot.
93pub fn inspect_mir_with_diff(mir: &MirModule, before: Option<&MirModule>) -> String {
94    let g = mir.as_graph();
95    let mut out = String::new();
96    writeln!(out, "mir @{} {{", mir.name()).unwrap();
97    if let Some(b) = before {
98        writeln!(out).unwrap();
99        out.push_str(&inspect_graph_diff(b.as_graph(), g));
100        writeln!(out).unwrap();
101        writeln!(out, "--- graph ---").unwrap();
102    }
103    writeln!(out).unwrap();
104    out.push_str(&pretty_print(g));
105    if !out.ends_with('\n') {
106        out.push('\n');
107    }
108    write!(out, "}}").unwrap();
109    out
110}
111
112/// Diff two MIR snapshots (typically pre/post fusion).
113pub fn inspect_mir_diff(before: &MirModule, after: &MirModule) -> String {
114    inspect_graph_diff(before.as_graph(), after.as_graph())
115}
116
117/// Summarize graph changes between pipeline stages.
118pub fn inspect_graph_diff(before: &Graph, after: &Graph) -> String {
119    use std::collections::BTreeMap;
120
121    let mut out = String::new();
122    writeln!(
123        out,
124        "  diff: {} → {} nodes ({} → {} outputs)",
125        before.len(),
126        after.len(),
127        before.outputs.len(),
128        after.outputs.len(),
129    )
130    .unwrap();
131
132    let count_kinds = |g: &Graph| {
133        let mut h: BTreeMap<String, i32> = BTreeMap::new();
134        for n in g.nodes() {
135            *h.entry(format!("{:?}", n.op.kind())).or_insert(0) += 1;
136        }
137        h
138    };
139    let b = count_kinds(before);
140    let a = count_kinds(after);
141    let mut keys: Vec<String> = b.keys().chain(a.keys()).cloned().collect();
142    keys.sort();
143    keys.dedup();
144    let mut changes = Vec::new();
145    for k in keys {
146        let d = a.get(&k).copied().unwrap_or(0) - b.get(&k).copied().unwrap_or(0);
147        if d != 0 {
148            changes.push(format!("{k}{d:+}"));
149        }
150    }
151    if !changes.is_empty() {
152        writeln!(out, "  op delta: {}", changes.join(", ")).unwrap();
153    }
154    out
155}
156
157/// Annotated LIR dump: optimized MIR + buffer plan + schedule.
158pub fn inspect_lir(lir: &LirModule) -> String {
159    let mut out = String::new();
160    writeln!(out, "lir @{} {{", lir.name()).unwrap();
161    writeln!(out, "  fingerprint: {:016x}", lir.fingerprint().0).unwrap();
162    writeln!(out).unwrap();
163    out.push_str(&inspect_buffer_plan(&lir.buffers));
164    if !lir.buffers.phases.is_empty() {
165        writeln!(out).unwrap();
166        out.push_str(&inspect_phases(&lir.buffers));
167    }
168    if !lir.buffers.io.inputs.is_empty() || !lir.buffers.io.params.is_empty() {
169        writeln!(out).unwrap();
170        out.push_str(&inspect_io_manifest(&lir.buffers));
171    }
172    writeln!(out).unwrap();
173    writeln!(out, "--- mir ---").unwrap();
174    out.push_str(&pretty_print(lir.as_graph()));
175    if !out.ends_with('\n') {
176        out.push('\n');
177    }
178    write!(out, "}}").unwrap();
179    out
180}
181
182/// Annotated graph dump (MIR body). Alias for [`pretty_print`].
183pub fn inspect_graph(g: &Graph) -> String {
184    pretty_print(g)
185}
186
187/// One-line HIR summary (header + op histogram).
188pub fn inspect_hir_stats(hir: &HirModule) -> String {
189    format!(
190        "hir @{} ({} nodes, {} outputs, fusion={:?})\n{}",
191        hir.name,
192        hir.len(),
193        hir.outputs.len(),
194        hir.fusion_policy,
195        hir_op_kinds_line(hir),
196    )
197}
198
199/// One-line MIR summary.
200pub fn inspect_mir_stats(mir: &MirModule) -> String {
201    let g = mir.as_graph();
202    format!(
203        "mir @{} — {}\n{}",
204        mir.name(),
205        header_line(g),
206        op_kinds_line(g),
207    )
208}
209
210/// Buffer plan section for LIR inspection.
211pub fn inspect_buffer_plan(plan: &LirBufferPlan) -> String {
212    let mut out = String::new();
213    let saved = plan.bytes_saved();
214    let naive = plan.total_unshared_bytes();
215    writeln!(
216        out,
217        "  arena: {} bytes (saved {} vs {} naive, align={})",
218        plan.arena_size, saved, naive, plan.alignment,
219    )
220    .unwrap();
221    writeln!(
222        out,
223        "  schedule: {} nodes, {} views",
224        plan.schedule.len(),
225        plan.view_aliases.len(),
226    )
227    .unwrap();
228    if !plan.dynamic_symbols.is_empty() {
229        let syms: Vec<String> = plan
230            .dynamic_symbols
231            .iter()
232            .map(|s| format!("?{s}"))
233            .collect();
234        writeln!(out, "  dynamic: {}", syms.join(", ")).unwrap();
235    }
236    writeln!(out).unwrap();
237    writeln!(out, "  # offset\tsize\tnode").unwrap();
238
239    let mut rows: Vec<(usize, usize, NodeId)> = plan
240        .assignments
241        .iter()
242        .map(|(id, slot)| (slot.offset, slot.size, *id))
243        .collect();
244    rows.sort_by_key(|(off, _, _)| *off);
245    for (off, sz, id) in rows {
246        let sched = plan
247            .schedule
248            .iter()
249            .position(|&n| n == id)
250            .map(|i| format!(" sched={i}"))
251            .unwrap_or_default();
252        let view = plan
253            .view_aliases
254            .get(&id)
255            .map(|LirViewAlias { root, byte_offset }| format!(" view→{root}+{byte_offset}"))
256            .unwrap_or_default();
257        let phase = plan
258            .phases
259            .get(id)
260            .map(|p| format!(" {p:?}"))
261            .unwrap_or_default();
262        writeln!(out, "  {off}\t{sz}\t{id}{sched}{view}{phase}").unwrap();
263    }
264    out
265}
266
267fn inspect_phases(plan: &LirBufferPlan) -> String {
268    let mut out = String::from("  phases:\n");
269    for phase in [Phase::Prologue, Phase::SteadyState, Phase::Epilogue] {
270        let nodes = plan.nodes_in_phase(phase);
271        if !nodes.is_empty() {
272            writeln!(out, "    {phase:?}: {nodes:?}").unwrap();
273        }
274    }
275    out
276}
277
278fn inspect_io_manifest(plan: &LirBufferPlan) -> String {
279    let mut out = String::from("  io:\n");
280    for (name, id) in &plan.io.inputs {
281        writeln!(out, "    input \"{name}\" → {id}").unwrap();
282    }
283    for (name, id) in &plan.io.params {
284        writeln!(out, "    param \"{name}\" → {id}").unwrap();
285    }
286    if !plan.io.outputs.is_empty() {
287        write!(out, "    outputs: {:?}", plan.io.outputs).unwrap();
288        out.push('\n');
289    }
290    out
291}
292
293fn hir_op_kinds_line(hir: &HirModule) -> String {
294    let mut hist: BTreeMap<String, usize> = BTreeMap::new();
295    for node in hir.nodes() {
296        *hist.entry(hir_op_kind(&node.op)).or_insert(0) += 1;
297    }
298    let parts: Vec<String> = hist.into_iter().map(|(k, c)| format!("{k}={c}")).collect();
299    format!("  block ops: {}", parts.join(", "))
300}
301
302fn hir_op_kind(op: &HirOp) -> String {
303    match op {
304        HirOp::Input { .. } => "Input".into(),
305        HirOp::Param { .. } => "Param".into(),
306        HirOp::Constant { .. } => "Constant".into(),
307        HirOp::Linear { .. } => "Linear".into(),
308        HirOp::LinearFused { .. } => "LinearFused".into(),
309        HirOp::SharedLinearPair { .. } => "SharedLinearPair".into(),
310        HirOp::SwiGLU => "SwiGLU".into(),
311        HirOp::ResidualRmsNorm { .. } => "ResidualRmsNorm".into(),
312        HirOp::Attention { .. } => "Attention".into(),
313        HirOp::DepthwiseConv1dCausal { .. } => "DepthwiseConv1dCausal".into(),
314        HirOp::DequantMatMul { .. } => "DequantMatMul".into(),
315        HirOp::GatedDeltaNet { .. } => "GatedDeltaNet".into(),
316        HirOp::Lstm { .. } => "Lstm".into(),
317        HirOp::RoPE { .. } => "RoPE".into(),
318        HirOp::RmsNorm { .. } => "RmsNorm".into(),
319        HirOp::Mir(_) => "Mir".into(),
320        HirOp::LlamaDecoderBlock { .. } => "LlamaDecoderBlock".into(),
321        HirOp::Qwen35MtpHead { .. } => "Qwen35MtpHead".into(),
322    }
323}
324
325fn hir_node_tag(node: &HirNode) -> String {
326    let label: Option<String> = match &node.op {
327        HirOp::Input { name } => Some(format!("input \"{name}\"")),
328        HirOp::Param { name } => Some(format!("param \"{name}\"")),
329        _ => node.name.as_deref().map(|s| format!("\"{s}\"")),
330    };
331    match label {
332        Some(s) => format!("{} [{s}]", node.id),
333        None => format!("{}", node.id),
334    }
335}
336
337fn format_hir_op(op: &HirOp) -> String {
338    match op {
339        HirOp::Input { name } => format!("input(\"{name}\")"),
340        HirOp::Param { name } => format!("param(\"{name}\")"),
341        HirOp::Constant { data } => format!("constant({} bytes)", data.len()),
342        HirOp::Linear {
343            activation,
344            has_bias,
345        } => {
346            let mut s = String::from("linear");
347            if *has_bias {
348                s.push_str("+bias");
349            }
350            if let Some(act) = activation {
351                write!(s, "+{act:?}").unwrap();
352            }
353            s
354        }
355        HirOp::LinearFused { activation } => match activation {
356            Some(act) => format!("linear_fused({act:?})"),
357            None => "linear_fused".into(),
358        },
359        HirOp::SharedLinearPair { slot } => format!("shared_linear_pair(out={slot})"),
360        HirOp::SwiGLU => "swiglu_ffn".into(),
361        HirOp::ResidualRmsNorm { eps } => format!("residual_rms_norm(eps={eps})"),
362        HirOp::Attention {
363            num_heads,
364            head_dim,
365            mask,
366        } => format!("attention(heads={num_heads}, dim={head_dim}, mask={mask:?})"),
367        HirOp::DepthwiseConv1dCausal { kernel_size } => {
368            format!("depthwise_conv1d_causal(k={kernel_size})")
369        }
370        HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
371        HirOp::GatedDeltaNet {
372            state_size,
373            carry_state,
374        } => {
375            if *carry_state {
376                format!("gated_delta_net(n={state_size},carry)")
377            } else {
378                format!("gated_delta_net(n={state_size})")
379            }
380        }
381        HirOp::Lstm {
382            hidden_size,
383            num_layers,
384            bidirectional,
385            ..
386        } => {
387            let dir = if *bidirectional { "bi" } else { "uni" };
388            format!("lstm(h={hidden_size},layers={num_layers},{dir})")
389        }
390        HirOp::RoPE { head_dim, n_rot } => format!("rope(d={head_dim}, n_rot={n_rot})"),
391        HirOp::RmsNorm { eps } => format!("rms_norm(eps={eps})"),
392        HirOp::LlamaDecoderBlock {
393            num_heads,
394            head_dim,
395            num_kv_heads,
396            eps,
397            mask,
398        } => format!(
399            "llama_decoder_block(heads={num_heads}, dim={head_dim}, kv={num_kv_heads}, eps={eps}, mask={mask:?})"
400        ),
401        HirOp::Qwen35MtpHead {
402            num_heads,
403            head_dim,
404            mtp_vocab,
405            ..
406        } => format!("qwen35_mtp_head(heads={num_heads}, dim={head_dim}, vocab={mtp_vocab})"),
407        HirOp::Mir(inner) => format!("mir({inner})"),
408    }
409}
410
411// ── convenience methods on pipeline types ───────────────────────────────
412
413impl HirModule {
414    /// Text dump for inspection. Alias for [`inspect_hir`].
415    pub fn inspect(&self) -> String {
416        inspect_hir(self)
417    }
418}
419
420impl MirModule {
421    /// Text dump for inspection. Alias for [`inspect_mir`].
422    pub fn inspect(&self) -> String {
423        inspect_mir(self)
424    }
425}
426
427impl LirModule {
428    /// Text dump for inspection. Alias for [`inspect_lir`].
429    pub fn inspect(&self) -> String {
430        inspect_lir(self)
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::DType;
438    use crate::Shape;
439
440    fn f32_shape(d: &[usize]) -> Shape {
441        Shape::new(d, DType::F32)
442    }
443
444    #[test]
445    fn inspect_hir_includes_blocks_and_outputs() {
446        let mut hir = HirModule::new("layer");
447        let x = hir.input("x", f32_shape(&[2, 128]));
448        let w = hir.param("w", f32_shape(&[128, 128]));
449        let h = hir.linear(x, w, None, None, f32_shape(&[2, 128]));
450        hir.outputs = vec![h];
451
452        let text = inspect_hir(&hir);
453        assert!(text.contains("hir @layer"));
454        assert!(text.contains("linear"));
455        assert!(text.contains("← output"));
456        assert!(text.contains("fusion=Direct"));
457    }
458
459    #[test]
460    fn inspect_mir_wraps_pretty_print() {
461        let mut hir = HirModule::new("m");
462        let x = hir.input("x", f32_shape(&[4]));
463        hir.outputs = vec![x];
464        let mir = hir.lower_to_mir().expect("lower");
465
466        let text = inspect_mir(&mir);
467        assert!(text.contains("mir @m"));
468        assert!(text.contains("graph @m"));
469        assert!(text.contains("input(\"x\")"));
470    }
471
472    #[test]
473    fn named_block_appears_in_hir_dump() {
474        let mut hir = HirModule::new("layer");
475        let x = hir.input("x", f32_shape(&[2, 8]));
476        let w = hir.param("w", f32_shape(&[8, 8]));
477        let out = hir.named("layer0.ffn", |h| {
478            h.linear(x, w, None, None, f32_shape(&[2, 8]))
479        });
480        hir.outputs = vec![out];
481
482        let text = inspect_hir(&hir);
483        assert!(text.contains("layer0.ffn"));
484    }
485
486    #[test]
487    fn provenance_survives_lower() {
488        let mut hir = HirModule::new("m");
489        let x = hir.input("x", f32_shape(&[2, 8]));
490        let w = hir.param("w", f32_shape(&[8, 8]));
491        let out = hir.named("block", |h| h.linear(x, w, None, None, f32_shape(&[2, 8])));
492        hir.outputs = vec![out];
493
494        let mir = hir.lower_to_mir().expect("lower");
495        let text = inspect_mir(&mir);
496        assert!(text.contains("hir=h"));
497        assert!(text.contains("block"));
498    }
499
500    #[test]
501    fn inspect_lir_includes_buffer_plan() {
502        use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
503
504        let mut hir = HirModule::new("l");
505        let x = hir.input("x", f32_shape(&[4]));
506        hir.outputs = vec![x];
507        let mir = hir.lower_to_mir().expect("lower");
508        let plan = LirBufferPlan {
509            arena_size: 16,
510            assignments: [(
511                NodeId(0),
512                LirBufferSlot {
513                    offset: 0,
514                    size: 16,
515                },
516            )]
517            .into_iter()
518            .collect(),
519            schedule: vec![NodeId(0)],
520            io: LirIoManifest {
521                inputs: vec![("x".into(), NodeId(0))],
522                ..Default::default()
523            },
524            ..Default::default()
525        };
526        let lir = LirModule::new(mir, plan);
527
528        let text = inspect_lir(&lir);
529        assert!(text.contains("lir @l"));
530        assert!(text.contains("arena: 16 bytes"));
531        assert!(text.contains("fingerprint:"));
532        assert!(text.contains("--- mir ---"));
533    }
534}