Skip to main content

cgp/analysis/
explain.rs

1//! `cgp explain` — Static code analysis for PTX, SIMD assembly, and WGSL shaders.
2//! Spec section 2.7: wraps trueno-explain or performs inline analysis.
3//! Detects register pressure, instruction mix, and common performance pitfalls.
4
5use anyhow::Result;
6use std::path::Path;
7
8/// Analyze a PTX file for performance-relevant patterns.
9pub fn analyze_ptx(source: &str) -> PtxAnalysis {
10    let lines: Vec<&str> = source.lines().collect();
11    let total_instructions = lines
12        .iter()
13        .filter(|l| {
14            let trimmed = l.trim();
15            !trimmed.is_empty()
16                && !trimmed.starts_with("//")
17                && !trimmed.starts_with('.')
18                && !trimmed.starts_with('{')
19                && !trimmed.starts_with('}')
20        })
21        .count();
22
23    // Count instruction types
24    let mut memory_ops = 0u32;
25    let mut compute_ops = 0u32;
26    let mut control_ops = 0u32;
27    let mut sync_ops = 0u32;
28    let mut shared_ops = 0u32;
29    let mut registers_declared = 0u32;
30    let mut has_wmma = false;
31    let mut has_fma = false;
32    let mut warnings: Vec<String> = Vec::new();
33
34    for line in &lines {
35        let trimmed = line.trim();
36
37        // Register declarations
38        if trimmed.starts_with(".reg") {
39            if let Some(count_str) = trimmed.split('<').nth(1).and_then(|s| s.split('>').next()) {
40                if let Ok(count) = count_str.parse::<u32>() {
41                    registers_declared += count;
42                }
43            }
44        }
45
46        // Memory operations
47        if trimmed.starts_with("ld.") || trimmed.starts_with("st.") {
48            memory_ops += 1;
49            if trimmed.contains(".global") {
50                // Global memory ops are expensive
51            }
52            if trimmed.contains(".shared") {
53                shared_ops += 1;
54            }
55        }
56
57        // Compute operations
58        if trimmed.starts_with("add.")
59            || trimmed.starts_with("mul.")
60            || trimmed.starts_with("mad.")
61            || trimmed.starts_with("fma.")
62        {
63            compute_ops += 1;
64            if trimmed.starts_with("fma.") || trimmed.starts_with("mad.") {
65                has_fma = true;
66            }
67        }
68
69        // Control flow
70        if trimmed.starts_with("bra") || trimmed.starts_with('@') {
71            control_ops += 1;
72            // Predicated instructions with data-dependent condition
73            if trimmed.starts_with("@%p") && trimmed.contains("bra") {
74                warnings.push("Data-dependent branch may cause warp divergence".to_string());
75            }
76        }
77
78        // Synchronization
79        if trimmed.starts_with("bar.") {
80            sync_ops += 1;
81        }
82
83        // WMMA (tensor core) instructions
84        if trimmed.contains("wmma.") || trimmed.contains("mma.") {
85            has_wmma = true;
86        }
87    }
88
89    // Compute/memory ratio (higher = more compute-bound)
90    let compute_memory_ratio = if memory_ops > 0 {
91        compute_ops as f64 / memory_ops as f64
92    } else {
93        f64::INFINITY
94    };
95
96    // Register pressure warning
97    if registers_declared > 128 {
98        warnings.push(format!(
99            "High register usage ({registers_declared}) may limit occupancy"
100        ));
101    }
102
103    // Sync overhead
104    if sync_ops > 2 {
105        warnings.push(format!(
106            "{sync_ops} barrier syncs — review if all are necessary"
107        ));
108    }
109
110    PtxAnalysis {
111        total_instructions: total_instructions as u32,
112        memory_ops,
113        compute_ops,
114        control_ops,
115        sync_ops,
116        shared_ops,
117        registers_declared,
118        has_wmma,
119        has_fma,
120        compute_memory_ratio,
121        warnings,
122    }
123}
124
125/// Result of PTX static analysis.
126#[derive(Debug)]
127pub struct PtxAnalysis {
128    pub total_instructions: u32,
129    pub memory_ops: u32,
130    pub compute_ops: u32,
131    pub control_ops: u32,
132    pub sync_ops: u32,
133    pub shared_ops: u32,
134    pub registers_declared: u32,
135    pub has_wmma: bool,
136    pub has_fma: bool,
137    pub compute_memory_ratio: f64,
138    pub warnings: Vec<String>,
139}
140
141/// Analyze a WGSL shader for compute patterns.
142pub fn analyze_wgsl(source: &str) -> WgslAnalysis {
143    let lines: Vec<&str> = source.lines().collect();
144    let total_lines = lines.len() as u32;
145
146    let mut workgroup_size = None;
147    let mut bindings = 0u32;
148    let mut has_atomics = false;
149    let mut has_shared = false;
150    let mut warnings: Vec<String> = Vec::new();
151
152    for line in &lines {
153        let trimmed = line.trim();
154
155        if trimmed.contains("@workgroup_size") {
156            let start = trimmed.find('(').map(|i| i + 1);
157            let end = trimmed.find(')');
158            if let (Some(s), Some(e)) = (start, end) {
159                workgroup_size = Some(trimmed[s..e].to_string());
160            }
161        }
162
163        if trimmed.contains("@binding") {
164            bindings += 1;
165        }
166
167        if trimmed.contains("atomicAdd") || trimmed.contains("atomicStore") {
168            has_atomics = true;
169        }
170
171        if trimmed.contains("var<workgroup>") {
172            has_shared = true;
173        }
174    }
175
176    // Workgroup size warnings
177    if let Some(ref ws) = workgroup_size {
178        let total: u32 = ws
179            .split(',')
180            .filter_map(|s| s.trim().parse::<u32>().ok())
181            .product();
182        if total < 64 {
183            warnings.push(format!(
184                "Workgroup size ({ws}) = {total} threads — consider >=64 for GPU occupancy"
185            ));
186        }
187        if total > 1024 {
188            warnings.push(format!(
189                "Workgroup size ({ws}) = {total} threads — exceeds common hardware limit (1024)"
190            ));
191        }
192    }
193
194    WgslAnalysis {
195        total_lines,
196        workgroup_size,
197        bindings,
198        has_atomics,
199        has_shared,
200        warnings,
201    }
202}
203
204/// Result of WGSL static analysis.
205#[derive(Debug)]
206pub struct WgslAnalysis {
207    pub total_lines: u32,
208    pub workgroup_size: Option<String>,
209    pub bindings: u32,
210    pub has_atomics: bool,
211    pub has_shared: bool,
212    pub warnings: Vec<String>,
213}
214
215/// Run the explain command.
216pub fn run_explain(target: &str, kernel: Option<&str>) -> Result<()> {
217    println!("\n=== CGP Explain: {target} ===\n");
218
219    match target {
220        "ptx" => {
221            let kernel_name = kernel.unwrap_or("*");
222            println!("  Target: PTX (CUDA assembly)");
223            println!("  Kernel: {kernel_name}");
224
225            // Try to find PTX files
226            let ptx_path = find_ptx_file(kernel_name);
227            match ptx_path {
228                Some(path) => {
229                    let source = std::fs::read_to_string(&path)?;
230                    let analysis = analyze_ptx(&source);
231                    println!("  File: {path}");
232                    render_ptx_analysis(&analysis);
233                }
234                None => {
235                    println!("  No PTX file found for kernel '{kernel_name}'.");
236                    println!("  Generate with: cargo build -p trueno-gpu --features cuda");
237                    println!("  Or provide path: cgp explain ptx --kernel path/to/kernel.ptx");
238                }
239            }
240        }
241        "wgsl" | "shader" => {
242            let shader_path = kernel.unwrap_or("*.wgsl");
243            println!("  Target: WGSL (WebGPU shader)");
244
245            if Path::new(shader_path).exists() {
246                let source = std::fs::read_to_string(shader_path)?;
247                let analysis = analyze_wgsl(&source);
248                println!("  File: {shader_path}");
249                render_wgsl_analysis(&analysis);
250            } else {
251                println!("  Shader file not found: {shader_path}");
252                println!(
253                    "  Provide path: cgp explain wgsl --kernel src/backends/gpu/shaders/gemm.wgsl"
254                );
255            }
256        }
257        "simd" => {
258            println!("  Target: SIMD (x86/ARM assembly analysis)");
259            println!("  Analysis: instruction mix, vectorization rate, register usage");
260            println!(
261                "  Use: cgp profile simd --function <fn> --arch avx2 for runtime SIMD analysis"
262            );
263        }
264        _ => {
265            println!("  Unknown target: {target}");
266            println!("  Supported: ptx, wgsl, simd");
267        }
268    }
269
270    println!();
271    Ok(())
272}
273
274/// Render PTX analysis results.
275fn render_ptx_analysis(analysis: &PtxAnalysis) {
276    println!("\n  Instruction Mix:");
277    println!("    Total instructions: {}", analysis.total_instructions);
278    println!("    Compute ops:       {}", analysis.compute_ops);
279    println!("    Memory ops:        {}", analysis.memory_ops);
280    println!("    Control flow:      {}", analysis.control_ops);
281    println!("    Sync barriers:     {}", analysis.sync_ops);
282    println!("    Shared memory ops: {}", analysis.shared_ops);
283
284    println!(
285        "\n  Compute/Memory Ratio: {:.2}",
286        analysis.compute_memory_ratio
287    );
288    if analysis.compute_memory_ratio < 1.0 {
289        println!("    Status: MEMORY-INTENSIVE (more loads than compute)");
290    } else if analysis.compute_memory_ratio > 4.0 {
291        println!("    Status: COMPUTE-INTENSIVE (good arithmetic density)");
292    } else {
293        println!("    Status: BALANCED");
294    }
295
296    println!("\n  Features:");
297    println!("    Registers declared: {}", analysis.registers_declared);
298    println!(
299        "    Tensor cores (WMMA/MMA): {}",
300        if analysis.has_wmma { "YES" } else { "no" }
301    );
302    println!(
303        "    FMA instructions: {}",
304        if analysis.has_fma { "YES" } else { "no" }
305    );
306
307    if !analysis.warnings.is_empty() {
308        println!("\n  Warnings:");
309        for w in &analysis.warnings {
310            println!("    \x1b[33m[WARN]\x1b[0m {w}");
311        }
312    }
313}
314
315/// Render WGSL analysis results.
316fn render_wgsl_analysis(analysis: &WgslAnalysis) {
317    println!("\n  Shader Info:");
318    println!("    Lines: {}", analysis.total_lines);
319    println!(
320        "    Workgroup size: {}",
321        analysis
322            .workgroup_size
323            .as_deref()
324            .unwrap_or("not specified")
325    );
326    println!("    Bindings: {}", analysis.bindings);
327    println!(
328        "    Atomics: {}",
329        if analysis.has_atomics { "YES" } else { "no" }
330    );
331    println!(
332        "    Shared memory: {}",
333        if analysis.has_shared { "YES" } else { "no" }
334    );
335
336    if !analysis.warnings.is_empty() {
337        println!("\n  Warnings:");
338        for w in &analysis.warnings {
339            println!("    \x1b[33m[WARN]\x1b[0m {w}");
340        }
341    }
342}
343
344/// Find a PTX file for a given kernel name.
345fn find_ptx_file(kernel_name: &str) -> Option<String> {
346    // Check if kernel_name is already a path
347    if Path::new(kernel_name).exists() {
348        return Some(kernel_name.to_string());
349    }
350
351    // Search common locations
352    let search_dirs = ["src/backends/gpu/kernels", "trueno-gpu/src", "."];
353    for dir in &search_dirs {
354        if let Ok(entries) = std::fs::read_dir(dir) {
355            for entry in entries.flatten() {
356                let name = entry.file_name();
357                let name_str = name.to_string_lossy();
358                if name_str.ends_with(".ptx")
359                    && (kernel_name == "*" || name_str.contains(kernel_name))
360                {
361                    return Some(entry.path().display().to_string());
362                }
363            }
364        }
365    }
366    None
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_analyze_ptx_basic() {
375        let ptx = r#"
376.version 8.0
377.target sm_89
378.entry gemm_kernel {
379    .reg .f32 %f<32>;
380    .reg .pred %p<4>;
381    ld.global.f32 %f1, [%rd1];
382    ld.global.f32 %f2, [%rd2];
383    fma.rn.f32 %f3, %f1, %f2, %f0;
384    st.global.f32 [%rd3], %f3;
385    bar.sync 0;
386}
387"#;
388        let analysis = analyze_ptx(ptx);
389        assert!(analysis.memory_ops >= 3); // 2 loads + 1 store
390        assert!(analysis.compute_ops >= 1); // fma
391        assert!(analysis.has_fma);
392        assert!(analysis.sync_ops >= 1);
393        assert!(analysis.registers_declared >= 32);
394    }
395
396    #[test]
397    fn test_analyze_ptx_wmma() {
398        let ptx = "wmma.mma.sync.aligned.m16n16k16.row.col.f32.f16 {a}, {b}, {c};";
399        let analysis = analyze_ptx(ptx);
400        assert!(analysis.has_wmma);
401    }
402
403    #[test]
404    fn test_analyze_ptx_high_register_warning() {
405        let ptx = ".reg .f32 %f<256>;";
406        let analysis = analyze_ptx(ptx);
407        assert!(analysis.registers_declared >= 256);
408        assert!(!analysis.warnings.is_empty());
409    }
410
411    #[test]
412    fn test_analyze_wgsl_basic() {
413        let wgsl = r#"
414@group(0) @binding(0) var<storage, read> a: array<f32>;
415@group(0) @binding(1) var<storage, read_write> b: array<f32>;
416
417@compute @workgroup_size(256, 1, 1)
418fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
419    b[gid.x] = a[gid.x] * 2.0;
420}
421"#;
422        let analysis = analyze_wgsl(wgsl);
423        assert_eq!(analysis.bindings, 2);
424        assert_eq!(analysis.workgroup_size.as_deref(), Some("256, 1, 1"));
425        assert!(!analysis.has_atomics);
426    }
427
428    #[test]
429    fn test_analyze_wgsl_small_workgroup() {
430        let wgsl = "@compute @workgroup_size(8, 1, 1)\nfn main() {}";
431        let analysis = analyze_wgsl(wgsl);
432        assert!(!analysis.warnings.is_empty());
433    }
434
435    #[test]
436    fn test_run_explain_ptx() {
437        let result = run_explain("ptx", None);
438        assert!(result.is_ok());
439    }
440
441    #[test]
442    fn test_run_explain_simd() {
443        let result = run_explain("simd", None);
444        assert!(result.is_ok());
445    }
446
447    #[test]
448    fn test_run_explain_unknown() {
449        let result = run_explain("unknown_target", None);
450        assert!(result.is_ok());
451    }
452}