Skip to main content

oxicuda_ptx/
tui_explorer.rs

1//! Visual PTX explorer for terminal-based PTX analysis.
2//!
3//! This module provides a set of rendering and analysis utilities that produce
4//! formatted ASCII/ANSI text output for inspecting PTX intermediate representation.
5//! It analyzes [`PtxModule`], [`PtxFunction`], [`Instruction`], and [`BasicBlock`]
6//! structures, producing pretty-printed PTX code, control flow graphs, register
7//! lifetime timelines, instruction mix bar charts, and more.
8//!
9//! This is **not** a live TUI application -- it requires no TUI framework. All
10//! output is returned as plain `String` values suitable for printing to a terminal
11//! or writing to a file.
12//!
13//! # Example
14//!
15//! ```
16//! use oxicuda_ptx::tui_explorer::{ExplorerConfig, PtxExplorer};
17//! use oxicuda_ptx::ir::{PtxFunction, PtxType};
18//!
19//! let config = ExplorerConfig::default();
20//! let explorer = PtxExplorer::new(config);
21//! let func = PtxFunction::new("my_kernel");
22//! let output = explorer.render_function(&func);
23//! assert!(!output.is_empty());
24//! ```
25
26use std::collections::HashMap;
27use std::fmt::Write;
28
29use crate::ir::{BasicBlock, Instruction, MemorySpace, Operand, PtxFunction, PtxModule};
30
31// ---------------------------------------------------------------------------
32// Configuration
33// ---------------------------------------------------------------------------
34
35/// Configuration for the PTX explorer rendering engine.
36#[derive(Debug, Clone)]
37#[allow(clippy::struct_excessive_bools)]
38pub struct ExplorerConfig {
39    /// Whether to emit ANSI color codes in output.
40    pub use_color: bool,
41    /// Maximum output width in columns.
42    pub max_width: usize,
43    /// Whether to show line numbers alongside instructions.
44    pub show_line_numbers: bool,
45    /// Whether to annotate registers with their types.
46    pub show_register_types: bool,
47    /// Whether to show estimated instruction latency.
48    pub show_instruction_latency: bool,
49}
50
51impl Default for ExplorerConfig {
52    fn default() -> Self {
53        Self {
54            use_color: false,
55            max_width: 120,
56            show_line_numbers: false,
57            show_register_types: false,
58            show_instruction_latency: false,
59        }
60    }
61}
62
63// ---------------------------------------------------------------------------
64// Instruction categorisation
65// ---------------------------------------------------------------------------
66
67/// Category of a PTX instruction for analysis purposes.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69pub enum InstructionCategory {
70    /// Arithmetic / math operations (add, mul, fma, etc.).
71    Arithmetic,
72    /// Memory operations (load, store, cp.async, atom, etc.).
73    Memory,
74    /// Control flow operations (branch, label, return).
75    Control,
76    /// Synchronization primitives (bar.sync, fence, mbarrier, etc.).
77    Synchronization,
78    /// Tensor Core operations (wmma, mma, wgmma).
79    TensorCore,
80    /// Special operations (mov.special, load.param, comment, raw, pragma).
81    Special,
82    /// Type conversion operations (cvt).
83    Conversion,
84}
85
86impl InstructionCategory {
87    /// Returns a human-readable label for this category.
88    #[must_use]
89    pub const fn label(self) -> &'static str {
90        match self {
91            Self::Arithmetic => "Arithmetic",
92            Self::Memory => "Memory",
93            Self::Control => "Control",
94            Self::Synchronization => "Sync",
95            Self::TensorCore => "TensorCore",
96            Self::Special => "Special",
97            Self::Conversion => "Conversion",
98        }
99    }
100
101    /// Returns an ANSI color escape code for this category.
102    const fn ansi_color(self) -> &'static str {
103        match self {
104            Self::Arithmetic => "\x1b[32m",      // green
105            Self::Memory => "\x1b[34m",          // blue
106            Self::Control => "\x1b[33m",         // yellow
107            Self::Synchronization => "\x1b[35m", // magenta
108            Self::TensorCore => "\x1b[36m",      // cyan
109            Self::Special => "\x1b[90m",         // bright black (gray)
110            Self::Conversion => "\x1b[37m",      // white
111        }
112    }
113}
114
115/// Detailed information about a single instruction.
116#[derive(Debug, Clone)]
117pub struct InstructionInfo {
118    /// The instruction text (PTX emission).
119    pub instruction: String,
120    /// Instruction category.
121    pub category: InstructionCategory,
122    /// Estimated latency in GPU clock cycles.
123    pub latency_cycles: u32,
124    /// Estimated throughput (instructions per SM per cycle).
125    pub throughput_per_sm: f64,
126    /// Registers read by this instruction.
127    pub registers_read: Vec<String>,
128    /// Registers written by this instruction.
129    pub registers_written: Vec<String>,
130}
131
132// ---------------------------------------------------------------------------
133// Register lifetime
134// ---------------------------------------------------------------------------
135
136/// Lifetime information for a single PTX register.
137#[derive(Debug, Clone)]
138pub struct RegisterLifetime {
139    /// Register name (e.g., `%f0`).
140    pub register: String,
141    /// Register type string (e.g., `.f32`).
142    pub reg_type: String,
143    /// Instruction index of the first definition.
144    pub first_def: usize,
145    /// Instruction index of the last use.
146    pub last_use: usize,
147    /// Total number of uses (reads).
148    pub num_uses: usize,
149}
150
151// ---------------------------------------------------------------------------
152// Instruction mix
153// ---------------------------------------------------------------------------
154
155/// Instruction mix statistics for a function.
156#[derive(Debug, Clone)]
157pub struct InstructionMix {
158    /// Per-category instruction counts.
159    pub counts: HashMap<InstructionCategory, usize>,
160    /// Total number of instructions analysed.
161    pub total: usize,
162}
163
164// ---------------------------------------------------------------------------
165// Memory report
166// ---------------------------------------------------------------------------
167
168/// Memory access pattern analysis report.
169#[derive(Debug, Clone)]
170pub struct MemoryReport {
171    /// Number of global memory load instructions.
172    pub global_loads: usize,
173    /// Number of global memory store instructions.
174    pub global_stores: usize,
175    /// Number of shared memory load instructions.
176    pub shared_loads: usize,
177    /// Number of shared memory store instructions.
178    pub shared_stores: usize,
179    /// Number of local memory load instructions.
180    pub local_loads: usize,
181    /// Number of local memory store instructions.
182    pub local_stores: usize,
183    /// Estimated coalescing score (0.0 = uncoalesced, 1.0 = perfectly coalesced).
184    pub coalescing_score: f64,
185}
186
187// ---------------------------------------------------------------------------
188// Diff report
189// ---------------------------------------------------------------------------
190
191/// Report comparing two PTX functions.
192#[derive(Debug, Clone)]
193pub struct DiffReport {
194    /// Number of instructions present in B but not in A.
195    pub added_instructions: usize,
196    /// Number of instructions present in A but not in B.
197    pub removed_instructions: usize,
198    /// Number of basic blocks that differ between A and B.
199    pub changed_blocks: usize,
200    /// Change in total register count (B - A). Positive means B uses more registers.
201    pub register_delta: i32,
202}
203
204// ---------------------------------------------------------------------------
205// Complexity metrics
206// ---------------------------------------------------------------------------
207
208/// Kernel complexity analysis results.
209#[derive(Debug, Clone)]
210pub struct ComplexityMetrics {
211    /// Total instruction count in the function body.
212    pub instruction_count: usize,
213    /// Number of branch instructions.
214    pub branch_count: usize,
215    /// Estimated loop count (number of back-edges detected).
216    pub loop_count: usize,
217    /// Maximum number of live registers at any point.
218    pub max_register_pressure: usize,
219    /// Estimated occupancy percentage (0.0 -- 100.0).
220    pub estimated_occupancy_pct: f64,
221    /// Arithmetic intensity (arithmetic ops / memory ops).
222    pub arithmetic_intensity: f64,
223}
224
225// ---------------------------------------------------------------------------
226// Helpers -- instruction classification
227// ---------------------------------------------------------------------------
228
229/// Categorise a PTX [`Instruction`] into an [`InstructionCategory`].
230const fn categorize_instruction(inst: &Instruction) -> InstructionCategory {
231    match inst {
232        // Arithmetic
233        Instruction::Add { .. }
234        | Instruction::Sub { .. }
235        | Instruction::Mul { .. }
236        | Instruction::Mad { .. }
237        | Instruction::MadLo { .. }
238        | Instruction::MadHi { .. }
239        | Instruction::MadWide { .. }
240        | Instruction::Fma { .. }
241        | Instruction::Neg { .. }
242        | Instruction::Abs { .. }
243        | Instruction::Min { .. }
244        | Instruction::Max { .. }
245        | Instruction::Brev { .. }
246        | Instruction::Clz { .. }
247        | Instruction::Popc { .. }
248        | Instruction::Bfind { .. }
249        | Instruction::Bfe { .. }
250        | Instruction::Bfi { .. }
251        | Instruction::Shl { .. }
252        | Instruction::Shr { .. }
253        | Instruction::Div { .. }
254        | Instruction::Rem { .. }
255        | Instruction::And { .. }
256        | Instruction::Or { .. }
257        | Instruction::Xor { .. }
258        | Instruction::Rcp { .. }
259        | Instruction::Rsqrt { .. }
260        | Instruction::Sqrt { .. }
261        | Instruction::Ex2 { .. }
262        | Instruction::Lg2 { .. }
263        | Instruction::Sin { .. }
264        | Instruction::Cos { .. }
265        | Instruction::Dp4a { .. }
266        | Instruction::Dp2a { .. }
267        | Instruction::SetP { .. } => InstructionCategory::Arithmetic,
268
269        // Memory
270        Instruction::Load { .. }
271        | Instruction::Store { .. }
272        | Instruction::CpAsync { .. }
273        | Instruction::CpAsyncCommit
274        | Instruction::CpAsyncWait { .. }
275        | Instruction::Atom { .. }
276        | Instruction::AtomCas { .. }
277        | Instruction::Red { .. }
278        | Instruction::TmaLoad { .. }
279        | Instruction::Tex1d { .. }
280        | Instruction::Tex2d { .. }
281        | Instruction::Tex3d { .. }
282        | Instruction::SurfLoad { .. }
283        | Instruction::SurfStore { .. }
284        | Instruction::Stmatrix { .. }
285        | Instruction::CpAsyncBulk { .. }
286        | Instruction::Ldmatrix { .. } => InstructionCategory::Memory,
287
288        // Control
289        Instruction::Branch { .. } | Instruction::Label(_) | Instruction::Return => {
290            InstructionCategory::Control
291        }
292
293        // Synchronization
294        Instruction::BarSync { .. }
295        | Instruction::BarArrive { .. }
296        | Instruction::FenceAcqRel { .. }
297        | Instruction::FenceProxy { .. }
298        | Instruction::MbarrierInit { .. }
299        | Instruction::MbarrierArrive { .. }
300        | Instruction::MbarrierWait { .. }
301        | Instruction::ElectSync { .. }
302        | Instruction::Griddepcontrol { .. }
303        | Instruction::Redux { .. }
304        | Instruction::BarrierCluster
305        | Instruction::FenceCluster => InstructionCategory::Synchronization,
306
307        // Tensor Core
308        Instruction::Wmma { .. }
309        | Instruction::Mma { .. }
310        | Instruction::Wgmma { .. }
311        | Instruction::Tcgen05Mma { .. } => InstructionCategory::TensorCore,
312
313        // Conversion
314        Instruction::Cvt { .. } => InstructionCategory::Conversion,
315
316        // Special
317        Instruction::MovSpecial { .. }
318        | Instruction::LoadParam { .. }
319        | Instruction::Comment(_)
320        | Instruction::Raw(_)
321        | Instruction::Pragma(_)
322        | Instruction::Setmaxnreg { .. } => InstructionCategory::Special,
323    }
324}
325
326/// Estimate latency in clock cycles for an instruction.
327#[allow(clippy::match_same_arms)]
328const fn estimate_latency(inst: &Instruction) -> u32 {
329    match inst {
330        // Arithmetic -- single-cycle ALU
331        Instruction::Add { .. }
332        | Instruction::Sub { .. }
333        | Instruction::Neg { .. }
334        | Instruction::Abs { .. }
335        | Instruction::Min { .. }
336        | Instruction::Max { .. }
337        | Instruction::And { .. }
338        | Instruction::Or { .. }
339        | Instruction::Xor { .. }
340        | Instruction::Shl { .. }
341        | Instruction::Shr { .. }
342        | Instruction::SetP { .. } => 4,
343
344        Instruction::Mul { .. }
345        | Instruction::Mad { .. }
346        | Instruction::MadLo { .. }
347        | Instruction::MadHi { .. }
348        | Instruction::MadWide { .. }
349        | Instruction::Fma { .. } => 4,
350
351        // Special math (multi-cycle)
352        Instruction::Div { .. } | Instruction::Rem { .. } => 32,
353        Instruction::Rcp { .. } | Instruction::Rsqrt { .. } | Instruction::Sqrt { .. } => 8,
354        Instruction::Ex2 { .. }
355        | Instruction::Lg2 { .. }
356        | Instruction::Sin { .. }
357        | Instruction::Cos { .. } => 8,
358
359        // Bit manipulation
360        Instruction::Brev { .. }
361        | Instruction::Clz { .. }
362        | Instruction::Popc { .. }
363        | Instruction::Bfind { .. }
364        | Instruction::Bfe { .. }
365        | Instruction::Bfi { .. } => 4,
366
367        // Dot product
368        Instruction::Dp4a { .. } | Instruction::Dp2a { .. } => 8,
369
370        // Memory
371        Instruction::Load { .. } => 200,
372        Instruction::Store { .. } => 200,
373        Instruction::CpAsync { .. } => 200,
374        Instruction::CpAsyncCommit | Instruction::CpAsyncWait { .. } => 4,
375        Instruction::Atom { .. } | Instruction::AtomCas { .. } | Instruction::Red { .. } => 200,
376        Instruction::TmaLoad { .. } | Instruction::CpAsyncBulk { .. } => 200,
377        Instruction::Tex1d { .. } | Instruction::Tex2d { .. } | Instruction::Tex3d { .. } => 200,
378        Instruction::SurfLoad { .. } | Instruction::SurfStore { .. } => 200,
379        Instruction::Stmatrix { .. } => 32,
380        Instruction::Ldmatrix { .. } => 20,
381
382        // Control
383        Instruction::Branch { .. } => 8,
384        Instruction::Label(_) | Instruction::Return => 0,
385
386        // Synchronization
387        Instruction::BarSync { .. }
388        | Instruction::BarArrive { .. }
389        | Instruction::FenceAcqRel { .. }
390        | Instruction::FenceProxy { .. }
391        | Instruction::MbarrierInit { .. }
392        | Instruction::MbarrierArrive { .. }
393        | Instruction::MbarrierWait { .. }
394        | Instruction::ElectSync { .. }
395        | Instruction::Griddepcontrol { .. }
396        | Instruction::Redux { .. }
397        | Instruction::BarrierCluster
398        | Instruction::FenceCluster => 16,
399
400        // Tensor Core
401        Instruction::Wmma { .. } => 32,
402        Instruction::Mma { .. } => 16,
403        Instruction::Wgmma { .. } => 64,
404        Instruction::Tcgen05Mma { .. } => 64,
405
406        // Conversion
407        Instruction::Cvt { .. } => 4,
408
409        // Special / meta
410        Instruction::MovSpecial { .. } | Instruction::LoadParam { .. } => 4,
411        Instruction::Comment(_) | Instruction::Raw(_) | Instruction::Pragma(_) => 0,
412        Instruction::Setmaxnreg { .. } => 0,
413    }
414}
415
416/// Estimate throughput per SM per cycle for an instruction.
417const fn estimate_throughput(inst: &Instruction) -> f64 {
418    match categorize_instruction(inst) {
419        InstructionCategory::Arithmetic => 64.0,
420        InstructionCategory::Memory
421        | InstructionCategory::Control
422        | InstructionCategory::Special
423        | InstructionCategory::Conversion => 32.0,
424        InstructionCategory::Synchronization => 16.0,
425        InstructionCategory::TensorCore => 1.0,
426    }
427}
428
429/// Extract register names that an instruction reads.
430fn registers_read(inst: &Instruction) -> Vec<String> {
431    let mut regs = Vec::new();
432    let mut push_operand = |op: &Operand| match op {
433        Operand::Register(r) => regs.push(r.name.clone()),
434        Operand::Address { base, .. } => regs.push(base.name.clone()),
435        _ => {}
436    };
437
438    match inst {
439        Instruction::Add { a, b, .. }
440        | Instruction::Sub { a, b, .. }
441        | Instruction::Mul { a, b, .. }
442        | Instruction::Min { a, b, .. }
443        | Instruction::Max { a, b, .. }
444        | Instruction::Div { a, b, .. }
445        | Instruction::Rem { a, b, .. }
446        | Instruction::And { a, b, .. }
447        | Instruction::Or { a, b, .. }
448        | Instruction::Xor { a, b, .. }
449        | Instruction::SetP { a, b, .. } => {
450            push_operand(a);
451            push_operand(b);
452        }
453        Instruction::Mad { a, b, c, .. }
454        | Instruction::MadLo { a, b, c, .. }
455        | Instruction::MadHi { a, b, c, .. }
456        | Instruction::MadWide { a, b, c, .. }
457        | Instruction::Fma { a, b, c, .. } => {
458            push_operand(a);
459            push_operand(b);
460            push_operand(c);
461        }
462        Instruction::Neg { src, .. }
463        | Instruction::Abs { src, .. }
464        | Instruction::Brev { src, .. }
465        | Instruction::Clz { src, .. }
466        | Instruction::Popc { src, .. }
467        | Instruction::Bfind { src, .. }
468        | Instruction::Cvt { src, .. }
469        | Instruction::Rcp { src, .. }
470        | Instruction::Rsqrt { src, .. }
471        | Instruction::Sqrt { src, .. }
472        | Instruction::Ex2 { src, .. }
473        | Instruction::Lg2 { src, .. }
474        | Instruction::Sin { src, .. }
475        | Instruction::Cos { src, .. } => {
476            push_operand(src);
477        }
478        Instruction::Load { addr, .. } => {
479            push_operand(addr);
480        }
481        Instruction::Store { addr, src, .. } => {
482            push_operand(addr);
483            regs.push(src.name.clone());
484        }
485        Instruction::Branch {
486            predicate: Some((pred, _)),
487            ..
488        } => {
489            regs.push(pred.name.clone());
490        }
491        Instruction::Shl { src, amount, .. } | Instruction::Shr { src, amount, .. } => {
492            push_operand(src);
493            push_operand(amount);
494        }
495        _ => {}
496    }
497    regs
498}
499
500/// Extract register names that an instruction writes.
501fn registers_written(inst: &Instruction) -> Vec<String> {
502    match inst {
503        Instruction::Add { dst, .. }
504        | Instruction::Sub { dst, .. }
505        | Instruction::Mul { dst, .. }
506        | Instruction::Mad { dst, .. }
507        | Instruction::MadLo { dst, .. }
508        | Instruction::MadHi { dst, .. }
509        | Instruction::MadWide { dst, .. }
510        | Instruction::Fma { dst, .. }
511        | Instruction::Neg { dst, .. }
512        | Instruction::Abs { dst, .. }
513        | Instruction::Min { dst, .. }
514        | Instruction::Max { dst, .. }
515        | Instruction::Brev { dst, .. }
516        | Instruction::Clz { dst, .. }
517        | Instruction::Popc { dst, .. }
518        | Instruction::Bfind { dst, .. }
519        | Instruction::Bfe { dst, .. }
520        | Instruction::Bfi { dst, .. }
521        | Instruction::Shl { dst, .. }
522        | Instruction::Shr { dst, .. }
523        | Instruction::Div { dst, .. }
524        | Instruction::Rem { dst, .. }
525        | Instruction::And { dst, .. }
526        | Instruction::Or { dst, .. }
527        | Instruction::Xor { dst, .. }
528        | Instruction::SetP { dst, .. }
529        | Instruction::Load { dst, .. }
530        | Instruction::Cvt { dst, .. }
531        | Instruction::Atom { dst, .. }
532        | Instruction::AtomCas { dst, .. }
533        | Instruction::MovSpecial { dst, .. }
534        | Instruction::LoadParam { dst, .. }
535        | Instruction::Rcp { dst, .. }
536        | Instruction::Rsqrt { dst, .. }
537        | Instruction::Sqrt { dst, .. }
538        | Instruction::Ex2 { dst, .. }
539        | Instruction::Lg2 { dst, .. }
540        | Instruction::Sin { dst, .. }
541        | Instruction::Cos { dst, .. }
542        | Instruction::Dp4a { dst, .. }
543        | Instruction::Dp2a { dst, .. }
544        | Instruction::Tex1d { dst, .. }
545        | Instruction::Tex2d { dst, .. }
546        | Instruction::Tex3d { dst, .. }
547        | Instruction::SurfLoad { dst, .. }
548        | Instruction::Redux { dst, .. }
549        | Instruction::ElectSync { dst, .. } => vec![dst.name.clone()],
550        _ => Vec::new(),
551    }
552}
553
554// ---------------------------------------------------------------------------
555// ANSI helpers
556// ---------------------------------------------------------------------------
557
558const ANSI_RESET: &str = "\x1b[0m";
559const ANSI_BOLD: &str = "\x1b[1m";
560
561fn colorize(text: &str, color: &str, use_color: bool) -> String {
562    if use_color {
563        format!("{color}{text}{ANSI_RESET}")
564    } else {
565        text.to_string()
566    }
567}
568
569// ---------------------------------------------------------------------------
570// PtxExplorer -- main analysis and rendering engine
571// ---------------------------------------------------------------------------
572
573/// Main PTX analysis and rendering engine.
574///
575/// `PtxExplorer` provides methods to render PTX functions and modules as
576/// formatted text, including syntax-highlighted code, control flow graphs,
577/// register lifetime diagrams, and instruction mix charts.
578#[derive(Debug, Clone)]
579pub struct PtxExplorer {
580    config: ExplorerConfig,
581}
582
583impl PtxExplorer {
584    /// Creates a new explorer with the given configuration.
585    #[must_use]
586    pub const fn new(config: ExplorerConfig) -> Self {
587        Self { config }
588    }
589
590    /// Renders a single PTX function as pretty-printed text.
591    ///
592    /// If `use_color` is enabled, ANSI escape codes are used for syntax
593    /// highlighting of different instruction categories.
594    #[must_use]
595    pub fn render_function(&self, func: &PtxFunction) -> String {
596        let mut out = String::new();
597        let header = format!(".entry {} (", func.name);
598        let _ = writeln!(
599            out,
600            "{}",
601            colorize(&header, ANSI_BOLD, self.config.use_color)
602        );
603
604        for (i, (name, ty)) in func.params.iter().enumerate() {
605            let comma = if i + 1 < func.params.len() { "," } else { "" };
606            let _ = writeln!(out, "    .param {} {}{}", ty.as_ptx_str(), name, comma);
607        }
608        let _ = writeln!(out, ")");
609        let _ = writeln!(out, "{{");
610
611        for (idx, inst) in func.body.iter().enumerate() {
612            let cat = categorize_instruction(inst);
613            let emitted = inst.emit();
614            let line = if self.config.show_line_numbers {
615                format!("{:>4}  {}", idx + 1, emitted)
616            } else {
617                format!("    {emitted}")
618            };
619
620            let line = if self.config.show_instruction_latency {
621                let lat = estimate_latency(inst);
622                if lat > 0 {
623                    let pad = self.config.max_width.saturating_sub(line.len()).max(2);
624                    format!("{line}{:>pad$}", format!("// ~{lat} cycles"), pad = pad)
625                } else {
626                    line
627                }
628            } else {
629                line
630            };
631
632            let _ = writeln!(
633                out,
634                "{}",
635                colorize(&line, cat.ansi_color(), self.config.use_color)
636            );
637        }
638
639        let _ = writeln!(out, "}}");
640        out
641    }
642
643    /// Renders an entire PTX module.
644    #[must_use]
645    pub fn render_module(&self, module: &PtxModule) -> String {
646        let mut out = String::new();
647        let _ = writeln!(out, ".version {}", module.version);
648        let _ = writeln!(out, ".target {}", module.target);
649        let _ = writeln!(out, ".address_size {}", module.address_size);
650        let _ = writeln!(out);
651
652        for func in &module.functions {
653            out.push_str(&self.render_function(func));
654            let _ = writeln!(out);
655        }
656        out
657    }
658
659    /// Renders a control flow graph for a function as ASCII art.
660    ///
661    /// Uses basic blocks derived from `Label` and `Branch` variants of
662    /// [`Instruction`] in the function body.
663    /// Each block is drawn as a box with its label and instruction count;
664    /// edges show branch targets.
665    #[must_use]
666    pub fn render_cfg(&self, func: &PtxFunction) -> String {
667        let blocks = split_into_blocks(&func.body);
668        let renderer = CfgRenderer;
669        renderer.render(&blocks)
670    }
671
672    /// Renders a register lifetime timeline for the function.
673    #[must_use]
674    pub fn render_register_lifetime(&self, func: &PtxFunction) -> String {
675        let analyzer = RegisterLifetimeAnalyzer;
676        let lifetimes = analyzer.analyze(func);
677        RegisterLifetimeAnalyzer::render_timeline(&lifetimes, self.config.max_width)
678    }
679
680    /// Renders an instruction mix bar chart for the function.
681    #[must_use]
682    pub fn render_instruction_mix(&self, func: &PtxFunction) -> String {
683        let analyzer = InstructionMixAnalyzer;
684        let mix = analyzer.analyze(func);
685        InstructionMixAnalyzer::render_bar_chart(&mix, self.config.max_width)
686    }
687
688    /// Renders a data dependency graph for a single basic block.
689    #[must_use]
690    pub fn render_dependency_graph(&self, block: &BasicBlock) -> String {
691        let mut out = String::new();
692        let label = block.label.as_deref().unwrap_or("(unnamed)");
693        let _ = writeln!(out, "Dependency graph for block: {label}");
694        let _ = writeln!(out, "{}", "-".repeat(40));
695
696        // Build a map from register name -> instruction index that last wrote it
697        let mut last_writer: HashMap<String, usize> = HashMap::new();
698        // edges: (from_idx, to_idx, register)
699        let mut edges: Vec<(usize, usize, String)> = Vec::new();
700
701        for (idx, inst) in block.instructions.iter().enumerate() {
702            // Check reads -- any register read that has a prior writer creates an edge
703            for reg in registers_read(inst) {
704                if let Some(&writer_idx) = last_writer.get(&reg) {
705                    edges.push((writer_idx, idx, reg));
706                }
707            }
708            // Record writes
709            for reg in registers_written(inst) {
710                last_writer.insert(reg, idx);
711            }
712        }
713
714        if edges.is_empty() {
715            let _ = writeln!(out, "(no data dependencies)");
716        } else {
717            for (from, to, reg) in &edges {
718                let from_text = block
719                    .instructions
720                    .get(*from)
721                    .map_or_else(|| "?".to_string(), |i| truncate_emit(i, 40));
722                let to_text = block
723                    .instructions
724                    .get(*to)
725                    .map_or_else(|| "?".to_string(), |i| truncate_emit(i, 40));
726                let _ = writeln!(out, "[{from}] {from_text}");
727                let _ = writeln!(out, "  --({reg})--> [{to}] {to_text}");
728            }
729        }
730        out
731    }
732}
733
734/// Analyse an instruction and return detailed information.
735#[must_use]
736pub fn analyze_instruction(inst: &Instruction) -> InstructionInfo {
737    InstructionInfo {
738        instruction: inst.emit(),
739        category: categorize_instruction(inst),
740        latency_cycles: estimate_latency(inst),
741        throughput_per_sm: estimate_throughput(inst),
742        registers_read: registers_read(inst),
743        registers_written: registers_written(inst),
744    }
745}
746
747// ---------------------------------------------------------------------------
748// CfgRenderer
749// ---------------------------------------------------------------------------
750
751/// Control flow graph renderer producing ASCII box-and-arrow diagrams.
752pub struct CfgRenderer;
753
754impl CfgRenderer {
755    /// Renders the given basic blocks as an ASCII CFG diagram.
756    ///
757    /// Each block is drawn as a bordered box containing the block label and
758    /// instruction count. Edges are drawn as arrows between blocks.
759    #[must_use]
760    pub fn render(&self, blocks: &[BasicBlock]) -> String {
761        if blocks.is_empty() {
762            return "(empty CFG)\n".to_string();
763        }
764
765        let mut out = String::new();
766        let _ = writeln!(out, "Control Flow Graph");
767        let _ = writeln!(out, "==================");
768        let _ = writeln!(out);
769
770        // Build label -> block index map
771        let mut label_to_idx: HashMap<&str, usize> = HashMap::new();
772        for (idx, blk) in blocks.iter().enumerate() {
773            if let Some(ref label) = blk.label {
774                label_to_idx.insert(label.as_str(), idx);
775            }
776        }
777
778        // Collect edges: (from_block_idx, to_block_idx)
779        let mut edges: Vec<(usize, usize)> = Vec::new();
780        for (idx, blk) in blocks.iter().enumerate() {
781            for inst in &blk.instructions {
782                if let Instruction::Branch { target, .. } = inst {
783                    if let Some(&target_idx) = label_to_idx.get(target.as_str()) {
784                        edges.push((idx, target_idx));
785                    }
786                }
787            }
788            // Fall-through edge to next block (if last instruction is not an
789            // unconditional branch or return)
790            let is_terminal = blk.instructions.last().is_some_and(|i| {
791                matches!(
792                    i,
793                    Instruction::Return
794                        | Instruction::Branch {
795                            predicate: None,
796                            ..
797                        }
798                )
799            });
800            if !is_terminal && idx + 1 < blocks.len() {
801                edges.push((idx, idx + 1));
802            }
803        }
804
805        // Draw blocks
806        for (idx, blk) in blocks.iter().enumerate() {
807            let label = blk.label.as_deref().unwrap_or("(entry)");
808            let box_content = format!("B{idx}: {label} ({} insts)", blk.instructions.len());
809            let box_width = box_content.len() + 4;
810            let border = "+".to_string() + &"-".repeat(box_width - 2) + "+";
811            let _ = writeln!(out, "{border}");
812            let _ = writeln!(out, "| {box_content} |");
813            let _ = writeln!(out, "{border}");
814
815            // Show outgoing edges from this block
816            let outgoing: Vec<&(usize, usize)> = edges.iter().filter(|(f, _)| *f == idx).collect();
817            for (_, to) in outgoing {
818                let target_label = blocks
819                    .get(*to)
820                    .and_then(|b| b.label.as_deref())
821                    .unwrap_or("(next)");
822                let _ = writeln!(out, "    |");
823                let _ = writeln!(out, "    +--> B{to}: {target_label}");
824            }
825            let _ = writeln!(out);
826        }
827        out
828    }
829}
830
831// ---------------------------------------------------------------------------
832// RegisterLifetimeAnalyzer
833// ---------------------------------------------------------------------------
834
835/// Analyzes register lifetimes (live ranges) across a function.
836pub struct RegisterLifetimeAnalyzer;
837
838impl RegisterLifetimeAnalyzer {
839    /// Analyses a PTX function and returns lifetime information for each register.
840    #[must_use]
841    pub fn analyze(&self, func: &PtxFunction) -> Vec<RegisterLifetime> {
842        let mut first_defs: HashMap<String, (usize, String)> = HashMap::new();
843        let mut last_uses: HashMap<String, usize> = HashMap::new();
844        let mut use_counts: HashMap<String, usize> = HashMap::new();
845
846        for (idx, inst) in func.body.iter().enumerate() {
847            // Written registers -- first definition
848            for reg in registers_written(inst) {
849                first_defs.entry(reg.clone()).or_insert_with(|| {
850                    let reg_type = Self::infer_type(inst, &reg);
851                    (idx, reg_type)
852                });
853                // A write is also a "use" of the register slot
854                last_uses.insert(reg, idx);
855            }
856            // Read registers
857            for reg in registers_read(inst) {
858                last_uses.insert(reg.clone(), idx);
859                *use_counts.entry(reg).or_insert(0) += 1;
860            }
861        }
862
863        let mut lifetimes: Vec<RegisterLifetime> = first_defs
864            .into_iter()
865            .map(|(reg, (def_idx, reg_type))| {
866                let last = last_uses.get(&reg).copied().unwrap_or(def_idx);
867                let uses = use_counts.get(&reg).copied().unwrap_or(0);
868                RegisterLifetime {
869                    register: reg,
870                    reg_type,
871                    first_def: def_idx,
872                    last_use: last,
873                    num_uses: uses,
874                }
875            })
876            .collect();
877
878        lifetimes.sort_by_key(|l| (l.first_def, l.register.clone()));
879        lifetimes
880    }
881
882    /// Renders a horizontal timeline of register lifetimes.
883    #[must_use]
884    pub fn render_timeline(lifetimes: &[RegisterLifetime], max_width: usize) -> String {
885        if lifetimes.is_empty() {
886            return "(no registers)\n".to_string();
887        }
888
889        let mut out = String::new();
890        let _ = writeln!(out, "Register Lifetimes");
891        let _ = writeln!(out, "==================");
892        let _ = writeln!(out);
893
894        // Determine scaling
895        let max_inst = lifetimes
896            .iter()
897            .map(|l| l.last_use)
898            .max()
899            .unwrap_or(0)
900            .max(1);
901
902        let name_col_width = lifetimes
903            .iter()
904            .map(|l| l.register.len())
905            .max()
906            .unwrap_or(4)
907            .max(4);
908        let type_col_width = lifetimes
909            .iter()
910            .map(|l| l.reg_type.len())
911            .max()
912            .unwrap_or(4)
913            .max(4);
914
915        // Available width for the bar
916        let bar_width = max_width
917            .saturating_sub(name_col_width + type_col_width + 10)
918            .max(10);
919
920        let _ = writeln!(
921            out,
922            "{:>nw$}  {:>tw$}  Lifetime",
923            "Reg",
924            "Type",
925            nw = name_col_width,
926            tw = type_col_width
927        );
928        let _ = writeln!(
929            out,
930            "{}  {}  {}",
931            "-".repeat(name_col_width),
932            "-".repeat(type_col_width),
933            "-".repeat(bar_width),
934        );
935
936        for lt in lifetimes {
937            let start_pos = (lt.first_def * bar_width) / max_inst.max(1);
938            let end_pos = (lt.last_use * bar_width) / max_inst.max(1);
939            let end_pos = end_pos.max(start_pos + 1).min(bar_width);
940
941            let mut bar = vec![' '; bar_width];
942            for ch in bar.iter_mut().take(end_pos).skip(start_pos) {
943                *ch = '#';
944            }
945            let bar_str: String = bar.into_iter().collect();
946
947            let _ = writeln!(
948                out,
949                "{:>nw$}  {:>tw$}  {bar_str}  (uses: {})",
950                lt.register,
951                lt.reg_type,
952                lt.num_uses,
953                nw = name_col_width,
954                tw = type_col_width,
955            );
956        }
957        out
958    }
959
960    /// Infer a type string for a register from the instruction that defines it.
961    fn infer_type(inst: &Instruction, _reg: &str) -> String {
962        match inst {
963            Instruction::Add { ty, .. }
964            | Instruction::Sub { ty, .. }
965            | Instruction::Mul { ty, .. }
966            | Instruction::Min { ty, .. }
967            | Instruction::Max { ty, .. }
968            | Instruction::Neg { ty, .. }
969            | Instruction::Abs { ty, .. }
970            | Instruction::Div { ty, .. }
971            | Instruction::Rem { ty, .. }
972            | Instruction::And { ty, .. }
973            | Instruction::Or { ty, .. }
974            | Instruction::Xor { ty, .. }
975            | Instruction::Shl { ty, .. }
976            | Instruction::Shr { ty, .. }
977            | Instruction::Load { ty, .. }
978            | Instruction::Brev { ty, .. }
979            | Instruction::Clz { ty, .. }
980            | Instruction::Popc { ty, .. }
981            | Instruction::Bfind { ty, .. }
982            | Instruction::Bfe { ty, .. }
983            | Instruction::Bfi { ty, .. }
984            | Instruction::Rcp { ty, .. }
985            | Instruction::Rsqrt { ty, .. }
986            | Instruction::Sqrt { ty, .. }
987            | Instruction::Ex2 { ty, .. }
988            | Instruction::Lg2 { ty, .. }
989            | Instruction::Sin { ty, .. }
990            | Instruction::Cos { ty, .. }
991            | Instruction::Tex1d { ty, .. }
992            | Instruction::Tex2d { ty, .. }
993            | Instruction::Tex3d { ty, .. }
994            | Instruction::SurfLoad { ty, .. }
995            | Instruction::Atom { ty, .. }
996            | Instruction::AtomCas { ty, .. }
997            | Instruction::Mad { ty, .. }
998            | Instruction::Fma { ty, .. }
999            | Instruction::SetP { ty, .. }
1000            | Instruction::LoadParam { ty, .. } => ty.as_ptx_str().to_string(),
1001            Instruction::MadLo { typ, .. } | Instruction::MadHi { typ, .. } => {
1002                typ.as_ptx_str().to_string()
1003            }
1004            Instruction::MadWide { src_typ, .. } => src_typ.as_ptx_str().to_string(),
1005            Instruction::Cvt { dst_ty, .. } => dst_ty.as_ptx_str().to_string(),
1006            _ => "?".to_string(),
1007        }
1008    }
1009}
1010
1011// ---------------------------------------------------------------------------
1012// InstructionMixAnalyzer
1013// ---------------------------------------------------------------------------
1014
1015/// Analyses the instruction mix (category distribution) of a PTX function.
1016pub struct InstructionMixAnalyzer;
1017
1018impl InstructionMixAnalyzer {
1019    /// Analyses a PTX function and returns instruction mix statistics.
1020    #[must_use]
1021    pub fn analyze(&self, func: &PtxFunction) -> InstructionMix {
1022        let mut counts: HashMap<InstructionCategory, usize> = HashMap::new();
1023        for inst in &func.body {
1024            *counts.entry(categorize_instruction(inst)).or_insert(0) += 1;
1025        }
1026        InstructionMix {
1027            total: func.body.len(),
1028            counts,
1029        }
1030    }
1031
1032    /// Renders a horizontal bar chart of instruction categories.
1033    #[must_use]
1034    pub fn render_bar_chart(mix: &InstructionMix, width: usize) -> String {
1035        if mix.total == 0 {
1036            return "(no instructions)\n".to_string();
1037        }
1038
1039        let mut out = String::new();
1040        let _ = writeln!(out, "Instruction Mix");
1041        let _ = writeln!(out, "===============");
1042        let _ = writeln!(out);
1043
1044        let label_width = 12_usize;
1045        let bar_width = width.saturating_sub(label_width + 20).max(10);
1046
1047        let mut categories: Vec<(InstructionCategory, usize)> =
1048            mix.counts.iter().map(|(&cat, &cnt)| (cat, cnt)).collect();
1049        categories.sort_by_key(|&(_, cnt)| std::cmp::Reverse(cnt));
1050
1051        for (cat, count) in &categories {
1052            #[allow(clippy::cast_precision_loss)]
1053            let pct = (*count as f64 / mix.total as f64) * 100.0;
1054            #[allow(
1055                clippy::cast_precision_loss,
1056                clippy::cast_possible_truncation,
1057                clippy::cast_sign_loss
1058            )]
1059            let filled = ((*count as f64 / mix.total as f64) * bar_width as f64) as usize;
1060            let bar: String = "#".repeat(filled) + &" ".repeat(bar_width.saturating_sub(filled));
1061            let _ = writeln!(
1062                out,
1063                "{:<lw$} [{bar}] {count:>4} ({pct:>5.1}%)",
1064                cat.label(),
1065                lw = label_width,
1066            );
1067        }
1068
1069        let _ = writeln!(out);
1070        let _ = writeln!(out, "Total: {} instructions", mix.total);
1071        out
1072    }
1073}
1074
1075// ---------------------------------------------------------------------------
1076// MemoryAccessPattern
1077// ---------------------------------------------------------------------------
1078
1079/// Analyses memory access patterns in a PTX function.
1080pub struct MemoryAccessPattern;
1081
1082impl MemoryAccessPattern {
1083    /// Analyses a PTX function and returns a memory access report.
1084    #[must_use]
1085    pub fn analyze(func: &PtxFunction) -> MemoryReport {
1086        let mut report = MemoryReport {
1087            global_loads: 0,
1088            global_stores: 0,
1089            shared_loads: 0,
1090            shared_stores: 0,
1091            local_loads: 0,
1092            local_stores: 0,
1093            coalescing_score: 1.0,
1094        };
1095
1096        let mut total_mem_ops = 0_usize;
1097        let mut likely_coalesced = 0_usize;
1098
1099        for inst in &func.body {
1100            match inst {
1101                Instruction::Load { space, .. } => {
1102                    total_mem_ops += 1;
1103                    match space {
1104                        MemorySpace::Global => {
1105                            report.global_loads += 1;
1106                            // Heuristic: global loads that use address + offset pattern
1107                            // are more likely to be coalesced
1108                            likely_coalesced += 1;
1109                        }
1110                        MemorySpace::Shared => report.shared_loads += 1,
1111                        MemorySpace::Local => report.local_loads += 1,
1112                        _ => {}
1113                    }
1114                }
1115                Instruction::Store { space, .. } => {
1116                    total_mem_ops += 1;
1117                    match space {
1118                        MemorySpace::Global => {
1119                            report.global_stores += 1;
1120                            likely_coalesced += 1;
1121                        }
1122                        MemorySpace::Shared => report.shared_stores += 1,
1123                        MemorySpace::Local => report.local_stores += 1,
1124                        _ => {}
1125                    }
1126                }
1127                Instruction::CpAsync { .. } | Instruction::TmaLoad { .. } => {
1128                    total_mem_ops += 1;
1129                    report.global_loads += 1;
1130                    report.shared_stores += 1;
1131                    likely_coalesced += 1;
1132                }
1133                _ => {}
1134            }
1135        }
1136
1137        if total_mem_ops > 0 {
1138            #[allow(clippy::cast_precision_loss)]
1139            {
1140                report.coalescing_score = likely_coalesced as f64 / total_mem_ops as f64;
1141            }
1142        }
1143
1144        report
1145    }
1146}
1147
1148// ---------------------------------------------------------------------------
1149// PtxDiff
1150// ---------------------------------------------------------------------------
1151
1152/// Compares two PTX functions and produces a diff report.
1153pub struct PtxDiff;
1154
1155impl PtxDiff {
1156    /// Compares two functions and returns a diff report.
1157    #[must_use]
1158    pub fn diff(a: &PtxFunction, b: &PtxFunction) -> DiffReport {
1159        let a_count = a.body.len();
1160        let b_count = b.body.len();
1161
1162        let added = b_count.saturating_sub(a_count);
1163        let removed = a_count.saturating_sub(b_count);
1164
1165        // Count blocks
1166        let a_blocks = split_into_blocks(&a.body);
1167        let b_blocks = split_into_blocks(&b.body);
1168
1169        let changed_blocks = count_changed_blocks(&a_blocks, &b_blocks);
1170
1171        // Register delta
1172        let a_regs = count_unique_registers(&a.body);
1173        let b_regs = count_unique_registers(&b.body);
1174        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
1175        let register_delta = b_regs as i32 - a_regs as i32;
1176
1177        DiffReport {
1178            added_instructions: added,
1179            removed_instructions: removed,
1180            changed_blocks,
1181            register_delta,
1182        }
1183    }
1184
1185    /// Renders a diff report as formatted text.
1186    #[must_use]
1187    pub fn render_diff(report: &DiffReport) -> String {
1188        let mut out = String::new();
1189        let _ = writeln!(out, "PTX Diff Report");
1190        let _ = writeln!(out, "===============");
1191        let _ = writeln!(out);
1192        let _ = writeln!(out, "Added instructions:   +{}", report.added_instructions);
1193        let _ = writeln!(
1194            out,
1195            "Removed instructions: -{}",
1196            report.removed_instructions
1197        );
1198        let _ = writeln!(out, "Changed blocks:       {}", report.changed_blocks);
1199
1200        let sign = if report.register_delta >= 0 { "+" } else { "" };
1201        let _ = writeln!(out, "Register delta:       {sign}{}", report.register_delta);
1202        out
1203    }
1204}
1205
1206// ---------------------------------------------------------------------------
1207// KernelComplexityScore
1208// ---------------------------------------------------------------------------
1209
1210/// Computes complexity metrics for a PTX kernel.
1211pub struct KernelComplexityScore;
1212
1213impl KernelComplexityScore {
1214    /// Analyses a PTX function and returns complexity metrics.
1215    #[must_use]
1216    pub fn analyze(func: &PtxFunction) -> ComplexityMetrics {
1217        let instruction_count = func.body.len();
1218
1219        let mut branch_count = 0_usize;
1220        let mut arith_count = 0_usize;
1221        let mut mem_count = 0_usize;
1222
1223        for inst in &func.body {
1224            match categorize_instruction(inst) {
1225                InstructionCategory::Control => {
1226                    if matches!(inst, Instruction::Branch { .. }) {
1227                        branch_count += 1;
1228                    }
1229                }
1230                InstructionCategory::Arithmetic => arith_count += 1,
1231                InstructionCategory::Memory => mem_count += 1,
1232                _ => {}
1233            }
1234        }
1235
1236        // Detect loops via back-edges: a branch to a label that appears
1237        // earlier in the instruction stream.
1238        let loop_count = count_back_edges(&func.body);
1239
1240        // Register pressure: count max live registers at any point
1241        let max_register_pressure = compute_max_register_pressure(&func.body);
1242
1243        // Estimated occupancy: rough heuristic based on register pressure.
1244        // SM has 65536 registers; a warp uses 32 * regs. Max warps per SM = 64.
1245        // occupancy = min(64, 65536 / (32 * max_regs)) / 64 * 100
1246        #[allow(clippy::cast_precision_loss)]
1247        let estimated_occupancy_pct = if max_register_pressure > 0 {
1248            let warps_per_sm = 65536_f64 / (32.0 * max_register_pressure as f64);
1249            let warps_per_sm = warps_per_sm.min(64.0);
1250            (warps_per_sm / 64.0) * 100.0
1251        } else {
1252            100.0
1253        };
1254
1255        #[allow(clippy::cast_precision_loss)]
1256        let arithmetic_intensity = if mem_count > 0 {
1257            arith_count as f64 / mem_count as f64
1258        } else if arith_count > 0 {
1259            f64::INFINITY
1260        } else {
1261            0.0
1262        };
1263
1264        ComplexityMetrics {
1265            instruction_count,
1266            branch_count,
1267            loop_count,
1268            max_register_pressure,
1269            estimated_occupancy_pct,
1270            arithmetic_intensity,
1271        }
1272    }
1273}
1274
1275// ---------------------------------------------------------------------------
1276// Internal helpers
1277// ---------------------------------------------------------------------------
1278
1279/// Split a flat instruction vector into basic blocks by scanning for Label
1280/// instructions.
1281fn split_into_blocks(body: &[Instruction]) -> Vec<BasicBlock> {
1282    if body.is_empty() {
1283        return Vec::new();
1284    }
1285
1286    let mut blocks: Vec<BasicBlock> = Vec::new();
1287    let mut current_label: Option<String> = None;
1288    let mut current_insts: Vec<Instruction> = Vec::new();
1289
1290    for inst in body {
1291        if let Instruction::Label(lbl) = inst {
1292            // Flush current block
1293            if !current_insts.is_empty() || current_label.is_some() {
1294                blocks.push(BasicBlock {
1295                    label: current_label.take(),
1296                    instructions: std::mem::take(&mut current_insts),
1297                });
1298            }
1299            current_label = Some(lbl.clone());
1300        } else {
1301            current_insts.push(inst.clone());
1302        }
1303    }
1304
1305    // Flush remaining
1306    if !current_insts.is_empty() || current_label.is_some() {
1307        blocks.push(BasicBlock {
1308            label: current_label,
1309            instructions: current_insts,
1310        });
1311    }
1312
1313    blocks
1314}
1315
1316/// Count the number of blocks that differ between two block lists.
1317fn count_changed_blocks(a: &[BasicBlock], b: &[BasicBlock]) -> usize {
1318    let max_len = a.len().max(b.len());
1319    let mut changed = 0_usize;
1320
1321    for i in 0..max_len {
1322        let a_block = a.get(i);
1323        let b_block = b.get(i);
1324        match (a_block, b_block) {
1325            (Some(ab), Some(bb)) => {
1326                if ab.label != bb.label || ab.instructions.len() != bb.instructions.len() {
1327                    changed += 1;
1328                } else {
1329                    // Compare instruction emissions
1330                    let differs = ab
1331                        .instructions
1332                        .iter()
1333                        .zip(bb.instructions.iter())
1334                        .any(|(ai, bi)| ai.emit() != bi.emit());
1335                    if differs {
1336                        changed += 1;
1337                    }
1338                }
1339            }
1340            _ => changed += 1,
1341        }
1342    }
1343    changed
1344}
1345
1346/// Count unique register names referenced in an instruction list.
1347fn count_unique_registers(body: &[Instruction]) -> usize {
1348    let mut regs = std::collections::HashSet::new();
1349    for inst in body {
1350        for r in registers_read(inst) {
1351            regs.insert(r);
1352        }
1353        for r in registers_written(inst) {
1354            regs.insert(r);
1355        }
1356    }
1357    regs.len()
1358}
1359
1360/// Count back-edges (branches to a label that appears before the branch).
1361fn count_back_edges(body: &[Instruction]) -> usize {
1362    // Record the instruction index of each label
1363    let mut label_positions: HashMap<&str, usize> = HashMap::new();
1364    for (idx, inst) in body.iter().enumerate() {
1365        if let Instruction::Label(lbl) = inst {
1366            label_positions.insert(lbl.as_str(), idx);
1367        }
1368    }
1369
1370    let mut count = 0_usize;
1371    for (idx, inst) in body.iter().enumerate() {
1372        if let Instruction::Branch { target, .. } = inst {
1373            if let Some(&lbl_idx) = label_positions.get(target.as_str()) {
1374                if lbl_idx <= idx {
1375                    count += 1;
1376                }
1377            }
1378        }
1379    }
1380    count
1381}
1382
1383/// Compute maximum number of simultaneously live registers.
1384fn compute_max_register_pressure(body: &[Instruction]) -> usize {
1385    if body.is_empty() {
1386        return 0;
1387    }
1388
1389    // Build per-register [first_def, last_use] intervals
1390    let mut first_def: HashMap<String, usize> = HashMap::new();
1391    let mut last_use: HashMap<String, usize> = HashMap::new();
1392
1393    for (idx, inst) in body.iter().enumerate() {
1394        for r in registers_written(inst) {
1395            first_def.entry(r.clone()).or_insert(idx);
1396            last_use.insert(r, idx);
1397        }
1398        for r in registers_read(inst) {
1399            last_use.insert(r, idx);
1400        }
1401    }
1402
1403    // Sweep: for each instruction index, count live registers
1404    let intervals: Vec<(usize, usize)> = first_def
1405        .iter()
1406        .map(|(reg, &def)| {
1407            let use_end = last_use.get(reg).copied().unwrap_or(def);
1408            (def, use_end)
1409        })
1410        .collect();
1411
1412    let mut max_live = 0_usize;
1413    for idx in 0..body.len() {
1414        let live = intervals
1415            .iter()
1416            .filter(|(start, end)| *start <= idx && idx <= *end)
1417            .count();
1418        if live > max_live {
1419            max_live = live;
1420        }
1421    }
1422    max_live
1423}
1424
1425/// Truncate an instruction's emitted text to a maximum length.
1426fn truncate_emit(inst: &Instruction, max_len: usize) -> String {
1427    let s = inst.emit();
1428    if s.len() > max_len {
1429        format!("{}...", &s[..max_len.saturating_sub(3)])
1430    } else {
1431        s
1432    }
1433}
1434
1435// ===========================================================================
1436// Tests
1437// ===========================================================================
1438
1439#[cfg(test)]
1440mod tests {
1441    use super::*;
1442    use crate::ir::{
1443        BasicBlock, CacheQualifier, CmpOp, ImmValue, Instruction, MemorySpace, Operand,
1444        PtxFunction, PtxModule, PtxType, Register, RoundingMode, SpecialReg, VectorWidth,
1445    };
1446
1447    // -- Test helpers -------------------------------------------------------
1448
1449    fn make_reg(name: &str, ty: PtxType) -> Register {
1450        Register {
1451            name: name.to_string(),
1452            ty,
1453        }
1454    }
1455
1456    fn make_operand_reg(name: &str, ty: PtxType) -> Operand {
1457        Operand::Register(make_reg(name, ty))
1458    }
1459
1460    fn make_simple_function() -> PtxFunction {
1461        let mut func = PtxFunction::new("test_kernel");
1462        func.add_param("a_ptr", PtxType::U64);
1463        func.add_param("n", PtxType::U32);
1464
1465        // Load param
1466        func.push(Instruction::LoadParam {
1467            ty: PtxType::U64,
1468            dst: make_reg("%rd0", PtxType::U64),
1469            param_name: "a_ptr".to_string(),
1470        });
1471
1472        // MovSpecial
1473        func.push(Instruction::MovSpecial {
1474            dst: make_reg("%r0", PtxType::U32),
1475            special: SpecialReg::TidX,
1476        });
1477
1478        // Add
1479        func.push(Instruction::Add {
1480            ty: PtxType::U32,
1481            dst: make_reg("%r1", PtxType::U32),
1482            a: make_operand_reg("%r0", PtxType::U32),
1483            b: Operand::Immediate(ImmValue::U32(1)),
1484        });
1485
1486        // Load
1487        func.push(Instruction::Load {
1488            space: MemorySpace::Global,
1489            qualifier: CacheQualifier::None,
1490            vec: VectorWidth::V1,
1491            ty: PtxType::F32,
1492            dst: make_reg("%f0", PtxType::F32),
1493            addr: Operand::Address {
1494                base: make_reg("%rd0", PtxType::U64),
1495                offset: None,
1496            },
1497        });
1498
1499        // Fma
1500        func.push(Instruction::Fma {
1501            rnd: RoundingMode::Rn,
1502            ty: PtxType::F32,
1503            dst: make_reg("%f1", PtxType::F32),
1504            a: make_operand_reg("%f0", PtxType::F32),
1505            b: Operand::Immediate(ImmValue::F32(2.0)),
1506            c: Operand::Immediate(ImmValue::F32(1.0)),
1507        });
1508
1509        // Store
1510        func.push(Instruction::Store {
1511            space: MemorySpace::Global,
1512            qualifier: CacheQualifier::None,
1513            vec: VectorWidth::V1,
1514            ty: PtxType::F32,
1515            addr: Operand::Address {
1516                base: make_reg("%rd0", PtxType::U64),
1517                offset: None,
1518            },
1519            src: make_reg("%f1", PtxType::F32),
1520        });
1521
1522        func.push(Instruction::Return);
1523        func
1524    }
1525
1526    fn make_branching_function() -> PtxFunction {
1527        let mut func = PtxFunction::new("branch_kernel");
1528
1529        func.push(Instruction::MovSpecial {
1530            dst: make_reg("%r0", PtxType::U32),
1531            special: SpecialReg::TidX,
1532        });
1533
1534        func.push(Instruction::SetP {
1535            cmp: CmpOp::Lt,
1536            ty: PtxType::U32,
1537            dst: make_reg("%p0", PtxType::Pred),
1538            a: make_operand_reg("%r0", PtxType::U32),
1539            b: Operand::Immediate(ImmValue::U32(128)),
1540        });
1541
1542        func.push(Instruction::Branch {
1543            target: "skip".to_string(),
1544            predicate: Some((make_reg("%p0", PtxType::Pred), true)),
1545        });
1546
1547        // Work block
1548        func.push(Instruction::Add {
1549            ty: PtxType::U32,
1550            dst: make_reg("%r1", PtxType::U32),
1551            a: make_operand_reg("%r0", PtxType::U32),
1552            b: Operand::Immediate(ImmValue::U32(1)),
1553        });
1554
1555        func.push(Instruction::Label("skip".to_string()));
1556
1557        func.push(Instruction::Return);
1558        func
1559    }
1560
1561    // -- Tests --------------------------------------------------------------
1562
1563    #[test]
1564    fn test_render_empty_function() {
1565        let config = ExplorerConfig::default();
1566        let explorer = PtxExplorer::new(config);
1567        let func = PtxFunction::new("empty");
1568        let output = explorer.render_function(&func);
1569        assert!(output.contains("empty"));
1570        assert!(output.contains('{'));
1571        assert!(output.contains('}'));
1572    }
1573
1574    #[test]
1575    fn test_render_function_with_multiple_blocks() {
1576        let config = ExplorerConfig::default();
1577        let explorer = PtxExplorer::new(config);
1578        let func = make_branching_function();
1579        let output = explorer.render_function(&func);
1580        assert!(output.contains("branch_kernel"));
1581        assert!(output.contains("setp"));
1582        assert!(output.contains("bra"));
1583        assert!(output.contains("add"));
1584    }
1585
1586    #[test]
1587    fn test_cfg_rendering_with_branches() {
1588        let config = ExplorerConfig::default();
1589        let explorer = PtxExplorer::new(config);
1590        let func = make_branching_function();
1591        let output = explorer.render_cfg(&func);
1592        assert!(output.contains("Control Flow Graph"));
1593        assert!(output.contains("skip"));
1594        assert!(output.contains("-->"));
1595    }
1596
1597    #[test]
1598    fn test_register_lifetime_analysis() {
1599        let analyzer = RegisterLifetimeAnalyzer;
1600        let func = make_simple_function();
1601        let lifetimes = analyzer.analyze(&func);
1602
1603        // We should have several registers
1604        assert!(!lifetimes.is_empty());
1605
1606        // %rd0 is defined first (LoadParam at index 0) and used later (Load, Store)
1607        let rd0 = lifetimes.iter().find(|l| l.register == "%rd0");
1608        assert!(rd0.is_some(), "should find %rd0 lifetime");
1609        let rd0 = rd0.expect("checked above");
1610        assert_eq!(rd0.first_def, 0);
1611        assert!(
1612            rd0.last_use > rd0.first_def,
1613            "last_use should be after first_def"
1614        );
1615    }
1616
1617    #[test]
1618    fn test_register_lifetime_timeline_rendering() {
1619        let analyzer = RegisterLifetimeAnalyzer;
1620        let func = make_simple_function();
1621        let lifetimes = analyzer.analyze(&func);
1622        let timeline = RegisterLifetimeAnalyzer::render_timeline(&lifetimes, 80);
1623        assert!(timeline.contains("Register Lifetimes"));
1624        assert!(timeline.contains('#')); // should have bar characters
1625        assert!(timeline.contains("uses:"));
1626    }
1627
1628    #[test]
1629    fn test_instruction_mix_categorization() {
1630        let analyzer = InstructionMixAnalyzer;
1631        let func = make_simple_function();
1632        let mix = analyzer.analyze(&func);
1633
1634        assert_eq!(mix.total, func.body.len());
1635
1636        // Should have arithmetic, memory, special, and control categories
1637        let arith = mix
1638            .counts
1639            .get(&InstructionCategory::Arithmetic)
1640            .copied()
1641            .unwrap_or(0);
1642        let mem = mix
1643            .counts
1644            .get(&InstructionCategory::Memory)
1645            .copied()
1646            .unwrap_or(0);
1647        let special = mix
1648            .counts
1649            .get(&InstructionCategory::Special)
1650            .copied()
1651            .unwrap_or(0);
1652        assert!(arith > 0, "should have arithmetic instructions");
1653        assert!(mem > 0, "should have memory instructions");
1654        assert!(special > 0, "should have special instructions");
1655    }
1656
1657    #[test]
1658    fn test_instruction_mix_bar_chart() {
1659        let analyzer = InstructionMixAnalyzer;
1660        let func = make_simple_function();
1661        let mix = analyzer.analyze(&func);
1662        let chart = InstructionMixAnalyzer::render_bar_chart(&mix, 80);
1663        assert!(chart.contains("Instruction Mix"));
1664        assert!(chart.contains('#')); // bar characters
1665        assert!(chart.contains('%')); // percentages
1666        assert!(chart.contains("Total:"));
1667    }
1668
1669    #[test]
1670    fn test_memory_access_pattern_analysis() {
1671        let func = make_simple_function();
1672        let report = MemoryAccessPattern::analyze(&func);
1673        assert_eq!(report.global_loads, 1);
1674        assert_eq!(report.global_stores, 1);
1675        assert_eq!(report.shared_loads, 0);
1676        assert_eq!(report.shared_stores, 0);
1677        assert!(report.coalescing_score > 0.0);
1678        assert!(report.coalescing_score <= 1.0);
1679    }
1680
1681    #[test]
1682    fn test_ptx_diff_identical_functions() {
1683        let func = make_simple_function();
1684        let report = PtxDiff::diff(&func, &func);
1685        assert_eq!(report.added_instructions, 0);
1686        assert_eq!(report.removed_instructions, 0);
1687        assert_eq!(report.changed_blocks, 0);
1688        assert_eq!(report.register_delta, 0);
1689    }
1690
1691    #[test]
1692    fn test_ptx_diff_different_functions() {
1693        let a = make_simple_function();
1694        let mut b = make_simple_function();
1695        // Add extra instructions to b
1696        b.push(Instruction::Comment("extra".to_string()));
1697        b.push(Instruction::Add {
1698            ty: PtxType::U32,
1699            dst: make_reg("%r99", PtxType::U32),
1700            a: Operand::Immediate(ImmValue::U32(0)),
1701            b: Operand::Immediate(ImmValue::U32(1)),
1702        });
1703
1704        let report = PtxDiff::diff(&a, &b);
1705        assert!(report.added_instructions > 0);
1706        assert!(report.register_delta > 0);
1707
1708        let rendered = PtxDiff::render_diff(&report);
1709        assert!(rendered.contains("PTX Diff Report"));
1710        assert!(rendered.contains('+'));
1711    }
1712
1713    #[test]
1714    fn test_kernel_complexity_scoring() {
1715        let func = make_branching_function();
1716        let metrics = KernelComplexityScore::analyze(&func);
1717        assert_eq!(metrics.instruction_count, func.body.len());
1718        assert!(metrics.branch_count > 0, "should detect branches");
1719        assert!(metrics.estimated_occupancy_pct > 0.0);
1720        assert!(metrics.estimated_occupancy_pct <= 100.0);
1721    }
1722
1723    #[test]
1724    fn test_color_vs_no_color_output() {
1725        let func = make_simple_function();
1726
1727        let no_color = PtxExplorer::new(ExplorerConfig {
1728            use_color: false,
1729            ..ExplorerConfig::default()
1730        });
1731        let with_color = PtxExplorer::new(ExplorerConfig {
1732            use_color: true,
1733            ..ExplorerConfig::default()
1734        });
1735
1736        let plain = no_color.render_function(&func);
1737        let colored = with_color.render_function(&func);
1738
1739        // Colored output should contain ANSI codes
1740        assert!(colored.contains("\x1b["));
1741        // Plain output should not
1742        assert!(!plain.contains("\x1b["));
1743        // Both should contain the function name
1744        assert!(plain.contains("test_kernel"));
1745        assert!(colored.contains("test_kernel"));
1746    }
1747
1748    #[test]
1749    fn test_config_defaults() {
1750        let config = ExplorerConfig::default();
1751        assert!(!config.use_color);
1752        assert_eq!(config.max_width, 120);
1753        assert!(!config.show_line_numbers);
1754        assert!(!config.show_register_types);
1755        assert!(!config.show_instruction_latency);
1756    }
1757
1758    #[test]
1759    fn test_large_function_handling() {
1760        let mut func = PtxFunction::new("big_kernel");
1761        // Add 500 instructions -- should not crash or truncate
1762        for i in 0..500 {
1763            func.push(Instruction::Add {
1764                ty: PtxType::F32,
1765                dst: make_reg(&format!("%f{i}"), PtxType::F32),
1766                a: Operand::Immediate(ImmValue::F32(1.0)),
1767                b: Operand::Immediate(ImmValue::F32(2.0)),
1768            });
1769        }
1770
1771        let config = ExplorerConfig::default();
1772        let explorer = PtxExplorer::new(config);
1773        let output = explorer.render_function(&func);
1774        // Should contain all 500 instructions
1775        assert!(output.lines().count() > 500);
1776
1777        let mix = InstructionMixAnalyzer.analyze(&func);
1778        assert_eq!(mix.total, 500);
1779
1780        let metrics = KernelComplexityScore::analyze(&func);
1781        assert_eq!(metrics.instruction_count, 500);
1782    }
1783
1784    #[test]
1785    fn test_line_number_rendering() {
1786        let config = ExplorerConfig {
1787            show_line_numbers: true,
1788            ..ExplorerConfig::default()
1789        };
1790        let explorer = PtxExplorer::new(config);
1791        let func = make_simple_function();
1792        let output = explorer.render_function(&func);
1793        // Should contain line numbers starting from 1
1794        assert!(output.contains("   1  "));
1795        assert!(output.contains("   2  "));
1796    }
1797
1798    #[test]
1799    fn test_render_module() {
1800        let mut module = PtxModule::new("sm_80");
1801        module.add_function(make_simple_function());
1802        module.add_function(make_branching_function());
1803
1804        let explorer = PtxExplorer::new(ExplorerConfig::default());
1805        let output = explorer.render_module(&module);
1806        assert!(output.contains(".version 8.5"));
1807        assert!(output.contains(".target sm_80"));
1808        assert!(output.contains("test_kernel"));
1809        assert!(output.contains("branch_kernel"));
1810    }
1811
1812    #[test]
1813    fn test_dependency_graph() {
1814        let mut block = BasicBlock::with_label("test_block");
1815        block.push(Instruction::LoadParam {
1816            ty: PtxType::F32,
1817            dst: make_reg("%f0", PtxType::F32),
1818            param_name: "x".to_string(),
1819        });
1820        block.push(Instruction::Add {
1821            ty: PtxType::F32,
1822            dst: make_reg("%f1", PtxType::F32),
1823            a: make_operand_reg("%f0", PtxType::F32),
1824            b: Operand::Immediate(ImmValue::F32(1.0)),
1825        });
1826        block.push(Instruction::Add {
1827            ty: PtxType::F32,
1828            dst: make_reg("%f2", PtxType::F32),
1829            a: make_operand_reg("%f1", PtxType::F32),
1830            b: make_operand_reg("%f0", PtxType::F32),
1831        });
1832
1833        let explorer = PtxExplorer::new(ExplorerConfig::default());
1834        let output = explorer.render_dependency_graph(&block);
1835        assert!(output.contains("Dependency graph"));
1836        assert!(output.contains("test_block"));
1837        assert!(output.contains("-->")); // should have dependency edges
1838        assert!(output.contains("%f0")); // register dependency
1839    }
1840
1841    #[test]
1842    fn test_cfg_empty_function() {
1843        let config = ExplorerConfig::default();
1844        let explorer = PtxExplorer::new(config);
1845        let func = PtxFunction::new("empty_kernel");
1846        let output = explorer.render_cfg(&func);
1847        // Empty function body → empty CFG message
1848        assert!(
1849            output.contains("empty CFG")
1850                || output.contains("Control Flow Graph")
1851                || output.is_empty()
1852                || output.contains("(entry)")
1853        );
1854    }
1855
1856    #[test]
1857    fn test_cfg_no_branch_single_block() {
1858        let config = ExplorerConfig::default();
1859        let explorer = PtxExplorer::new(config);
1860        let func = make_simple_function();
1861        let output = explorer.render_cfg(&func);
1862        // make_simple_function has no labels/branches → one block
1863        assert!(output.contains("Control Flow Graph"));
1864        assert!(output.contains("B0"));
1865    }
1866
1867    #[test]
1868    fn test_register_lifetime_single_instruction() {
1869        let analyzer = RegisterLifetimeAnalyzer;
1870        let mut func = PtxFunction::new("single");
1871        func.push(Instruction::Add {
1872            ty: PtxType::U32,
1873            dst: make_reg("%r0", PtxType::U32),
1874            a: Operand::Immediate(ImmValue::U32(1)),
1875            b: Operand::Immediate(ImmValue::U32(2)),
1876        });
1877        let lifetimes = analyzer.analyze(&func);
1878        // %r0 is written at index 0, last_use = 0
1879        let r0 = lifetimes.iter().find(|l| l.register == "%r0");
1880        assert!(r0.is_some(), "should track %r0");
1881        let r0 = r0.expect("checked above");
1882        assert_eq!(r0.first_def, 0);
1883        assert_eq!(r0.last_use, 0);
1884    }
1885
1886    #[test]
1887    fn test_register_lifetime_render_empty() {
1888        let rendered = RegisterLifetimeAnalyzer::render_timeline(&[], 80);
1889        assert!(rendered.contains("no registers"));
1890    }
1891
1892    #[test]
1893    fn test_instruction_mix_empty_function() {
1894        let analyzer = InstructionMixAnalyzer;
1895        let func = PtxFunction::new("empty_kernel");
1896        let mix = analyzer.analyze(&func);
1897        assert_eq!(mix.total, 0);
1898        let chart = InstructionMixAnalyzer::render_bar_chart(&mix, 80);
1899        assert!(chart.contains("no instructions"));
1900    }
1901
1902    #[test]
1903    fn test_dependency_graph_no_deps() {
1904        let mut block = BasicBlock::with_label("no_deps");
1905        // Instructions with no register dependencies between them
1906        block.push(Instruction::Add {
1907            ty: PtxType::U32,
1908            dst: make_reg("%r0", PtxType::U32),
1909            a: Operand::Immediate(ImmValue::U32(1)),
1910            b: Operand::Immediate(ImmValue::U32(2)),
1911        });
1912        block.push(Instruction::Add {
1913            ty: PtxType::U32,
1914            dst: make_reg("%r1", PtxType::U32),
1915            a: Operand::Immediate(ImmValue::U32(3)),
1916            b: Operand::Immediate(ImmValue::U32(4)),
1917        });
1918        let explorer = PtxExplorer::new(ExplorerConfig::default());
1919        let output = explorer.render_dependency_graph(&block);
1920        assert!(output.contains("no_deps"));
1921        // %r0 and %r1 are independent — no edges between them
1922        assert!(output.contains("no data dependencies"));
1923    }
1924
1925    #[test]
1926    fn test_cfg_renderer_empty_blocks() {
1927        let renderer = CfgRenderer;
1928        let output = renderer.render(&[]);
1929        assert!(output.contains("empty CFG"));
1930    }
1931}