use std::collections::HashMap;
use std::fmt::Write;
use crate::ir::{BasicBlock, Instruction, MemorySpace, Operand, PtxFunction, PtxModule};
#[derive(Debug, Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct ExplorerConfig {
pub use_color: bool,
pub max_width: usize,
pub show_line_numbers: bool,
pub show_register_types: bool,
pub show_instruction_latency: bool,
}
impl Default for ExplorerConfig {
fn default() -> Self {
Self {
use_color: false,
max_width: 120,
show_line_numbers: false,
show_register_types: false,
show_instruction_latency: false,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InstructionCategory {
Arithmetic,
Memory,
Control,
Synchronization,
TensorCore,
Special,
Conversion,
}
impl InstructionCategory {
#[must_use]
pub const fn label(self) -> &'static str {
match self {
Self::Arithmetic => "Arithmetic",
Self::Memory => "Memory",
Self::Control => "Control",
Self::Synchronization => "Sync",
Self::TensorCore => "TensorCore",
Self::Special => "Special",
Self::Conversion => "Conversion",
}
}
const fn ansi_color(self) -> &'static str {
match self {
Self::Arithmetic => "\x1b[32m", Self::Memory => "\x1b[34m", Self::Control => "\x1b[33m", Self::Synchronization => "\x1b[35m", Self::TensorCore => "\x1b[36m", Self::Special => "\x1b[90m", Self::Conversion => "\x1b[37m", }
}
}
#[derive(Debug, Clone)]
pub struct InstructionInfo {
pub instruction: String,
pub category: InstructionCategory,
pub latency_cycles: u32,
pub throughput_per_sm: f64,
pub registers_read: Vec<String>,
pub registers_written: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct RegisterLifetime {
pub register: String,
pub reg_type: String,
pub first_def: usize,
pub last_use: usize,
pub num_uses: usize,
}
#[derive(Debug, Clone)]
pub struct InstructionMix {
pub counts: HashMap<InstructionCategory, usize>,
pub total: usize,
}
#[derive(Debug, Clone)]
pub struct MemoryReport {
pub global_loads: usize,
pub global_stores: usize,
pub shared_loads: usize,
pub shared_stores: usize,
pub local_loads: usize,
pub local_stores: usize,
pub coalescing_score: f64,
}
#[derive(Debug, Clone)]
pub struct DiffReport {
pub added_instructions: usize,
pub removed_instructions: usize,
pub changed_blocks: usize,
pub register_delta: i32,
}
#[derive(Debug, Clone)]
pub struct ComplexityMetrics {
pub instruction_count: usize,
pub branch_count: usize,
pub loop_count: usize,
pub max_register_pressure: usize,
pub estimated_occupancy_pct: f64,
pub arithmetic_intensity: f64,
}
const fn categorize_instruction(inst: &Instruction) -> InstructionCategory {
match inst {
Instruction::Add { .. }
| Instruction::Sub { .. }
| Instruction::Mul { .. }
| Instruction::Mad { .. }
| Instruction::MadLo { .. }
| Instruction::MadHi { .. }
| Instruction::MadWide { .. }
| Instruction::Fma { .. }
| Instruction::Neg { .. }
| Instruction::Abs { .. }
| Instruction::Min { .. }
| Instruction::Max { .. }
| Instruction::Brev { .. }
| Instruction::Clz { .. }
| Instruction::Popc { .. }
| Instruction::Bfind { .. }
| Instruction::Bfe { .. }
| Instruction::Bfi { .. }
| Instruction::Shl { .. }
| Instruction::Shr { .. }
| Instruction::Div { .. }
| Instruction::Rem { .. }
| Instruction::And { .. }
| Instruction::Or { .. }
| Instruction::Xor { .. }
| Instruction::Rcp { .. }
| Instruction::Rsqrt { .. }
| Instruction::Sqrt { .. }
| Instruction::Ex2 { .. }
| Instruction::Lg2 { .. }
| Instruction::Sin { .. }
| Instruction::Cos { .. }
| Instruction::Dp4a { .. }
| Instruction::Dp2a { .. }
| Instruction::SetP { .. } => InstructionCategory::Arithmetic,
Instruction::Load { .. }
| Instruction::Store { .. }
| Instruction::CpAsync { .. }
| Instruction::CpAsyncCommit
| Instruction::CpAsyncWait { .. }
| Instruction::Atom { .. }
| Instruction::AtomCas { .. }
| Instruction::Red { .. }
| Instruction::TmaLoad { .. }
| Instruction::Tex1d { .. }
| Instruction::Tex2d { .. }
| Instruction::Tex3d { .. }
| Instruction::SurfLoad { .. }
| Instruction::SurfStore { .. }
| Instruction::Stmatrix { .. }
| Instruction::CpAsyncBulk { .. }
| Instruction::Ldmatrix { .. } => InstructionCategory::Memory,
Instruction::Branch { .. } | Instruction::Label(_) | Instruction::Return => {
InstructionCategory::Control
}
Instruction::BarSync { .. }
| Instruction::BarArrive { .. }
| Instruction::FenceAcqRel { .. }
| Instruction::FenceProxy { .. }
| Instruction::MbarrierInit { .. }
| Instruction::MbarrierArrive { .. }
| Instruction::MbarrierWait { .. }
| Instruction::ElectSync { .. }
| Instruction::Griddepcontrol { .. }
| Instruction::Redux { .. }
| Instruction::BarrierCluster
| Instruction::FenceCluster => InstructionCategory::Synchronization,
Instruction::Wmma { .. }
| Instruction::Mma { .. }
| Instruction::Wgmma { .. }
| Instruction::Tcgen05Mma { .. } => InstructionCategory::TensorCore,
Instruction::Cvt { .. } => InstructionCategory::Conversion,
Instruction::MovSpecial { .. }
| Instruction::LoadParam { .. }
| Instruction::Comment(_)
| Instruction::Raw(_)
| Instruction::Pragma(_)
| Instruction::Setmaxnreg { .. } => InstructionCategory::Special,
}
}
#[allow(clippy::match_same_arms)]
const fn estimate_latency(inst: &Instruction) -> u32 {
match inst {
Instruction::Add { .. }
| Instruction::Sub { .. }
| Instruction::Neg { .. }
| Instruction::Abs { .. }
| Instruction::Min { .. }
| Instruction::Max { .. }
| Instruction::And { .. }
| Instruction::Or { .. }
| Instruction::Xor { .. }
| Instruction::Shl { .. }
| Instruction::Shr { .. }
| Instruction::SetP { .. } => 4,
Instruction::Mul { .. }
| Instruction::Mad { .. }
| Instruction::MadLo { .. }
| Instruction::MadHi { .. }
| Instruction::MadWide { .. }
| Instruction::Fma { .. } => 4,
Instruction::Div { .. } | Instruction::Rem { .. } => 32,
Instruction::Rcp { .. } | Instruction::Rsqrt { .. } | Instruction::Sqrt { .. } => 8,
Instruction::Ex2 { .. }
| Instruction::Lg2 { .. }
| Instruction::Sin { .. }
| Instruction::Cos { .. } => 8,
Instruction::Brev { .. }
| Instruction::Clz { .. }
| Instruction::Popc { .. }
| Instruction::Bfind { .. }
| Instruction::Bfe { .. }
| Instruction::Bfi { .. } => 4,
Instruction::Dp4a { .. } | Instruction::Dp2a { .. } => 8,
Instruction::Load { .. } => 200,
Instruction::Store { .. } => 200,
Instruction::CpAsync { .. } => 200,
Instruction::CpAsyncCommit | Instruction::CpAsyncWait { .. } => 4,
Instruction::Atom { .. } | Instruction::AtomCas { .. } | Instruction::Red { .. } => 200,
Instruction::TmaLoad { .. } | Instruction::CpAsyncBulk { .. } => 200,
Instruction::Tex1d { .. } | Instruction::Tex2d { .. } | Instruction::Tex3d { .. } => 200,
Instruction::SurfLoad { .. } | Instruction::SurfStore { .. } => 200,
Instruction::Stmatrix { .. } => 32,
Instruction::Ldmatrix { .. } => 20,
Instruction::Branch { .. } => 8,
Instruction::Label(_) | Instruction::Return => 0,
Instruction::BarSync { .. }
| Instruction::BarArrive { .. }
| Instruction::FenceAcqRel { .. }
| Instruction::FenceProxy { .. }
| Instruction::MbarrierInit { .. }
| Instruction::MbarrierArrive { .. }
| Instruction::MbarrierWait { .. }
| Instruction::ElectSync { .. }
| Instruction::Griddepcontrol { .. }
| Instruction::Redux { .. }
| Instruction::BarrierCluster
| Instruction::FenceCluster => 16,
Instruction::Wmma { .. } => 32,
Instruction::Mma { .. } => 16,
Instruction::Wgmma { .. } => 64,
Instruction::Tcgen05Mma { .. } => 64,
Instruction::Cvt { .. } => 4,
Instruction::MovSpecial { .. } | Instruction::LoadParam { .. } => 4,
Instruction::Comment(_) | Instruction::Raw(_) | Instruction::Pragma(_) => 0,
Instruction::Setmaxnreg { .. } => 0,
}
}
const fn estimate_throughput(inst: &Instruction) -> f64 {
match categorize_instruction(inst) {
InstructionCategory::Arithmetic => 64.0,
InstructionCategory::Memory
| InstructionCategory::Control
| InstructionCategory::Special
| InstructionCategory::Conversion => 32.0,
InstructionCategory::Synchronization => 16.0,
InstructionCategory::TensorCore => 1.0,
}
}
fn registers_read(inst: &Instruction) -> Vec<String> {
let mut regs = Vec::new();
let mut push_operand = |op: &Operand| match op {
Operand::Register(r) => regs.push(r.name.clone()),
Operand::Address { base, .. } => regs.push(base.name.clone()),
_ => {}
};
match inst {
Instruction::Add { a, b, .. }
| Instruction::Sub { a, b, .. }
| Instruction::Mul { a, b, .. }
| Instruction::Min { a, b, .. }
| Instruction::Max { a, b, .. }
| Instruction::Div { a, b, .. }
| Instruction::Rem { a, b, .. }
| Instruction::And { a, b, .. }
| Instruction::Or { a, b, .. }
| Instruction::Xor { a, b, .. }
| Instruction::SetP { a, b, .. } => {
push_operand(a);
push_operand(b);
}
Instruction::Mad { a, b, c, .. }
| Instruction::MadLo { a, b, c, .. }
| Instruction::MadHi { a, b, c, .. }
| Instruction::MadWide { a, b, c, .. }
| Instruction::Fma { a, b, c, .. } => {
push_operand(a);
push_operand(b);
push_operand(c);
}
Instruction::Neg { src, .. }
| Instruction::Abs { src, .. }
| Instruction::Brev { src, .. }
| Instruction::Clz { src, .. }
| Instruction::Popc { src, .. }
| Instruction::Bfind { src, .. }
| Instruction::Cvt { src, .. }
| Instruction::Rcp { src, .. }
| Instruction::Rsqrt { src, .. }
| Instruction::Sqrt { src, .. }
| Instruction::Ex2 { src, .. }
| Instruction::Lg2 { src, .. }
| Instruction::Sin { src, .. }
| Instruction::Cos { src, .. } => {
push_operand(src);
}
Instruction::Load { addr, .. } => {
push_operand(addr);
}
Instruction::Store { addr, src, .. } => {
push_operand(addr);
regs.push(src.name.clone());
}
Instruction::Branch {
predicate: Some((pred, _)),
..
} => {
regs.push(pred.name.clone());
}
Instruction::Shl { src, amount, .. } | Instruction::Shr { src, amount, .. } => {
push_operand(src);
push_operand(amount);
}
_ => {}
}
regs
}
fn registers_written(inst: &Instruction) -> Vec<String> {
match inst {
Instruction::Add { dst, .. }
| Instruction::Sub { dst, .. }
| Instruction::Mul { dst, .. }
| Instruction::Mad { dst, .. }
| Instruction::MadLo { dst, .. }
| Instruction::MadHi { dst, .. }
| Instruction::MadWide { dst, .. }
| Instruction::Fma { dst, .. }
| Instruction::Neg { dst, .. }
| Instruction::Abs { dst, .. }
| Instruction::Min { dst, .. }
| Instruction::Max { dst, .. }
| Instruction::Brev { dst, .. }
| Instruction::Clz { dst, .. }
| Instruction::Popc { dst, .. }
| Instruction::Bfind { dst, .. }
| Instruction::Bfe { dst, .. }
| Instruction::Bfi { dst, .. }
| Instruction::Shl { dst, .. }
| Instruction::Shr { dst, .. }
| Instruction::Div { dst, .. }
| Instruction::Rem { dst, .. }
| Instruction::And { dst, .. }
| Instruction::Or { dst, .. }
| Instruction::Xor { dst, .. }
| Instruction::SetP { dst, .. }
| Instruction::Load { dst, .. }
| Instruction::Cvt { dst, .. }
| Instruction::Atom { dst, .. }
| Instruction::AtomCas { dst, .. }
| Instruction::MovSpecial { dst, .. }
| Instruction::LoadParam { dst, .. }
| Instruction::Rcp { dst, .. }
| Instruction::Rsqrt { dst, .. }
| Instruction::Sqrt { dst, .. }
| Instruction::Ex2 { dst, .. }
| Instruction::Lg2 { dst, .. }
| Instruction::Sin { dst, .. }
| Instruction::Cos { dst, .. }
| Instruction::Dp4a { dst, .. }
| Instruction::Dp2a { dst, .. }
| Instruction::Tex1d { dst, .. }
| Instruction::Tex2d { dst, .. }
| Instruction::Tex3d { dst, .. }
| Instruction::SurfLoad { dst, .. }
| Instruction::Redux { dst, .. }
| Instruction::ElectSync { dst, .. } => vec![dst.name.clone()],
_ => Vec::new(),
}
}
const ANSI_RESET: &str = "\x1b[0m";
const ANSI_BOLD: &str = "\x1b[1m";
fn colorize(text: &str, color: &str, use_color: bool) -> String {
if use_color {
format!("{color}{text}{ANSI_RESET}")
} else {
text.to_string()
}
}
#[derive(Debug, Clone)]
pub struct PtxExplorer {
config: ExplorerConfig,
}
impl PtxExplorer {
#[must_use]
pub const fn new(config: ExplorerConfig) -> Self {
Self { config }
}
#[must_use]
pub fn render_function(&self, func: &PtxFunction) -> String {
let mut out = String::new();
let header = format!(".entry {} (", func.name);
let _ = writeln!(
out,
"{}",
colorize(&header, ANSI_BOLD, self.config.use_color)
);
for (i, (name, ty)) in func.params.iter().enumerate() {
let comma = if i + 1 < func.params.len() { "," } else { "" };
let _ = writeln!(out, " .param {} {}{}", ty.as_ptx_str(), name, comma);
}
let _ = writeln!(out, ")");
let _ = writeln!(out, "{{");
for (idx, inst) in func.body.iter().enumerate() {
let cat = categorize_instruction(inst);
let emitted = inst.emit();
let line = if self.config.show_line_numbers {
format!("{:>4} {}", idx + 1, emitted)
} else {
format!(" {emitted}")
};
let line = if self.config.show_instruction_latency {
let lat = estimate_latency(inst);
if lat > 0 {
let pad = self.config.max_width.saturating_sub(line.len()).max(2);
format!("{line}{:>pad$}", format!("// ~{lat} cycles"), pad = pad)
} else {
line
}
} else {
line
};
let _ = writeln!(
out,
"{}",
colorize(&line, cat.ansi_color(), self.config.use_color)
);
}
let _ = writeln!(out, "}}");
out
}
#[must_use]
pub fn render_module(&self, module: &PtxModule) -> String {
let mut out = String::new();
let _ = writeln!(out, ".version {}", module.version);
let _ = writeln!(out, ".target {}", module.target);
let _ = writeln!(out, ".address_size {}", module.address_size);
let _ = writeln!(out);
for func in &module.functions {
out.push_str(&self.render_function(func));
let _ = writeln!(out);
}
out
}
#[must_use]
pub fn render_cfg(&self, func: &PtxFunction) -> String {
let blocks = split_into_blocks(&func.body);
let renderer = CfgRenderer;
renderer.render(&blocks)
}
#[must_use]
pub fn render_register_lifetime(&self, func: &PtxFunction) -> String {
let analyzer = RegisterLifetimeAnalyzer;
let lifetimes = analyzer.analyze(func);
RegisterLifetimeAnalyzer::render_timeline(&lifetimes, self.config.max_width)
}
#[must_use]
pub fn render_instruction_mix(&self, func: &PtxFunction) -> String {
let analyzer = InstructionMixAnalyzer;
let mix = analyzer.analyze(func);
InstructionMixAnalyzer::render_bar_chart(&mix, self.config.max_width)
}
#[must_use]
pub fn render_dependency_graph(&self, block: &BasicBlock) -> String {
let mut out = String::new();
let label = block.label.as_deref().unwrap_or("(unnamed)");
let _ = writeln!(out, "Dependency graph for block: {label}");
let _ = writeln!(out, "{}", "-".repeat(40));
let mut last_writer: HashMap<String, usize> = HashMap::new();
let mut edges: Vec<(usize, usize, String)> = Vec::new();
for (idx, inst) in block.instructions.iter().enumerate() {
for reg in registers_read(inst) {
if let Some(&writer_idx) = last_writer.get(®) {
edges.push((writer_idx, idx, reg));
}
}
for reg in registers_written(inst) {
last_writer.insert(reg, idx);
}
}
if edges.is_empty() {
let _ = writeln!(out, "(no data dependencies)");
} else {
for (from, to, reg) in &edges {
let from_text = block
.instructions
.get(*from)
.map_or_else(|| "?".to_string(), |i| truncate_emit(i, 40));
let to_text = block
.instructions
.get(*to)
.map_or_else(|| "?".to_string(), |i| truncate_emit(i, 40));
let _ = writeln!(out, "[{from}] {from_text}");
let _ = writeln!(out, " --({reg})--> [{to}] {to_text}");
}
}
out
}
}
#[must_use]
pub fn analyze_instruction(inst: &Instruction) -> InstructionInfo {
InstructionInfo {
instruction: inst.emit(),
category: categorize_instruction(inst),
latency_cycles: estimate_latency(inst),
throughput_per_sm: estimate_throughput(inst),
registers_read: registers_read(inst),
registers_written: registers_written(inst),
}
}
pub struct CfgRenderer;
impl CfgRenderer {
#[must_use]
pub fn render(&self, blocks: &[BasicBlock]) -> String {
if blocks.is_empty() {
return "(empty CFG)\n".to_string();
}
let mut out = String::new();
let _ = writeln!(out, "Control Flow Graph");
let _ = writeln!(out, "==================");
let _ = writeln!(out);
let mut label_to_idx: HashMap<&str, usize> = HashMap::new();
for (idx, blk) in blocks.iter().enumerate() {
if let Some(ref label) = blk.label {
label_to_idx.insert(label.as_str(), idx);
}
}
let mut edges: Vec<(usize, usize)> = Vec::new();
for (idx, blk) in blocks.iter().enumerate() {
for inst in &blk.instructions {
if let Instruction::Branch { target, .. } = inst {
if let Some(&target_idx) = label_to_idx.get(target.as_str()) {
edges.push((idx, target_idx));
}
}
}
let is_terminal = blk.instructions.last().is_some_and(|i| {
matches!(
i,
Instruction::Return
| Instruction::Branch {
predicate: None,
..
}
)
});
if !is_terminal && idx + 1 < blocks.len() {
edges.push((idx, idx + 1));
}
}
for (idx, blk) in blocks.iter().enumerate() {
let label = blk.label.as_deref().unwrap_or("(entry)");
let box_content = format!("B{idx}: {label} ({} insts)", blk.instructions.len());
let box_width = box_content.len() + 4;
let border = "+".to_string() + &"-".repeat(box_width - 2) + "+";
let _ = writeln!(out, "{border}");
let _ = writeln!(out, "| {box_content} |");
let _ = writeln!(out, "{border}");
let outgoing: Vec<&(usize, usize)> = edges.iter().filter(|(f, _)| *f == idx).collect();
for (_, to) in outgoing {
let target_label = blocks
.get(*to)
.and_then(|b| b.label.as_deref())
.unwrap_or("(next)");
let _ = writeln!(out, " |");
let _ = writeln!(out, " +--> B{to}: {target_label}");
}
let _ = writeln!(out);
}
out
}
}
pub struct RegisterLifetimeAnalyzer;
impl RegisterLifetimeAnalyzer {
#[must_use]
pub fn analyze(&self, func: &PtxFunction) -> Vec<RegisterLifetime> {
let mut first_defs: HashMap<String, (usize, String)> = HashMap::new();
let mut last_uses: HashMap<String, usize> = HashMap::new();
let mut use_counts: HashMap<String, usize> = HashMap::new();
for (idx, inst) in func.body.iter().enumerate() {
for reg in registers_written(inst) {
first_defs.entry(reg.clone()).or_insert_with(|| {
let reg_type = Self::infer_type(inst, ®);
(idx, reg_type)
});
last_uses.insert(reg, idx);
}
for reg in registers_read(inst) {
last_uses.insert(reg.clone(), idx);
*use_counts.entry(reg).or_insert(0) += 1;
}
}
let mut lifetimes: Vec<RegisterLifetime> = first_defs
.into_iter()
.map(|(reg, (def_idx, reg_type))| {
let last = last_uses.get(®).copied().unwrap_or(def_idx);
let uses = use_counts.get(®).copied().unwrap_or(0);
RegisterLifetime {
register: reg,
reg_type,
first_def: def_idx,
last_use: last,
num_uses: uses,
}
})
.collect();
lifetimes.sort_by_key(|l| (l.first_def, l.register.clone()));
lifetimes
}
#[must_use]
pub fn render_timeline(lifetimes: &[RegisterLifetime], max_width: usize) -> String {
if lifetimes.is_empty() {
return "(no registers)\n".to_string();
}
let mut out = String::new();
let _ = writeln!(out, "Register Lifetimes");
let _ = writeln!(out, "==================");
let _ = writeln!(out);
let max_inst = lifetimes
.iter()
.map(|l| l.last_use)
.max()
.unwrap_or(0)
.max(1);
let name_col_width = lifetimes
.iter()
.map(|l| l.register.len())
.max()
.unwrap_or(4)
.max(4);
let type_col_width = lifetimes
.iter()
.map(|l| l.reg_type.len())
.max()
.unwrap_or(4)
.max(4);
let bar_width = max_width
.saturating_sub(name_col_width + type_col_width + 10)
.max(10);
let _ = writeln!(
out,
"{:>nw$} {:>tw$} Lifetime",
"Reg",
"Type",
nw = name_col_width,
tw = type_col_width
);
let _ = writeln!(
out,
"{} {} {}",
"-".repeat(name_col_width),
"-".repeat(type_col_width),
"-".repeat(bar_width),
);
for lt in lifetimes {
let start_pos = (lt.first_def * bar_width) / max_inst.max(1);
let end_pos = (lt.last_use * bar_width) / max_inst.max(1);
let end_pos = end_pos.max(start_pos + 1).min(bar_width);
let mut bar = vec![' '; bar_width];
for ch in bar.iter_mut().take(end_pos).skip(start_pos) {
*ch = '#';
}
let bar_str: String = bar.into_iter().collect();
let _ = writeln!(
out,
"{:>nw$} {:>tw$} {bar_str} (uses: {})",
lt.register,
lt.reg_type,
lt.num_uses,
nw = name_col_width,
tw = type_col_width,
);
}
out
}
fn infer_type(inst: &Instruction, _reg: &str) -> String {
match inst {
Instruction::Add { ty, .. }
| Instruction::Sub { ty, .. }
| Instruction::Mul { ty, .. }
| Instruction::Min { ty, .. }
| Instruction::Max { ty, .. }
| Instruction::Neg { ty, .. }
| Instruction::Abs { ty, .. }
| Instruction::Div { ty, .. }
| Instruction::Rem { ty, .. }
| Instruction::And { ty, .. }
| Instruction::Or { ty, .. }
| Instruction::Xor { ty, .. }
| Instruction::Shl { ty, .. }
| Instruction::Shr { ty, .. }
| Instruction::Load { ty, .. }
| Instruction::Brev { ty, .. }
| Instruction::Clz { ty, .. }
| Instruction::Popc { ty, .. }
| Instruction::Bfind { ty, .. }
| Instruction::Bfe { ty, .. }
| Instruction::Bfi { ty, .. }
| Instruction::Rcp { ty, .. }
| Instruction::Rsqrt { ty, .. }
| Instruction::Sqrt { ty, .. }
| Instruction::Ex2 { ty, .. }
| Instruction::Lg2 { ty, .. }
| Instruction::Sin { ty, .. }
| Instruction::Cos { ty, .. }
| Instruction::Tex1d { ty, .. }
| Instruction::Tex2d { ty, .. }
| Instruction::Tex3d { ty, .. }
| Instruction::SurfLoad { ty, .. }
| Instruction::Atom { ty, .. }
| Instruction::AtomCas { ty, .. }
| Instruction::Mad { ty, .. }
| Instruction::Fma { ty, .. }
| Instruction::SetP { ty, .. }
| Instruction::LoadParam { ty, .. } => ty.as_ptx_str().to_string(),
Instruction::MadLo { typ, .. } | Instruction::MadHi { typ, .. } => {
typ.as_ptx_str().to_string()
}
Instruction::MadWide { src_typ, .. } => src_typ.as_ptx_str().to_string(),
Instruction::Cvt { dst_ty, .. } => dst_ty.as_ptx_str().to_string(),
_ => "?".to_string(),
}
}
}
pub struct InstructionMixAnalyzer;
impl InstructionMixAnalyzer {
#[must_use]
pub fn analyze(&self, func: &PtxFunction) -> InstructionMix {
let mut counts: HashMap<InstructionCategory, usize> = HashMap::new();
for inst in &func.body {
*counts.entry(categorize_instruction(inst)).or_insert(0) += 1;
}
InstructionMix {
total: func.body.len(),
counts,
}
}
#[must_use]
pub fn render_bar_chart(mix: &InstructionMix, width: usize) -> String {
if mix.total == 0 {
return "(no instructions)\n".to_string();
}
let mut out = String::new();
let _ = writeln!(out, "Instruction Mix");
let _ = writeln!(out, "===============");
let _ = writeln!(out);
let label_width = 12_usize;
let bar_width = width.saturating_sub(label_width + 20).max(10);
let mut categories: Vec<(InstructionCategory, usize)> =
mix.counts.iter().map(|(&cat, &cnt)| (cat, cnt)).collect();
categories.sort_by_key(|&(_, cnt)| std::cmp::Reverse(cnt));
for (cat, count) in &categories {
#[allow(clippy::cast_precision_loss)]
let pct = (*count as f64 / mix.total as f64) * 100.0;
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
let filled = ((*count as f64 / mix.total as f64) * bar_width as f64) as usize;
let bar: String = "#".repeat(filled) + &" ".repeat(bar_width.saturating_sub(filled));
let _ = writeln!(
out,
"{:<lw$} [{bar}] {count:>4} ({pct:>5.1}%)",
cat.label(),
lw = label_width,
);
}
let _ = writeln!(out);
let _ = writeln!(out, "Total: {} instructions", mix.total);
out
}
}
pub struct MemoryAccessPattern;
impl MemoryAccessPattern {
#[must_use]
pub fn analyze(func: &PtxFunction) -> MemoryReport {
let mut report = MemoryReport {
global_loads: 0,
global_stores: 0,
shared_loads: 0,
shared_stores: 0,
local_loads: 0,
local_stores: 0,
coalescing_score: 1.0,
};
let mut total_mem_ops = 0_usize;
let mut likely_coalesced = 0_usize;
for inst in &func.body {
match inst {
Instruction::Load { space, .. } => {
total_mem_ops += 1;
match space {
MemorySpace::Global => {
report.global_loads += 1;
likely_coalesced += 1;
}
MemorySpace::Shared => report.shared_loads += 1,
MemorySpace::Local => report.local_loads += 1,
_ => {}
}
}
Instruction::Store { space, .. } => {
total_mem_ops += 1;
match space {
MemorySpace::Global => {
report.global_stores += 1;
likely_coalesced += 1;
}
MemorySpace::Shared => report.shared_stores += 1,
MemorySpace::Local => report.local_stores += 1,
_ => {}
}
}
Instruction::CpAsync { .. } | Instruction::TmaLoad { .. } => {
total_mem_ops += 1;
report.global_loads += 1;
report.shared_stores += 1;
likely_coalesced += 1;
}
_ => {}
}
}
if total_mem_ops > 0 {
#[allow(clippy::cast_precision_loss)]
{
report.coalescing_score = likely_coalesced as f64 / total_mem_ops as f64;
}
}
report
}
}
pub struct PtxDiff;
impl PtxDiff {
#[must_use]
pub fn diff(a: &PtxFunction, b: &PtxFunction) -> DiffReport {
let a_count = a.body.len();
let b_count = b.body.len();
let added = b_count.saturating_sub(a_count);
let removed = a_count.saturating_sub(b_count);
let a_blocks = split_into_blocks(&a.body);
let b_blocks = split_into_blocks(&b.body);
let changed_blocks = count_changed_blocks(&a_blocks, &b_blocks);
let a_regs = count_unique_registers(&a.body);
let b_regs = count_unique_registers(&b.body);
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let register_delta = b_regs as i32 - a_regs as i32;
DiffReport {
added_instructions: added,
removed_instructions: removed,
changed_blocks,
register_delta,
}
}
#[must_use]
pub fn render_diff(report: &DiffReport) -> String {
let mut out = String::new();
let _ = writeln!(out, "PTX Diff Report");
let _ = writeln!(out, "===============");
let _ = writeln!(out);
let _ = writeln!(out, "Added instructions: +{}", report.added_instructions);
let _ = writeln!(
out,
"Removed instructions: -{}",
report.removed_instructions
);
let _ = writeln!(out, "Changed blocks: {}", report.changed_blocks);
let sign = if report.register_delta >= 0 { "+" } else { "" };
let _ = writeln!(out, "Register delta: {sign}{}", report.register_delta);
out
}
}
pub struct KernelComplexityScore;
impl KernelComplexityScore {
#[must_use]
pub fn analyze(func: &PtxFunction) -> ComplexityMetrics {
let instruction_count = func.body.len();
let mut branch_count = 0_usize;
let mut arith_count = 0_usize;
let mut mem_count = 0_usize;
for inst in &func.body {
match categorize_instruction(inst) {
InstructionCategory::Control => {
if matches!(inst, Instruction::Branch { .. }) {
branch_count += 1;
}
}
InstructionCategory::Arithmetic => arith_count += 1,
InstructionCategory::Memory => mem_count += 1,
_ => {}
}
}
let loop_count = count_back_edges(&func.body);
let max_register_pressure = compute_max_register_pressure(&func.body);
#[allow(clippy::cast_precision_loss)]
let estimated_occupancy_pct = if max_register_pressure > 0 {
let warps_per_sm = 65536_f64 / (32.0 * max_register_pressure as f64);
let warps_per_sm = warps_per_sm.min(64.0);
(warps_per_sm / 64.0) * 100.0
} else {
100.0
};
#[allow(clippy::cast_precision_loss)]
let arithmetic_intensity = if mem_count > 0 {
arith_count as f64 / mem_count as f64
} else if arith_count > 0 {
f64::INFINITY
} else {
0.0
};
ComplexityMetrics {
instruction_count,
branch_count,
loop_count,
max_register_pressure,
estimated_occupancy_pct,
arithmetic_intensity,
}
}
}
fn split_into_blocks(body: &[Instruction]) -> Vec<BasicBlock> {
if body.is_empty() {
return Vec::new();
}
let mut blocks: Vec<BasicBlock> = Vec::new();
let mut current_label: Option<String> = None;
let mut current_insts: Vec<Instruction> = Vec::new();
for inst in body {
if let Instruction::Label(lbl) = inst {
if !current_insts.is_empty() || current_label.is_some() {
blocks.push(BasicBlock {
label: current_label.take(),
instructions: std::mem::take(&mut current_insts),
});
}
current_label = Some(lbl.clone());
} else {
current_insts.push(inst.clone());
}
}
if !current_insts.is_empty() || current_label.is_some() {
blocks.push(BasicBlock {
label: current_label,
instructions: current_insts,
});
}
blocks
}
fn count_changed_blocks(a: &[BasicBlock], b: &[BasicBlock]) -> usize {
let max_len = a.len().max(b.len());
let mut changed = 0_usize;
for i in 0..max_len {
let a_block = a.get(i);
let b_block = b.get(i);
match (a_block, b_block) {
(Some(ab), Some(bb)) => {
if ab.label != bb.label || ab.instructions.len() != bb.instructions.len() {
changed += 1;
} else {
let differs = ab
.instructions
.iter()
.zip(bb.instructions.iter())
.any(|(ai, bi)| ai.emit() != bi.emit());
if differs {
changed += 1;
}
}
}
_ => changed += 1,
}
}
changed
}
fn count_unique_registers(body: &[Instruction]) -> usize {
let mut regs = std::collections::HashSet::new();
for inst in body {
for r in registers_read(inst) {
regs.insert(r);
}
for r in registers_written(inst) {
regs.insert(r);
}
}
regs.len()
}
fn count_back_edges(body: &[Instruction]) -> usize {
let mut label_positions: HashMap<&str, usize> = HashMap::new();
for (idx, inst) in body.iter().enumerate() {
if let Instruction::Label(lbl) = inst {
label_positions.insert(lbl.as_str(), idx);
}
}
let mut count = 0_usize;
for (idx, inst) in body.iter().enumerate() {
if let Instruction::Branch { target, .. } = inst {
if let Some(&lbl_idx) = label_positions.get(target.as_str()) {
if lbl_idx <= idx {
count += 1;
}
}
}
}
count
}
fn compute_max_register_pressure(body: &[Instruction]) -> usize {
if body.is_empty() {
return 0;
}
let mut first_def: HashMap<String, usize> = HashMap::new();
let mut last_use: HashMap<String, usize> = HashMap::new();
for (idx, inst) in body.iter().enumerate() {
for r in registers_written(inst) {
first_def.entry(r.clone()).or_insert(idx);
last_use.insert(r, idx);
}
for r in registers_read(inst) {
last_use.insert(r, idx);
}
}
let intervals: Vec<(usize, usize)> = first_def
.iter()
.map(|(reg, &def)| {
let use_end = last_use.get(reg).copied().unwrap_or(def);
(def, use_end)
})
.collect();
let mut max_live = 0_usize;
for idx in 0..body.len() {
let live = intervals
.iter()
.filter(|(start, end)| *start <= idx && idx <= *end)
.count();
if live > max_live {
max_live = live;
}
}
max_live
}
fn truncate_emit(inst: &Instruction, max_len: usize) -> String {
let s = inst.emit();
if s.len() > max_len {
format!("{}...", &s[..max_len.saturating_sub(3)])
} else {
s
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{
BasicBlock, CacheQualifier, CmpOp, ImmValue, Instruction, MemorySpace, Operand,
PtxFunction, PtxModule, PtxType, Register, RoundingMode, SpecialReg, VectorWidth,
};
fn make_reg(name: &str, ty: PtxType) -> Register {
Register {
name: name.to_string(),
ty,
}
}
fn make_operand_reg(name: &str, ty: PtxType) -> Operand {
Operand::Register(make_reg(name, ty))
}
fn make_simple_function() -> PtxFunction {
let mut func = PtxFunction::new("test_kernel");
func.add_param("a_ptr", PtxType::U64);
func.add_param("n", PtxType::U32);
func.push(Instruction::LoadParam {
ty: PtxType::U64,
dst: make_reg("%rd0", PtxType::U64),
param_name: "a_ptr".to_string(),
});
func.push(Instruction::MovSpecial {
dst: make_reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
});
func.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r1", PtxType::U32),
a: make_operand_reg("%r0", PtxType::U32),
b: Operand::Immediate(ImmValue::U32(1)),
});
func.push(Instruction::Load {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
dst: make_reg("%f0", PtxType::F32),
addr: Operand::Address {
base: make_reg("%rd0", PtxType::U64),
offset: None,
},
});
func.push(Instruction::Fma {
rnd: RoundingMode::Rn,
ty: PtxType::F32,
dst: make_reg("%f1", PtxType::F32),
a: make_operand_reg("%f0", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(2.0)),
c: Operand::Immediate(ImmValue::F32(1.0)),
});
func.push(Instruction::Store {
space: MemorySpace::Global,
qualifier: CacheQualifier::None,
vec: VectorWidth::V1,
ty: PtxType::F32,
addr: Operand::Address {
base: make_reg("%rd0", PtxType::U64),
offset: None,
},
src: make_reg("%f1", PtxType::F32),
});
func.push(Instruction::Return);
func
}
fn make_branching_function() -> PtxFunction {
let mut func = PtxFunction::new("branch_kernel");
func.push(Instruction::MovSpecial {
dst: make_reg("%r0", PtxType::U32),
special: SpecialReg::TidX,
});
func.push(Instruction::SetP {
cmp: CmpOp::Lt,
ty: PtxType::U32,
dst: make_reg("%p0", PtxType::Pred),
a: make_operand_reg("%r0", PtxType::U32),
b: Operand::Immediate(ImmValue::U32(128)),
});
func.push(Instruction::Branch {
target: "skip".to_string(),
predicate: Some((make_reg("%p0", PtxType::Pred), true)),
});
func.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r1", PtxType::U32),
a: make_operand_reg("%r0", PtxType::U32),
b: Operand::Immediate(ImmValue::U32(1)),
});
func.push(Instruction::Label("skip".to_string()));
func.push(Instruction::Return);
func
}
#[test]
fn test_render_empty_function() {
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let func = PtxFunction::new("empty");
let output = explorer.render_function(&func);
assert!(output.contains("empty"));
assert!(output.contains('{'));
assert!(output.contains('}'));
}
#[test]
fn test_render_function_with_multiple_blocks() {
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let func = make_branching_function();
let output = explorer.render_function(&func);
assert!(output.contains("branch_kernel"));
assert!(output.contains("setp"));
assert!(output.contains("bra"));
assert!(output.contains("add"));
}
#[test]
fn test_cfg_rendering_with_branches() {
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let func = make_branching_function();
let output = explorer.render_cfg(&func);
assert!(output.contains("Control Flow Graph"));
assert!(output.contains("skip"));
assert!(output.contains("-->"));
}
#[test]
fn test_register_lifetime_analysis() {
let analyzer = RegisterLifetimeAnalyzer;
let func = make_simple_function();
let lifetimes = analyzer.analyze(&func);
assert!(!lifetimes.is_empty());
let rd0 = lifetimes.iter().find(|l| l.register == "%rd0");
assert!(rd0.is_some(), "should find %rd0 lifetime");
let rd0 = rd0.expect("checked above");
assert_eq!(rd0.first_def, 0);
assert!(
rd0.last_use > rd0.first_def,
"last_use should be after first_def"
);
}
#[test]
fn test_register_lifetime_timeline_rendering() {
let analyzer = RegisterLifetimeAnalyzer;
let func = make_simple_function();
let lifetimes = analyzer.analyze(&func);
let timeline = RegisterLifetimeAnalyzer::render_timeline(&lifetimes, 80);
assert!(timeline.contains("Register Lifetimes"));
assert!(timeline.contains('#')); assert!(timeline.contains("uses:"));
}
#[test]
fn test_instruction_mix_categorization() {
let analyzer = InstructionMixAnalyzer;
let func = make_simple_function();
let mix = analyzer.analyze(&func);
assert_eq!(mix.total, func.body.len());
let arith = mix
.counts
.get(&InstructionCategory::Arithmetic)
.copied()
.unwrap_or(0);
let mem = mix
.counts
.get(&InstructionCategory::Memory)
.copied()
.unwrap_or(0);
let special = mix
.counts
.get(&InstructionCategory::Special)
.copied()
.unwrap_or(0);
assert!(arith > 0, "should have arithmetic instructions");
assert!(mem > 0, "should have memory instructions");
assert!(special > 0, "should have special instructions");
}
#[test]
fn test_instruction_mix_bar_chart() {
let analyzer = InstructionMixAnalyzer;
let func = make_simple_function();
let mix = analyzer.analyze(&func);
let chart = InstructionMixAnalyzer::render_bar_chart(&mix, 80);
assert!(chart.contains("Instruction Mix"));
assert!(chart.contains('#')); assert!(chart.contains('%')); assert!(chart.contains("Total:"));
}
#[test]
fn test_memory_access_pattern_analysis() {
let func = make_simple_function();
let report = MemoryAccessPattern::analyze(&func);
assert_eq!(report.global_loads, 1);
assert_eq!(report.global_stores, 1);
assert_eq!(report.shared_loads, 0);
assert_eq!(report.shared_stores, 0);
assert!(report.coalescing_score > 0.0);
assert!(report.coalescing_score <= 1.0);
}
#[test]
fn test_ptx_diff_identical_functions() {
let func = make_simple_function();
let report = PtxDiff::diff(&func, &func);
assert_eq!(report.added_instructions, 0);
assert_eq!(report.removed_instructions, 0);
assert_eq!(report.changed_blocks, 0);
assert_eq!(report.register_delta, 0);
}
#[test]
fn test_ptx_diff_different_functions() {
let a = make_simple_function();
let mut b = make_simple_function();
b.push(Instruction::Comment("extra".to_string()));
b.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r99", PtxType::U32),
a: Operand::Immediate(ImmValue::U32(0)),
b: Operand::Immediate(ImmValue::U32(1)),
});
let report = PtxDiff::diff(&a, &b);
assert!(report.added_instructions > 0);
assert!(report.register_delta > 0);
let rendered = PtxDiff::render_diff(&report);
assert!(rendered.contains("PTX Diff Report"));
assert!(rendered.contains('+'));
}
#[test]
fn test_kernel_complexity_scoring() {
let func = make_branching_function();
let metrics = KernelComplexityScore::analyze(&func);
assert_eq!(metrics.instruction_count, func.body.len());
assert!(metrics.branch_count > 0, "should detect branches");
assert!(metrics.estimated_occupancy_pct > 0.0);
assert!(metrics.estimated_occupancy_pct <= 100.0);
}
#[test]
fn test_color_vs_no_color_output() {
let func = make_simple_function();
let no_color = PtxExplorer::new(ExplorerConfig {
use_color: false,
..ExplorerConfig::default()
});
let with_color = PtxExplorer::new(ExplorerConfig {
use_color: true,
..ExplorerConfig::default()
});
let plain = no_color.render_function(&func);
let colored = with_color.render_function(&func);
assert!(colored.contains("\x1b["));
assert!(!plain.contains("\x1b["));
assert!(plain.contains("test_kernel"));
assert!(colored.contains("test_kernel"));
}
#[test]
fn test_config_defaults() {
let config = ExplorerConfig::default();
assert!(!config.use_color);
assert_eq!(config.max_width, 120);
assert!(!config.show_line_numbers);
assert!(!config.show_register_types);
assert!(!config.show_instruction_latency);
}
#[test]
fn test_large_function_handling() {
let mut func = PtxFunction::new("big_kernel");
for i in 0..500 {
func.push(Instruction::Add {
ty: PtxType::F32,
dst: make_reg(&format!("%f{i}"), PtxType::F32),
a: Operand::Immediate(ImmValue::F32(1.0)),
b: Operand::Immediate(ImmValue::F32(2.0)),
});
}
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let output = explorer.render_function(&func);
assert!(output.lines().count() > 500);
let mix = InstructionMixAnalyzer.analyze(&func);
assert_eq!(mix.total, 500);
let metrics = KernelComplexityScore::analyze(&func);
assert_eq!(metrics.instruction_count, 500);
}
#[test]
fn test_line_number_rendering() {
let config = ExplorerConfig {
show_line_numbers: true,
..ExplorerConfig::default()
};
let explorer = PtxExplorer::new(config);
let func = make_simple_function();
let output = explorer.render_function(&func);
assert!(output.contains(" 1 "));
assert!(output.contains(" 2 "));
}
#[test]
fn test_render_module() {
let mut module = PtxModule::new("sm_80");
module.add_function(make_simple_function());
module.add_function(make_branching_function());
let explorer = PtxExplorer::new(ExplorerConfig::default());
let output = explorer.render_module(&module);
assert!(output.contains(".version 8.5"));
assert!(output.contains(".target sm_80"));
assert!(output.contains("test_kernel"));
assert!(output.contains("branch_kernel"));
}
#[test]
fn test_dependency_graph() {
let mut block = BasicBlock::with_label("test_block");
block.push(Instruction::LoadParam {
ty: PtxType::F32,
dst: make_reg("%f0", PtxType::F32),
param_name: "x".to_string(),
});
block.push(Instruction::Add {
ty: PtxType::F32,
dst: make_reg("%f1", PtxType::F32),
a: make_operand_reg("%f0", PtxType::F32),
b: Operand::Immediate(ImmValue::F32(1.0)),
});
block.push(Instruction::Add {
ty: PtxType::F32,
dst: make_reg("%f2", PtxType::F32),
a: make_operand_reg("%f1", PtxType::F32),
b: make_operand_reg("%f0", PtxType::F32),
});
let explorer = PtxExplorer::new(ExplorerConfig::default());
let output = explorer.render_dependency_graph(&block);
assert!(output.contains("Dependency graph"));
assert!(output.contains("test_block"));
assert!(output.contains("-->")); assert!(output.contains("%f0")); }
#[test]
fn test_cfg_empty_function() {
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let func = PtxFunction::new("empty_kernel");
let output = explorer.render_cfg(&func);
assert!(
output.contains("empty CFG")
|| output.contains("Control Flow Graph")
|| output.is_empty()
|| output.contains("(entry)")
);
}
#[test]
fn test_cfg_no_branch_single_block() {
let config = ExplorerConfig::default();
let explorer = PtxExplorer::new(config);
let func = make_simple_function();
let output = explorer.render_cfg(&func);
assert!(output.contains("Control Flow Graph"));
assert!(output.contains("B0"));
}
#[test]
fn test_register_lifetime_single_instruction() {
let analyzer = RegisterLifetimeAnalyzer;
let mut func = PtxFunction::new("single");
func.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r0", PtxType::U32),
a: Operand::Immediate(ImmValue::U32(1)),
b: Operand::Immediate(ImmValue::U32(2)),
});
let lifetimes = analyzer.analyze(&func);
let r0 = lifetimes.iter().find(|l| l.register == "%r0");
assert!(r0.is_some(), "should track %r0");
let r0 = r0.expect("checked above");
assert_eq!(r0.first_def, 0);
assert_eq!(r0.last_use, 0);
}
#[test]
fn test_register_lifetime_render_empty() {
let rendered = RegisterLifetimeAnalyzer::render_timeline(&[], 80);
assert!(rendered.contains("no registers"));
}
#[test]
fn test_instruction_mix_empty_function() {
let analyzer = InstructionMixAnalyzer;
let func = PtxFunction::new("empty_kernel");
let mix = analyzer.analyze(&func);
assert_eq!(mix.total, 0);
let chart = InstructionMixAnalyzer::render_bar_chart(&mix, 80);
assert!(chart.contains("no instructions"));
}
#[test]
fn test_dependency_graph_no_deps() {
let mut block = BasicBlock::with_label("no_deps");
block.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r0", PtxType::U32),
a: Operand::Immediate(ImmValue::U32(1)),
b: Operand::Immediate(ImmValue::U32(2)),
});
block.push(Instruction::Add {
ty: PtxType::U32,
dst: make_reg("%r1", PtxType::U32),
a: Operand::Immediate(ImmValue::U32(3)),
b: Operand::Immediate(ImmValue::U32(4)),
});
let explorer = PtxExplorer::new(ExplorerConfig::default());
let output = explorer.render_dependency_graph(&block);
assert!(output.contains("no_deps"));
assert!(output.contains("no data dependencies"));
}
#[test]
fn test_cfg_renderer_empty_blocks() {
let renderer = CfgRenderer;
let output = renderer.render(&[]);
assert!(output.contains("empty CFG"));
}
}