1use std::collections::BTreeMap;
24use std::fmt::Write as _;
25
26use crate::hir::{HirModule, HirNode, HirOp};
27use crate::lir::{LirBufferPlan, LirModule, LirViewAlias};
28use crate::mir::MirModule;
29use crate::phase::Phase;
30use crate::pretty::{header_line, op_kinds_line, pretty_print};
31use crate::{Graph, NodeId};
32
33pub fn inspect_hir(hir: &HirModule) -> String {
35 let mut out = String::new();
36 writeln!(
37 out,
38 "hir @{} ({} nodes, {} outputs, fusion={:?})",
39 hir.name,
40 hir.len(),
41 hir.outputs.len(),
42 hir.fusion_policy,
43 )
44 .unwrap();
45 writeln!(out, "{}", hir_op_kinds_line(hir)).unwrap();
46 writeln!(out).unwrap();
47
48 let mut tag_w = 0usize;
49 for node in hir.nodes() {
50 let t = hir_node_tag(node);
51 tag_w = tag_w.max(t.len());
52 }
53
54 for node in hir.nodes() {
55 let tag = hir_node_tag(node);
56 write!(out, " {tag:<width$} = ", width = tag_w).unwrap();
57 write!(out, "{}", format_hir_op(&node.op)).unwrap();
58 if !node.inputs.is_empty() {
59 write!(out, "(").unwrap();
60 for (i, inp) in node.inputs.iter().enumerate() {
61 if i > 0 {
62 write!(out, ", ").unwrap();
63 }
64 write!(out, "{inp}").unwrap();
65 }
66 write!(out, ")").unwrap();
67 }
68 write!(out, " : {}", node.shape).unwrap();
69 if hir.outputs.contains(&node.id) {
70 write!(out, " ← output").unwrap();
71 }
72 writeln!(out).unwrap();
73 }
74 if !hir.outputs.is_empty() {
75 write!(out, " return ").unwrap();
76 for (i, o) in hir.outputs.iter().enumerate() {
77 if i > 0 {
78 write!(out, ", ").unwrap();
79 }
80 write!(out, "{o}").unwrap();
81 }
82 writeln!(out).unwrap();
83 }
84 out
85}
86
87pub fn inspect_mir(mir: &MirModule) -> String {
89 inspect_mir_with_diff(mir, None)
90}
91
92pub fn inspect_mir_with_diff(mir: &MirModule, before: Option<&MirModule>) -> String {
94 let g = mir.as_graph();
95 let mut out = String::new();
96 writeln!(out, "mir @{} {{", mir.name()).unwrap();
97 if let Some(b) = before {
98 writeln!(out).unwrap();
99 out.push_str(&inspect_graph_diff(b.as_graph(), g));
100 writeln!(out).unwrap();
101 writeln!(out, "--- graph ---").unwrap();
102 }
103 writeln!(out).unwrap();
104 out.push_str(&pretty_print(g));
105 if !out.ends_with('\n') {
106 out.push('\n');
107 }
108 write!(out, "}}").unwrap();
109 out
110}
111
112pub fn inspect_mir_diff(before: &MirModule, after: &MirModule) -> String {
114 inspect_graph_diff(before.as_graph(), after.as_graph())
115}
116
117pub fn inspect_graph_diff(before: &Graph, after: &Graph) -> String {
119 use std::collections::BTreeMap;
120
121 let mut out = String::new();
122 writeln!(
123 out,
124 " diff: {} → {} nodes ({} → {} outputs)",
125 before.len(),
126 after.len(),
127 before.outputs.len(),
128 after.outputs.len(),
129 )
130 .unwrap();
131
132 let count_kinds = |g: &Graph| {
133 let mut h: BTreeMap<String, i32> = BTreeMap::new();
134 for n in g.nodes() {
135 *h.entry(format!("{:?}", n.op.kind())).or_insert(0) += 1;
136 }
137 h
138 };
139 let b = count_kinds(before);
140 let a = count_kinds(after);
141 let mut keys: Vec<String> = b.keys().chain(a.keys()).cloned().collect();
142 keys.sort();
143 keys.dedup();
144 let mut changes = Vec::new();
145 for k in keys {
146 let d = a.get(&k).copied().unwrap_or(0) - b.get(&k).copied().unwrap_or(0);
147 if d != 0 {
148 changes.push(format!("{k}{d:+}"));
149 }
150 }
151 if !changes.is_empty() {
152 writeln!(out, " op delta: {}", changes.join(", ")).unwrap();
153 }
154 out
155}
156
157pub fn inspect_lir(lir: &LirModule) -> String {
159 let mut out = String::new();
160 writeln!(out, "lir @{} {{", lir.name()).unwrap();
161 writeln!(out, " fingerprint: {:016x}", lir.fingerprint().0).unwrap();
162 writeln!(out).unwrap();
163 out.push_str(&inspect_buffer_plan(&lir.buffers));
164 if !lir.buffers.phases.is_empty() {
165 writeln!(out).unwrap();
166 out.push_str(&inspect_phases(&lir.buffers));
167 }
168 if !lir.buffers.io.inputs.is_empty() || !lir.buffers.io.params.is_empty() {
169 writeln!(out).unwrap();
170 out.push_str(&inspect_io_manifest(&lir.buffers));
171 }
172 writeln!(out).unwrap();
173 writeln!(out, "--- mir ---").unwrap();
174 out.push_str(&pretty_print(lir.as_graph()));
175 if !out.ends_with('\n') {
176 out.push('\n');
177 }
178 write!(out, "}}").unwrap();
179 out
180}
181
182pub fn inspect_graph(g: &Graph) -> String {
184 pretty_print(g)
185}
186
187pub fn inspect_hir_stats(hir: &HirModule) -> String {
189 format!(
190 "hir @{} ({} nodes, {} outputs, fusion={:?})\n{}",
191 hir.name,
192 hir.len(),
193 hir.outputs.len(),
194 hir.fusion_policy,
195 hir_op_kinds_line(hir),
196 )
197}
198
199pub fn inspect_mir_stats(mir: &MirModule) -> String {
201 let g = mir.as_graph();
202 format!(
203 "mir @{} — {}\n{}",
204 mir.name(),
205 header_line(g),
206 op_kinds_line(g),
207 )
208}
209
210pub fn inspect_buffer_plan(plan: &LirBufferPlan) -> String {
212 let mut out = String::new();
213 let saved = plan.bytes_saved();
214 let naive = plan.total_unshared_bytes();
215 writeln!(
216 out,
217 " arena: {} bytes (saved {} vs {} naive, align={})",
218 plan.arena_size, saved, naive, plan.alignment,
219 )
220 .unwrap();
221 writeln!(
222 out,
223 " schedule: {} nodes, {} views",
224 plan.schedule.len(),
225 plan.view_aliases.len(),
226 )
227 .unwrap();
228 if !plan.dynamic_symbols.is_empty() {
229 let syms: Vec<String> = plan
230 .dynamic_symbols
231 .iter()
232 .map(|s| format!("?{s}"))
233 .collect();
234 writeln!(out, " dynamic: {}", syms.join(", ")).unwrap();
235 }
236 writeln!(out).unwrap();
237 writeln!(out, " # offset\tsize\tnode").unwrap();
238
239 let mut rows: Vec<(usize, usize, NodeId)> = plan
240 .assignments
241 .iter()
242 .map(|(id, slot)| (slot.offset, slot.size, *id))
243 .collect();
244 rows.sort_by_key(|(off, _, _)| *off);
245 for (off, sz, id) in rows {
246 let sched = plan
247 .schedule
248 .iter()
249 .position(|&n| n == id)
250 .map(|i| format!(" sched={i}"))
251 .unwrap_or_default();
252 let view = plan
253 .view_aliases
254 .get(&id)
255 .map(|LirViewAlias { root, byte_offset }| format!(" view→{root}+{byte_offset}"))
256 .unwrap_or_default();
257 let phase = plan
258 .phases
259 .get(id)
260 .map(|p| format!(" {p:?}"))
261 .unwrap_or_default();
262 writeln!(out, " {off}\t{sz}\t{id}{sched}{view}{phase}").unwrap();
263 }
264 out
265}
266
267fn inspect_phases(plan: &LirBufferPlan) -> String {
268 let mut out = String::from(" phases:\n");
269 for phase in [Phase::Prologue, Phase::SteadyState, Phase::Epilogue] {
270 let nodes = plan.nodes_in_phase(phase);
271 if !nodes.is_empty() {
272 writeln!(out, " {phase:?}: {nodes:?}").unwrap();
273 }
274 }
275 out
276}
277
278fn inspect_io_manifest(plan: &LirBufferPlan) -> String {
279 let mut out = String::from(" io:\n");
280 for (name, id) in &plan.io.inputs {
281 writeln!(out, " input \"{name}\" → {id}").unwrap();
282 }
283 for (name, id) in &plan.io.params {
284 writeln!(out, " param \"{name}\" → {id}").unwrap();
285 }
286 if !plan.io.outputs.is_empty() {
287 write!(out, " outputs: {:?}", plan.io.outputs).unwrap();
288 out.push('\n');
289 }
290 out
291}
292
293fn hir_op_kinds_line(hir: &HirModule) -> String {
294 let mut hist: BTreeMap<String, usize> = BTreeMap::new();
295 for node in hir.nodes() {
296 *hist.entry(hir_op_kind(&node.op)).or_insert(0) += 1;
297 }
298 let parts: Vec<String> = hist.into_iter().map(|(k, c)| format!("{k}={c}")).collect();
299 format!(" block ops: {}", parts.join(", "))
300}
301
302fn hir_op_kind(op: &HirOp) -> String {
303 match op {
304 HirOp::Input { .. } => "Input".into(),
305 HirOp::Param { .. } => "Param".into(),
306 HirOp::Constant { .. } => "Constant".into(),
307 HirOp::Linear { .. } => "Linear".into(),
308 HirOp::LinearFused { .. } => "LinearFused".into(),
309 HirOp::SharedLinearPair { .. } => "SharedLinearPair".into(),
310 HirOp::SwiGLU => "SwiGLU".into(),
311 HirOp::ResidualRmsNorm { .. } => "ResidualRmsNorm".into(),
312 HirOp::Attention { .. } => "Attention".into(),
313 HirOp::DepthwiseConv1dCausal { .. } => "DepthwiseConv1dCausal".into(),
314 HirOp::DequantMatMul { .. } => "DequantMatMul".into(),
315 HirOp::GatedDeltaNet { .. } => "GatedDeltaNet".into(),
316 HirOp::Lstm { .. } => "Lstm".into(),
317 HirOp::RoPE { .. } => "RoPE".into(),
318 HirOp::RmsNorm { .. } => "RmsNorm".into(),
319 HirOp::Mir(_) => "Mir".into(),
320 HirOp::LlamaDecoderBlock { .. } => "LlamaDecoderBlock".into(),
321 HirOp::Qwen35MtpHead { .. } => "Qwen35MtpHead".into(),
322 }
323}
324
325fn hir_node_tag(node: &HirNode) -> String {
326 let label: Option<String> = match &node.op {
327 HirOp::Input { name } => Some(format!("input \"{name}\"")),
328 HirOp::Param { name } => Some(format!("param \"{name}\"")),
329 _ => node.name.as_deref().map(|s| format!("\"{s}\"")),
330 };
331 match label {
332 Some(s) => format!("{} [{s}]", node.id),
333 None => format!("{}", node.id),
334 }
335}
336
337fn format_hir_op(op: &HirOp) -> String {
338 match op {
339 HirOp::Input { name } => format!("input(\"{name}\")"),
340 HirOp::Param { name } => format!("param(\"{name}\")"),
341 HirOp::Constant { data } => format!("constant({} bytes)", data.len()),
342 HirOp::Linear {
343 activation,
344 has_bias,
345 } => {
346 let mut s = String::from("linear");
347 if *has_bias {
348 s.push_str("+bias");
349 }
350 if let Some(act) = activation {
351 write!(s, "+{act:?}").unwrap();
352 }
353 s
354 }
355 HirOp::LinearFused { activation } => match activation {
356 Some(act) => format!("linear_fused({act:?})"),
357 None => "linear_fused".into(),
358 },
359 HirOp::SharedLinearPair { slot } => format!("shared_linear_pair(out={slot})"),
360 HirOp::SwiGLU => "swiglu_ffn".into(),
361 HirOp::ResidualRmsNorm { eps } => format!("residual_rms_norm(eps={eps})"),
362 HirOp::Attention {
363 num_heads,
364 head_dim,
365 mask,
366 } => format!("attention(heads={num_heads}, dim={head_dim}, mask={mask:?})"),
367 HirOp::DepthwiseConv1dCausal { kernel_size } => {
368 format!("depthwise_conv1d_causal(k={kernel_size})")
369 }
370 HirOp::DequantMatMul { scheme } => format!("dequant_matmul({scheme})"),
371 HirOp::GatedDeltaNet {
372 state_size,
373 carry_state,
374 } => {
375 if *carry_state {
376 format!("gated_delta_net(n={state_size},carry)")
377 } else {
378 format!("gated_delta_net(n={state_size})")
379 }
380 }
381 HirOp::Lstm {
382 hidden_size,
383 num_layers,
384 bidirectional,
385 ..
386 } => {
387 let dir = if *bidirectional { "bi" } else { "uni" };
388 format!("lstm(h={hidden_size},layers={num_layers},{dir})")
389 }
390 HirOp::RoPE { head_dim, n_rot } => format!("rope(d={head_dim}, n_rot={n_rot})"),
391 HirOp::RmsNorm { eps } => format!("rms_norm(eps={eps})"),
392 HirOp::LlamaDecoderBlock {
393 num_heads,
394 head_dim,
395 num_kv_heads,
396 eps,
397 mask,
398 } => format!(
399 "llama_decoder_block(heads={num_heads}, dim={head_dim}, kv={num_kv_heads}, eps={eps}, mask={mask:?})"
400 ),
401 HirOp::Qwen35MtpHead {
402 num_heads,
403 head_dim,
404 mtp_vocab,
405 ..
406 } => format!("qwen35_mtp_head(heads={num_heads}, dim={head_dim}, vocab={mtp_vocab})"),
407 HirOp::Mir(inner) => format!("mir({inner})"),
408 }
409}
410
411impl HirModule {
414 pub fn inspect(&self) -> String {
416 inspect_hir(self)
417 }
418}
419
420impl MirModule {
421 pub fn inspect(&self) -> String {
423 inspect_mir(self)
424 }
425}
426
427impl LirModule {
428 pub fn inspect(&self) -> String {
430 inspect_lir(self)
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::DType;
438 use crate::Shape;
439
440 fn f32_shape(d: &[usize]) -> Shape {
441 Shape::new(d, DType::F32)
442 }
443
444 #[test]
445 fn inspect_hir_includes_blocks_and_outputs() {
446 let mut hir = HirModule::new("layer");
447 let x = hir.input("x", f32_shape(&[2, 128]));
448 let w = hir.param("w", f32_shape(&[128, 128]));
449 let h = hir.linear(x, w, None, None, f32_shape(&[2, 128]));
450 hir.outputs = vec![h];
451
452 let text = inspect_hir(&hir);
453 assert!(text.contains("hir @layer"));
454 assert!(text.contains("linear"));
455 assert!(text.contains("← output"));
456 assert!(text.contains("fusion=Direct"));
457 }
458
459 #[test]
460 fn inspect_mir_wraps_pretty_print() {
461 let mut hir = HirModule::new("m");
462 let x = hir.input("x", f32_shape(&[4]));
463 hir.outputs = vec![x];
464 let mir = hir.lower_to_mir().expect("lower");
465
466 let text = inspect_mir(&mir);
467 assert!(text.contains("mir @m"));
468 assert!(text.contains("graph @m"));
469 assert!(text.contains("input(\"x\")"));
470 }
471
472 #[test]
473 fn named_block_appears_in_hir_dump() {
474 let mut hir = HirModule::new("layer");
475 let x = hir.input("x", f32_shape(&[2, 8]));
476 let w = hir.param("w", f32_shape(&[8, 8]));
477 let out = hir.named("layer0.ffn", |h| {
478 h.linear(x, w, None, None, f32_shape(&[2, 8]))
479 });
480 hir.outputs = vec![out];
481
482 let text = inspect_hir(&hir);
483 assert!(text.contains("layer0.ffn"));
484 }
485
486 #[test]
487 fn provenance_survives_lower() {
488 let mut hir = HirModule::new("m");
489 let x = hir.input("x", f32_shape(&[2, 8]));
490 let w = hir.param("w", f32_shape(&[8, 8]));
491 let out = hir.named("block", |h| h.linear(x, w, None, None, f32_shape(&[2, 8])));
492 hir.outputs = vec![out];
493
494 let mir = hir.lower_to_mir().expect("lower");
495 let text = inspect_mir(&mir);
496 assert!(text.contains("hir=h"));
497 assert!(text.contains("block"));
498 }
499
500 #[test]
501 fn inspect_lir_includes_buffer_plan() {
502 use crate::lir::{LirBufferPlan, LirBufferSlot, LirIoManifest};
503
504 let mut hir = HirModule::new("l");
505 let x = hir.input("x", f32_shape(&[4]));
506 hir.outputs = vec![x];
507 let mir = hir.lower_to_mir().expect("lower");
508 let plan = LirBufferPlan {
509 arena_size: 16,
510 assignments: [(
511 NodeId(0),
512 LirBufferSlot {
513 offset: 0,
514 size: 16,
515 },
516 )]
517 .into_iter()
518 .collect(),
519 schedule: vec![NodeId(0)],
520 io: LirIoManifest {
521 inputs: vec![("x".into(), NodeId(0))],
522 ..Default::default()
523 },
524 ..Default::default()
525 };
526 let lir = LirModule::new(mir, plan);
527
528 let text = inspect_lir(&lir);
529 assert!(text.contains("lir @l"));
530 assert!(text.contains("arena: 16 bytes"));
531 assert!(text.contains("fingerprint:"));
532 assert!(text.contains("--- mir ---"));
533 }
534}