Skip to main content

cgp/analysis/
explain.rs

1//! `cgp explain` — Static code analysis for PTX, SIMD assembly, and WGSL shaders.
2//! Spec section 2.7: wraps trueno-explain or performs inline analysis.
3//! Detects register pressure, instruction mix, and common performance pitfalls.
4
5use anyhow::Result;
6use std::path::Path;
7
8/// Analyze a PTX file for performance-relevant patterns.
9pub fn analyze_ptx(source: &str) -> PtxAnalysis {
10    let lines: Vec<&str> = source.lines().collect();
11    let total_instructions = count_ptx_instructions(&lines);
12
13    let mut counters = PtxCounters::default();
14    for line in &lines {
15        classify_ptx_line(line, &mut counters);
16    }
17    append_ptx_pressure_warnings(&mut counters);
18
19    let compute_memory_ratio = if counters.memory_ops > 0 {
20        f64::from(counters.compute_ops) / f64::from(counters.memory_ops)
21    } else {
22        f64::INFINITY
23    };
24
25    PtxAnalysis {
26        total_instructions,
27        memory_ops: counters.memory_ops,
28        compute_ops: counters.compute_ops,
29        control_ops: counters.control_ops,
30        sync_ops: counters.sync_ops,
31        shared_ops: counters.shared_ops,
32        registers_declared: counters.registers_declared,
33        has_wmma: counters.has_wmma,
34        has_fma: counters.has_fma,
35        compute_memory_ratio,
36        warnings: counters.warnings,
37    }
38}
39
40/// Mutable accumulator used while walking PTX lines.
41#[derive(Default)]
42struct PtxCounters {
43    memory_ops: u32,
44    compute_ops: u32,
45    control_ops: u32,
46    sync_ops: u32,
47    shared_ops: u32,
48    registers_declared: u32,
49    has_wmma: bool,
50    has_fma: bool,
51    warnings: Vec<String>,
52}
53
54/// Count non-trivial PTX instructions (skip blanks, comments, directives, braces).
55fn count_ptx_instructions(lines: &[&str]) -> u32 {
56    u32::try_from(
57        lines
58            .iter()
59            .filter(|l| {
60                let t = l.trim();
61                !t.is_empty()
62                    && !t.starts_with("//")
63                    && !t.starts_with('.')
64                    && !t.starts_with('{')
65                    && !t.starts_with('}')
66            })
67            .count(),
68    )
69    .unwrap_or(u32::MAX)
70}
71
72/// Dispatch a single PTX line to all per-category counters.
73fn classify_ptx_line(line: &str, counters: &mut PtxCounters) {
74    let trimmed = line.trim();
75    tally_register_declaration(trimmed, counters);
76    tally_memory_op(trimmed, counters);
77    tally_compute_op(trimmed, counters);
78    tally_control_op(trimmed, counters);
79    if trimmed.starts_with("bar.") {
80        counters.sync_ops += 1;
81    }
82    if trimmed.contains("wmma.") || trimmed.contains("mma.") {
83        counters.has_wmma = true;
84    }
85}
86
87/// Parse `.reg .TYPE %name<N>;` and accumulate the declared register count.
88fn tally_register_declaration(trimmed: &str, counters: &mut PtxCounters) {
89    if !trimmed.starts_with(".reg") {
90        return;
91    }
92    let Some(count_str) = trimmed.split('<').nth(1).and_then(|s| s.split('>').next()) else {
93        return;
94    };
95    let Ok(count) = count_str.parse::<u32>() else {
96        return;
97    };
98    counters.registers_declared += count;
99}
100
101/// Track `ld.*` / `st.*` operations, flagging shared-memory traffic separately.
102fn tally_memory_op(trimmed: &str, counters: &mut PtxCounters) {
103    if !(trimmed.starts_with("ld.") || trimmed.starts_with("st.")) {
104        return;
105    }
106    counters.memory_ops += 1;
107    if trimmed.contains(".shared") {
108        counters.shared_ops += 1;
109    }
110}
111
112/// Track arithmetic ops and record whether FMA-class instructions are present.
113fn tally_compute_op(trimmed: &str, counters: &mut PtxCounters) {
114    let is_compute = trimmed.starts_with("add.")
115        || trimmed.starts_with("mul.")
116        || trimmed.starts_with("mad.")
117        || trimmed.starts_with("fma.");
118    if !is_compute {
119        return;
120    }
121    counters.compute_ops += 1;
122    if trimmed.starts_with("fma.") || trimmed.starts_with("mad.") {
123        counters.has_fma = true;
124    }
125}
126
127/// Track branches and predicated branches; warn on data-dependent divergence.
128fn tally_control_op(trimmed: &str, counters: &mut PtxCounters) {
129    if !(trimmed.starts_with("bra") || trimmed.starts_with('@')) {
130        return;
131    }
132    counters.control_ops += 1;
133    if trimmed.starts_with("@%p") && trimmed.contains("bra") {
134        counters
135            .warnings
136            .push("Data-dependent branch may cause warp divergence".to_string());
137    }
138}
139
140/// Add register-pressure and sync-overhead warnings after counting is complete.
141fn append_ptx_pressure_warnings(counters: &mut PtxCounters) {
142    if counters.registers_declared > 128 {
143        counters.warnings.push(format!(
144            "High register usage ({}) may limit occupancy",
145            counters.registers_declared
146        ));
147    }
148    if counters.sync_ops > 2 {
149        counters.warnings.push(format!(
150            "{} barrier syncs — review if all are necessary",
151            counters.sync_ops
152        ));
153    }
154}
155
156/// Result of PTX static analysis.
157#[derive(Debug)]
158pub struct PtxAnalysis {
159    pub total_instructions: u32,
160    pub memory_ops: u32,
161    pub compute_ops: u32,
162    pub control_ops: u32,
163    pub sync_ops: u32,
164    pub shared_ops: u32,
165    pub registers_declared: u32,
166    pub has_wmma: bool,
167    pub has_fma: bool,
168    pub compute_memory_ratio: f64,
169    pub warnings: Vec<String>,
170}
171
172/// Analyze a WGSL shader for compute patterns.
173pub fn analyze_wgsl(source: &str) -> WgslAnalysis {
174    let lines: Vec<&str> = source.lines().collect();
175    let total_lines = u32::try_from(lines.len()).unwrap_or(u32::MAX);
176
177    let mut counters = WgslCounters::default();
178    for line in &lines {
179        classify_wgsl_line(line, &mut counters);
180    }
181    append_wgsl_workgroup_warnings(&mut counters);
182
183    WgslAnalysis {
184        total_lines,
185        workgroup_size: counters.workgroup_size,
186        bindings: counters.bindings,
187        has_atomics: counters.has_atomics,
188        has_shared: counters.has_shared,
189        warnings: counters.warnings,
190    }
191}
192
193/// Mutable accumulator used while walking WGSL lines.
194#[derive(Default)]
195struct WgslCounters {
196    workgroup_size: Option<String>,
197    bindings: u32,
198    has_atomics: bool,
199    has_shared: bool,
200    warnings: Vec<String>,
201}
202
203/// Dispatch a single WGSL line to each feature detector.
204fn classify_wgsl_line(line: &str, counters: &mut WgslCounters) {
205    let trimmed = line.trim();
206    capture_workgroup_size(trimmed, counters);
207    if trimmed.contains("@binding") {
208        counters.bindings += 1;
209    }
210    if trimmed.contains("atomicAdd") || trimmed.contains("atomicStore") {
211        counters.has_atomics = true;
212    }
213    if trimmed.contains("var<workgroup>") {
214        counters.has_shared = true;
215    }
216}
217
218/// Extract the literal arguments of `@workgroup_size(...)` when present.
219fn capture_workgroup_size(trimmed: &str, counters: &mut WgslCounters) {
220    if !trimmed.contains("@workgroup_size") {
221        return;
222    }
223    let Some(start) = trimmed.find('(').map(|i| i + 1) else {
224        return;
225    };
226    let Some(end) = trimmed.find(')') else {
227        return;
228    };
229    counters.workgroup_size = Some(trimmed[start..end].to_string());
230}
231
232/// Emit warnings when the workgroup size is too small or too large for typical GPUs.
233fn append_wgsl_workgroup_warnings(counters: &mut WgslCounters) {
234    let Some(ref ws) = counters.workgroup_size else {
235        return;
236    };
237    let total: u32 = ws
238        .split(',')
239        .filter_map(|s| s.trim().parse::<u32>().ok())
240        .product();
241    if total < 64 {
242        counters.warnings.push(format!(
243            "Workgroup size ({ws}) = {total} threads — consider >=64 for GPU occupancy"
244        ));
245    }
246    if total > 1024 {
247        counters.warnings.push(format!(
248            "Workgroup size ({ws}) = {total} threads — exceeds common hardware limit (1024)"
249        ));
250    }
251}
252
253/// Result of WGSL static analysis.
254#[derive(Debug)]
255pub struct WgslAnalysis {
256    pub total_lines: u32,
257    pub workgroup_size: Option<String>,
258    pub bindings: u32,
259    pub has_atomics: bool,
260    pub has_shared: bool,
261    pub warnings: Vec<String>,
262}
263
264/// Run the explain command.
265pub fn run_explain(target: &str, kernel: Option<&str>) -> Result<()> {
266    println!("\n=== CGP Explain: {target} ===\n");
267
268    match target {
269        "ptx" => {
270            let kernel_name = kernel.unwrap_or("*");
271            println!("  Target: PTX (CUDA assembly)");
272            println!("  Kernel: {kernel_name}");
273
274            // Try to find PTX files
275            let ptx_path = find_ptx_file(kernel_name);
276            match ptx_path {
277                Some(path) => {
278                    let source = std::fs::read_to_string(&path)?;
279                    let analysis = analyze_ptx(&source);
280                    println!("  File: {path}");
281                    render_ptx_analysis(&analysis);
282                }
283                None => {
284                    println!("  No PTX file found for kernel '{kernel_name}'.");
285                    println!("  Generate with: cargo build -p trueno-gpu --features cuda");
286                    println!("  Or provide path: cgp explain ptx --kernel path/to/kernel.ptx");
287                }
288            }
289        }
290        "wgsl" | "shader" => {
291            let shader_path = kernel.unwrap_or("*.wgsl");
292            println!("  Target: WGSL (WebGPU shader)");
293
294            if Path::new(shader_path).exists() {
295                let source = std::fs::read_to_string(shader_path)?;
296                let analysis = analyze_wgsl(&source);
297                println!("  File: {shader_path}");
298                render_wgsl_analysis(&analysis);
299            } else {
300                println!("  Shader file not found: {shader_path}");
301                println!(
302                    "  Provide path: cgp explain wgsl --kernel src/backends/gpu/shaders/gemm.wgsl"
303                );
304            }
305        }
306        "simd" => {
307            println!("  Target: SIMD (x86/ARM assembly analysis)");
308            println!("  Analysis: instruction mix, vectorization rate, register usage");
309            println!(
310                "  Use: cgp profile simd --function <fn> --arch avx2 for runtime SIMD analysis"
311            );
312        }
313        _ => {
314            println!("  Unknown target: {target}");
315            println!("  Supported: ptx, wgsl, simd");
316        }
317    }
318
319    println!();
320    Ok(())
321}
322
323/// Render PTX analysis results.
324fn render_ptx_analysis(analysis: &PtxAnalysis) {
325    println!("\n  Instruction Mix:");
326    println!("    Total instructions: {}", analysis.total_instructions);
327    println!("    Compute ops:       {}", analysis.compute_ops);
328    println!("    Memory ops:        {}", analysis.memory_ops);
329    println!("    Control flow:      {}", analysis.control_ops);
330    println!("    Sync barriers:     {}", analysis.sync_ops);
331    println!("    Shared memory ops: {}", analysis.shared_ops);
332
333    println!(
334        "\n  Compute/Memory Ratio: {:.2}",
335        analysis.compute_memory_ratio
336    );
337    if analysis.compute_memory_ratio < 1.0 {
338        println!("    Status: MEMORY-INTENSIVE (more loads than compute)");
339    } else if analysis.compute_memory_ratio > 4.0 {
340        println!("    Status: COMPUTE-INTENSIVE (good arithmetic density)");
341    } else {
342        println!("    Status: BALANCED");
343    }
344
345    println!("\n  Features:");
346    println!("    Registers declared: {}", analysis.registers_declared);
347    println!(
348        "    Tensor cores (WMMA/MMA): {}",
349        if analysis.has_wmma { "YES" } else { "no" }
350    );
351    println!(
352        "    FMA instructions: {}",
353        if analysis.has_fma { "YES" } else { "no" }
354    );
355
356    if !analysis.warnings.is_empty() {
357        println!("\n  Warnings:");
358        for w in &analysis.warnings {
359            println!("    \x1b[33m[WARN]\x1b[0m {w}");
360        }
361    }
362}
363
364/// Render WGSL analysis results.
365fn render_wgsl_analysis(analysis: &WgslAnalysis) {
366    println!("\n  Shader Info:");
367    println!("    Lines: {}", analysis.total_lines);
368    println!(
369        "    Workgroup size: {}",
370        analysis
371            .workgroup_size
372            .as_deref()
373            .unwrap_or("not specified")
374    );
375    println!("    Bindings: {}", analysis.bindings);
376    println!(
377        "    Atomics: {}",
378        if analysis.has_atomics { "YES" } else { "no" }
379    );
380    println!(
381        "    Shared memory: {}",
382        if analysis.has_shared { "YES" } else { "no" }
383    );
384
385    if !analysis.warnings.is_empty() {
386        println!("\n  Warnings:");
387        for w in &analysis.warnings {
388            println!("    \x1b[33m[WARN]\x1b[0m {w}");
389        }
390    }
391}
392
393/// Find a PTX file for a given kernel name.
394fn find_ptx_file(kernel_name: &str) -> Option<String> {
395    // Check if kernel_name is already a path
396    if Path::new(kernel_name).exists() {
397        return Some(kernel_name.to_string());
398    }
399
400    // Search common locations
401    let search_dirs = ["src/backends/gpu/kernels", "trueno-gpu/src", "."];
402    for dir in &search_dirs {
403        if let Ok(entries) = std::fs::read_dir(dir) {
404            for entry in entries.flatten() {
405                let name = entry.file_name();
406                let name_str = name.to_string_lossy();
407                if name_str.ends_with(".ptx")
408                    && (kernel_name == "*" || name_str.contains(kernel_name))
409                {
410                    return Some(entry.path().display().to_string());
411                }
412            }
413        }
414    }
415    None
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    #[test]
423    fn test_analyze_ptx_basic() {
424        let ptx = r#"
425.version 8.0
426.target sm_89
427.entry gemm_kernel {
428    .reg .f32 %f<32>;
429    .reg .pred %p<4>;
430    ld.global.f32 %f1, [%rd1];
431    ld.global.f32 %f2, [%rd2];
432    fma.rn.f32 %f3, %f1, %f2, %f0;
433    st.global.f32 [%rd3], %f3;
434    bar.sync 0;
435}
436"#;
437        let analysis = analyze_ptx(ptx);
438        assert!(analysis.memory_ops >= 3); // 2 loads + 1 store
439        assert!(analysis.compute_ops >= 1); // fma
440        assert!(analysis.has_fma);
441        assert!(analysis.sync_ops >= 1);
442        assert!(analysis.registers_declared >= 32);
443    }
444
445    #[test]
446    fn test_analyze_ptx_wmma() {
447        let ptx = "wmma.mma.sync.aligned.m16n16k16.row.col.f32.f16 {a}, {b}, {c};";
448        let analysis = analyze_ptx(ptx);
449        assert!(analysis.has_wmma);
450    }
451
452    #[test]
453    fn test_analyze_ptx_high_register_warning() {
454        let ptx = ".reg .f32 %f<256>;";
455        let analysis = analyze_ptx(ptx);
456        assert!(analysis.registers_declared >= 256);
457        assert!(!analysis.warnings.is_empty());
458    }
459
460    #[test]
461    fn test_analyze_wgsl_basic() {
462        let wgsl = r#"
463@group(0) @binding(0) var<storage, read> a: array<f32>;
464@group(0) @binding(1) var<storage, read_write> b: array<f32>;
465
466@compute @workgroup_size(256, 1, 1)
467fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
468    b[gid.x] = a[gid.x] * 2.0;
469}
470"#;
471        let analysis = analyze_wgsl(wgsl);
472        assert_eq!(analysis.bindings, 2);
473        assert_eq!(analysis.workgroup_size.as_deref(), Some("256, 1, 1"));
474        assert!(!analysis.has_atomics);
475    }
476
477    #[test]
478    fn test_analyze_wgsl_small_workgroup() {
479        let wgsl = "@compute @workgroup_size(8, 1, 1)\nfn main() {}";
480        let analysis = analyze_wgsl(wgsl);
481        assert!(!analysis.warnings.is_empty());
482    }
483
484    #[test]
485    fn test_run_explain_ptx() {
486        let result = run_explain("ptx", None);
487        assert!(result.is_ok());
488    }
489
490    #[test]
491    fn test_run_explain_simd() {
492        let result = run_explain("simd", None);
493        assert!(result.is_ok());
494    }
495
496    #[test]
497    fn test_run_explain_unknown() {
498        let result = run_explain("unknown_target", None);
499        assert!(result.is_ok());
500    }
501}