1use std::collections::BTreeMap;
24use std::fmt::Write as _;
25
26use crate::hir::{HirModule, HirNode, HirOp};
27use crate::lir::{LirBufferPlan, LirModule, LirViewAlias};
28use crate::mir::MirModule;
29use crate::phase::Phase;
30use crate::pretty::{header_line, op_kinds_line, pretty_print};
31use crate::{Graph, NodeId};
32
33pub fn inspect_hir(hir: &HirModule) -> String {
35 let mut out = String::new();
36 writeln!(
37 out,
38 "hir @{} ({} nodes, {} outputs, fusion={:?})",
39 hir.name,
40 hir.len(),
41 hir.outputs.len(),
42 hir.fusion_policy,
43 )
44 .unwrap();
45 writeln!(out, "{}", hir_op_kinds_line(hir)).unwrap();
46 writeln!(out).unwrap();
47
48 let mut tag_w = 0usize;
49 for node in hir.nodes() {
50 let t = hir_node_tag(node);
51 tag_w = tag_w.max(t.len());
52 }
53
54 for node in hir.nodes() {
55 let tag = hir_node_tag(node);
56 write!(out, " {tag:<width$} = ", width = tag_w).unwrap();
57 write!(out, "{}", format_hir_op(&node.op)).unwrap();
58 if !node.inputs.is_empty() {
59 write!(out, "(").unwrap();
60 for (i, inp) in node.inputs.iter().enumerate() {
61 if i > 0 {
62 write!(out, ", ").unwrap();
63 }
64 write!(out, "{inp}").unwrap();
65 }
66 write!(out, ")").unwrap();
67 }
68 write!(out, " : {}", node.shape).unwrap();
69 if hir.outputs.contains(&node.id) {
70 write!(out, " ← output").unwrap();
71 }
72 writeln!(out).unwrap();
73 }
74 if !hir.outputs.is_empty() {
75 write!(out, " return ").unwrap();
76 for (i, o) in hir.outputs.iter().enumerate() {
77 if i > 0 {
78 write!(out, ", ").unwrap();
79 }
80 write!(out, "{o}").unwrap();
81 }
82 writeln!(out).unwrap();
83 }
84 out
85}
86
87pub fn inspect_mir(mir: &MirModule) -> String {
89 inspect_mir_with_diff(mir, None)
90}
91
92pub fn inspect_mir_with_diff(mir: &MirModule, before: Option<&MirModule>) -> String {
94 let g = mir.as_graph();
95 let mut out = String::new();
96 writeln!(out, "mir @{} {{", mir.name()).unwrap();
97 if let Some(b) = before {
98 writeln!(out).unwrap();
99 out.push_str(&inspect_graph_diff(b.as_graph(), g));
100 writeln!(out).unwrap();
101 writeln!(out, "--- graph ---").unwrap();
102 }
103 writeln!(out).unwrap();
104 out.push_str(&pretty_print(g));
105 if !out.ends_with('\n') {
106 out.push('\n');
107 }
108 write!(out, "}}").unwrap();
109 out
110}
111
112pub fn inspect_mir_diff(before: &MirModule, after: &MirModule) -> String {
114 inspect_graph_diff(before.as_graph(), after.as_graph())
115}
116
117pub fn inspect_graph_diff(before: &Graph, after: &Graph) -> String {
119 use std::collections::BTreeMap;
120
121 let mut out = String::new();
122 writeln!(
123 out,
124 " diff: {} → {} nodes ({} → {} outputs)",
125 before.len(),
126 after.len(),
127 before.outputs.len(),
128 after.outputs.len(),
129 )
130 .unwrap();
131
132 let count_kinds = |g: &Graph| {
133 let mut h: BTreeMap<String, i32> = BTreeMap::new();
134 for n in g.nodes() {
135 *h.entry(format!("{:?}", n.op.kind())).or_insert(0) += 1;
136 }
137 h
138 };
139 let b = count_kinds(before);
140 let a = count_kinds(after);
141 let mut keys: Vec<String> = b.keys().chain(a.keys()).cloned().collect();
142 keys.sort();
143 keys.dedup();
144 let mut changes = Vec::new();
145 for k in keys {
146 let d = a.get(&k).copied().unwrap_or(0) - b.get(&k).copied().unwrap_or(0);
147 if d != 0 {
148 changes.push(format!("{k}{d:+}"));
149 }
150 }
151 if !changes.is_empty() {
152 writeln!(out, " op delta: {}", changes.join(", ")).unwrap();
153 }
154 out
155}
156
157pub fn inspect_lir(lir: &LirModule) -> String {
159 let mut out = String::new();
160 writeln!(out, "lir @{} {{", lir.name()).unwrap();
161 writeln!(out, " fingerprint: {:016x}", lir.fingerprint().0).unwrap();
162 writeln!(out).unwrap();
163 out.push_str(&inspect_buffer_plan(&lir.buffers));
164 if !lir.buffers.phases.is_empty() {
165 writeln!(out).unwrap();
166 out.push_str(&inspect_phases(&lir.buffers));
167 }
168 if !lir.buffers.io.inputs.is_empty() || !lir.buffers.io.params.is_empty() {
169 writeln!(out).unwrap();
170 out.push_str(&inspect_io_manifest(&lir.buffers));
171 }
172 writeln!(out).unwrap();
173 writeln!(out, "--- mir ---").unwrap();
174 out.push_str(&pretty_print(lir.as_graph()));
175 if !out.ends_with('\n') {
176 out.push('\n');
177 }
178 write!(out, "}}").unwrap();
179 out
180}
181
182pub fn inspect_graph(g: &Graph) -> String {
184 pretty_print(g)
185}
186
187pub fn inspect_hir_stats(hir: &HirModule) -> String {
189 format!(
190 "hir @{} ({} nodes, {} outputs, fusion={:?})\n{}",
191 hir.name,
192 hir.len(),
193 hir.outputs.len(),
194 hir.fusion_policy,
195 hir_op_kinds_line(hir),
196 )
197}
198
199pub fn inspect_mir_stats(mir: &MirModule) -> String {
201 let g = mir.as_graph();
202 format!(
203 "mir @{} — {}\n{}",
204 mir.name(),
205 header_line(g),
206 op_kinds_line(g),
207 )
208}
209
210pub fn inspect_buffer_plan(plan: &LirBufferPlan) -> String {
212 let mut out = String::new();
213 let saved = plan.bytes_saved();
214 let naive = plan.total_unshared_bytes();
215 writeln!(
216 out,
217 " arena: {} bytes (saved {} vs {} naive, align={})",
218 plan.arena_size, saved, naive, plan.alignment,
219 )
220 .unwrap();
221 writeln!(
222 out,
223 " schedule: {} nodes, {} views",
224 plan.schedule.len(),
225 plan.view_aliases.len(),
226 )
227 .unwrap();
228 if !plan.dynamic_symbols.is_empty() {
229 let syms: Vec<String> = plan
230 .dynamic_symbols
231 .iter()
232 .map(|s| format!("?{s}"))
233 .collect();
234 writeln!(out, " dynamic: {}", syms.join(", ")).unwrap();
235 }
236 writeln!(out).unwrap();
237 writeln!(out, " # offset\tsize\tnode").unwrap();
238
239 let mut rows: Vec<(usize, usize, NodeId)> = plan
240 .assignments
241 .iter()
242 .map(|(id, slot)| (slot.offset, slot.size, *id))
243 .collect();
244 rows.sort_by_key(|(off, _, _)| *off);
245 for (off, sz, id) in rows {
246 let sched = plan
247 .schedule
248 .iter()
249 .position(|&n| n == id)
250 .map(|i| format!(" sched={i}"))
251 .unwrap_or_default();
252 let view = plan
253 .view_aliases
254 .get(&id)
255 .map(|LirViewAlias { root, byte_offset }| format!(" view→{root}+{byte_offset}"))
256 .unwrap_or_default();
257 let phase = plan
258 .phases
259 .get(id)
260 .map(|p| format!(" {p:?}"))
261 .unwrap_or_default();
262 writeln!(out, " {off}\t{sz}\t{id}{sched}{view}{phase}").unwrap();
263 }
264 out
265}
266
267fn inspect_phases(plan: &LirBufferPlan) -> String {
268 let mut out = String::from(" phases:\n");
269 for phase in [Phase::Prologue, Phase::SteadyState, Phase::Epilogue] {
270 let nodes = plan.nodes_in_phase(phase);
271 if !nodes.is_empty() {
272 writeln!(out, " {phase:?}: {nodes:?}").unwrap();
273 }
274 }
275 out
276}
277
278fn inspect_io_manifest(plan: &LirBufferPlan) -> String {
279 let mut out = String::from(" io:\n");
280 for (name, id) in &plan.io.inputs {
281 writeln!(out, " input \"{name}\" → {id}").unwrap();
282 }
283 for (name, id) in &plan.io.params {
284 writeln!(out, " param \"{name}\" → {id}").unwrap();
285 }
286 if !plan.io.outputs.is_empty() {
287 write!(out, " outputs: {:?}", plan.io.outputs).unwrap();
288 out.push('\n');
289 }
290 out
291}
292
293fn hir_op_kinds_line(hir: &HirModule) -> String {
294 let mut hist: BTreeMap<String, usize> = BTreeMap::new();
295 for node in hir.nodes() {
296 *hist.entry(hir_op_kind(&node.op)).or_insert(0) += 1;
297 }
298 let parts: Vec<String> = hist.into_iter().map(|(k, c)| format!("{k}={c}")).collect();
299 format!(" block ops: {}", parts.join(", "))
300}
301
302fn hir_op_kind(op: &HirOp) -> String {
303 match op {
304 HirOp::Input { .. } => "Input".into(),
305 HirOp::Param { .. } => "Param".into(),
306 HirOp::Constant { .. } => "Constant".into(),
307 HirOp::Linear { .. } => "Linear".into(),
308 HirOp::LinearFused { .. } => "LinearFused".into(),
309 HirOp::SharedLinearPair { .. } => "SharedLinearPair".into(),
310 HirOp::SwiGLU => "SwiGLU".into(),
311 HirOp::ResidualRmsNorm { .. } => "ResidualRmsNorm".into(),
312 HirOp::Attention { .. } => "Attention".into(),
313 HirOp::DepthwiseConv1dCausal { .. } => "DepthwiseConv1dCausal".into(),
314 HirOp::DequantMatMul { .. } => "DequantMatMul".into(),
315 HirOp::GatedDeltaNet { .. } => "GatedDeltaNet".into(),
316 HirOp::RoPE { .. } => "RoPE".into(),
317 HirOp::RmsNorm { .. } => "RmsNorm".into(),
318 HirOp::Mir(_) => "Mir".into(),
319 HirOp::LlamaDecoderBlock { .. } => "LlamaDecoderBlock".into(),
320 HirOp::Qwen35MtpHead { .. } => "Qwen35MtpHead".into(),
321 }
322}
323
324fn hir_node_tag(node: &HirNode) -> String {
325 let label: Option<String> = match &node.op {
326 HirOp::Input { name } => Some(format!("input \"{name}\"")),
327 HirOp::Param { name } => Some(format!("param \"{name}\"")),
328 _ => node.name.as_deref().map(|s| format!("\"{s}\"")),
329 };
330 match label {
331 Some(s) => format!("{} [{s}]", node.id),
332 None => format!("{}", node.id),
333 }
334}
335
336fn format_hir_op(op: &HirOp) -> String {
337 match op {
338 HirOp::Input { name } => format!("input(\"{name}\")"),
339 HirOp::Param { name } => format!("param(\"{name}\")"),
340 HirOp::Constant { data } => format!("constant({} bytes)", data.len()),
341 HirOp::Linear {
342 activation,
343 has_bias,
344 } => {
345 let mut s = String::from("linear");
346 if *has_bias {
347 s.push_str("+bias");
348 }
349 if let Some(act) = activation {
350 write!(s, "+{act:?}").unwrap();
351 }
352 s
353 }
354 HirOp::LinearFused { activation } => match activation {
355 Some(act) => format!("linear_fused({act:?})"),
356 None => "linear_fused".into(),
357 },
358 HirOp::SharedLinearPair { slot } => format!("shared_linear_pair(out={slot})"),
359 HirOp::SwiGLU => "swiglu_ffn".into(),
360 HirOp::ResidualRmsNorm { eps } => format!("residual_rms_norm(eps={eps})"),
361 HirOp::Attention {
362 num_heads,
363 head_dim,
364 mask,
365 } => format!("attention(heads={num_heads}, dim={head_dim}, mask={mask:?})"),
366 HirOp::DepthwiseConv1dCausal { kernel_size } => {
367 format!("depthwise_conv1d_causal(k={kernel_size})")
368 }
369 HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
370 HirOp::GatedDeltaNet {
371 state_size,
372 carry_state,
373 } => {
374 if *carry_state {
375 format!("gated_delta_net(n={state_size},carry)")
376 } else {
377 format!("gated_delta_net(n={state_size})")
378 }
379 }
380 HirOp::RoPE { head_dim, n_rot } => format!("rope(d={head_dim}, n_rot={n_rot})"),
381 HirOp::RmsNorm { eps } => format!("rms_norm(eps={eps})"),
382 HirOp::LlamaDecoderBlock {
383 num_heads,
384 head_dim,
385 num_kv_heads,
386 eps,
387 mask,
388 } => format!(
389 "llama_decoder_block(heads={num_heads}, dim={head_dim}, kv={num_kv_heads}, eps={eps}, mask={mask:?})"
390 ),
391 HirOp::Qwen35MtpHead {
392 num_heads,
393 head_dim,
394 mtp_vocab,
395 ..
396 } => format!("qwen35_mtp_head(heads={num_heads}, dim={head_dim}, vocab={mtp_vocab})"),
397 HirOp::Mir(inner) => format!("mir({inner})"),
398 }
399}
400
401impl HirModule {
404 pub fn inspect(&self) -> String {
406 inspect_hir(self)
407 }
408}
409
410impl MirModule {
411 pub fn inspect(&self) -> String {
413 inspect_mir(self)
414 }
415}
416
417impl LirModule {
418 pub fn inspect(&self) -> String {
420 inspect_lir(self)
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::DType;
428 use crate::Shape;
429
430 fn f32_shape(d: &[usize]) -> Shape {
431 Shape::new(d, DType::F32)
432 }
433
434 #[test]
435 fn inspect_hir_includes_blocks_and_outputs() {
436 let mut hir = HirModule::new("layer");
437 let x = hir.input("x", f32_shape(&[2, 128]));
438 let w = hir.param("w", f32_shape(&[128, 128]));
439 let h = hir.linear(x, w, None, None, f32_shape(&[2, 128]));
440 hir.outputs = vec![h];
441
442 let text = inspect_hir(&hir);
443 assert!(text.contains("hir @layer"));
444 assert!(text.contains("linear"));
445 assert!(text.contains("← output"));
446 assert!(text.contains("fusion=Direct"));
447 }
448
449 #[test]
450 fn inspect_mir_wraps_pretty_print() {
451 let mut hir = HirModule::new("m");
452 let x = hir.input("x", f32_shape(&[4]));
453 hir.outputs = vec![x];
454 let mir = hir.lower_to_mir().expect("lower");
455
456 let text = inspect_mir(&mir);
457 assert!(text.contains("mir @m"));
458 assert!(text.contains("graph @m"));
459 assert!(text.contains("input(\"x\")"));
460 }
461
462 #[test]
463 fn named_block_appears_in_hir_dump() {
464 let mut hir = HirModule::new("layer");
465 let x = hir.input("x", f32_shape(&[2, 8]));
466 let w = hir.param("w", f32_shape(&[8, 8]));
467 let out = hir.named("layer0.ffn", |h| {
468 h.linear(x, w, None, None, f32_shape(&[2, 8]))
469 });
470 hir.outputs = vec![out];
471
472 let text = inspect_hir(&hir);
473 assert!(text.contains("layer0.ffn"));
474 }
475
476 #[test]
477 fn provenance_survives_lower() {
478 let mut hir = HirModule::new("m");
479 let x = hir.input("x", f32_shape(&[2, 8]));
480 let w = hir.param("w", f32_shape(&[8, 8]));
481 let out = hir.named("block", |h| h.linear(x, w, None, None, f32_shape(&[2, 8])));
482 hir.outputs = vec![out];
483
484 let mir = hir.lower_to_mir().expect("lower");
485 let text = inspect_mir(&mir);
486 assert!(text.contains("hir=h"));
487 assert!(text.contains("block"));
488 }
489
490 #[test]
491 fn inspect_lir_includes_buffer_plan() {
492 use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
493
494 let mut hir = HirModule::new("l");
495 let x = hir.input("x", f32_shape(&[4]));
496 hir.outputs = vec![x];
497 let mir = hir.lower_to_mir().expect("lower");
498 let plan = LirBufferPlan {
499 arena_size: 16,
500 assignments: [(
501 NodeId(0),
502 LirBufferSlot {
503 offset: 0,
504 size: 16,
505 },
506 )]
507 .into_iter()
508 .collect(),
509 schedule: vec![NodeId(0)],
510 io: LirIoManifest {
511 inputs: vec![("x".into(), NodeId(0))],
512 ..Default::default()
513 },
514 ..Default::default()
515 };
516 let lir = LirModule::new(mir, plan);
517
518 let text = inspect_lir(&lir);
519 assert!(text.contains("lir @l"));
520 assert!(text.contains("arena: 16 bytes"));
521 assert!(text.contains("fingerprint:"));
522 assert!(text.contains("--- mir ---"));
523 }
524}