1use anyhow::Result;
6use std::path::Path;
7
8pub fn analyze_ptx(source: &str) -> PtxAnalysis {
10 let lines: Vec<&str> = source.lines().collect();
11 let total_instructions = lines
12 .iter()
13 .filter(|l| {
14 let trimmed = l.trim();
15 !trimmed.is_empty()
16 && !trimmed.starts_with("//")
17 && !trimmed.starts_with('.')
18 && !trimmed.starts_with('{')
19 && !trimmed.starts_with('}')
20 })
21 .count();
22
23 let mut memory_ops = 0u32;
25 let mut compute_ops = 0u32;
26 let mut control_ops = 0u32;
27 let mut sync_ops = 0u32;
28 let mut shared_ops = 0u32;
29 let mut registers_declared = 0u32;
30 let mut has_wmma = false;
31 let mut has_fma = false;
32 let mut warnings: Vec<String> = Vec::new();
33
34 for line in &lines {
35 let trimmed = line.trim();
36
37 if trimmed.starts_with(".reg") {
39 if let Some(count_str) = trimmed.split('<').nth(1).and_then(|s| s.split('>').next()) {
40 if let Ok(count) = count_str.parse::<u32>() {
41 registers_declared += count;
42 }
43 }
44 }
45
46 if trimmed.starts_with("ld.") || trimmed.starts_with("st.") {
48 memory_ops += 1;
49 if trimmed.contains(".global") {
50 }
52 if trimmed.contains(".shared") {
53 shared_ops += 1;
54 }
55 }
56
57 if trimmed.starts_with("add.")
59 || trimmed.starts_with("mul.")
60 || trimmed.starts_with("mad.")
61 || trimmed.starts_with("fma.")
62 {
63 compute_ops += 1;
64 if trimmed.starts_with("fma.") || trimmed.starts_with("mad.") {
65 has_fma = true;
66 }
67 }
68
69 if trimmed.starts_with("bra") || trimmed.starts_with('@') {
71 control_ops += 1;
72 if trimmed.starts_with("@%p") && trimmed.contains("bra") {
74 warnings.push("Data-dependent branch may cause warp divergence".to_string());
75 }
76 }
77
78 if trimmed.starts_with("bar.") {
80 sync_ops += 1;
81 }
82
83 if trimmed.contains("wmma.") || trimmed.contains("mma.") {
85 has_wmma = true;
86 }
87 }
88
89 let compute_memory_ratio = if memory_ops > 0 {
91 compute_ops as f64 / memory_ops as f64
92 } else {
93 f64::INFINITY
94 };
95
96 if registers_declared > 128 {
98 warnings.push(format!(
99 "High register usage ({registers_declared}) may limit occupancy"
100 ));
101 }
102
103 if sync_ops > 2 {
105 warnings.push(format!(
106 "{sync_ops} barrier syncs — review if all are necessary"
107 ));
108 }
109
110 PtxAnalysis {
111 total_instructions: total_instructions as u32,
112 memory_ops,
113 compute_ops,
114 control_ops,
115 sync_ops,
116 shared_ops,
117 registers_declared,
118 has_wmma,
119 has_fma,
120 compute_memory_ratio,
121 warnings,
122 }
123}
124
125#[derive(Debug)]
127pub struct PtxAnalysis {
128 pub total_instructions: u32,
129 pub memory_ops: u32,
130 pub compute_ops: u32,
131 pub control_ops: u32,
132 pub sync_ops: u32,
133 pub shared_ops: u32,
134 pub registers_declared: u32,
135 pub has_wmma: bool,
136 pub has_fma: bool,
137 pub compute_memory_ratio: f64,
138 pub warnings: Vec<String>,
139}
140
141pub fn analyze_wgsl(source: &str) -> WgslAnalysis {
143 let lines: Vec<&str> = source.lines().collect();
144 let total_lines = lines.len() as u32;
145
146 let mut workgroup_size = None;
147 let mut bindings = 0u32;
148 let mut has_atomics = false;
149 let mut has_shared = false;
150 let mut warnings: Vec<String> = Vec::new();
151
152 for line in &lines {
153 let trimmed = line.trim();
154
155 if trimmed.contains("@workgroup_size") {
156 let start = trimmed.find('(').map(|i| i + 1);
157 let end = trimmed.find(')');
158 if let (Some(s), Some(e)) = (start, end) {
159 workgroup_size = Some(trimmed[s..e].to_string());
160 }
161 }
162
163 if trimmed.contains("@binding") {
164 bindings += 1;
165 }
166
167 if trimmed.contains("atomicAdd") || trimmed.contains("atomicStore") {
168 has_atomics = true;
169 }
170
171 if trimmed.contains("var<workgroup>") {
172 has_shared = true;
173 }
174 }
175
176 if let Some(ref ws) = workgroup_size {
178 let total: u32 = ws
179 .split(',')
180 .filter_map(|s| s.trim().parse::<u32>().ok())
181 .product();
182 if total < 64 {
183 warnings.push(format!(
184 "Workgroup size ({ws}) = {total} threads — consider >=64 for GPU occupancy"
185 ));
186 }
187 if total > 1024 {
188 warnings.push(format!(
189 "Workgroup size ({ws}) = {total} threads — exceeds common hardware limit (1024)"
190 ));
191 }
192 }
193
194 WgslAnalysis {
195 total_lines,
196 workgroup_size,
197 bindings,
198 has_atomics,
199 has_shared,
200 warnings,
201 }
202}
203
204#[derive(Debug)]
206pub struct WgslAnalysis {
207 pub total_lines: u32,
208 pub workgroup_size: Option<String>,
209 pub bindings: u32,
210 pub has_atomics: bool,
211 pub has_shared: bool,
212 pub warnings: Vec<String>,
213}
214
215pub fn run_explain(target: &str, kernel: Option<&str>) -> Result<()> {
217 println!("\n=== CGP Explain: {target} ===\n");
218
219 match target {
220 "ptx" => {
221 let kernel_name = kernel.unwrap_or("*");
222 println!(" Target: PTX (CUDA assembly)");
223 println!(" Kernel: {kernel_name}");
224
225 let ptx_path = find_ptx_file(kernel_name);
227 match ptx_path {
228 Some(path) => {
229 let source = std::fs::read_to_string(&path)?;
230 let analysis = analyze_ptx(&source);
231 println!(" File: {path}");
232 render_ptx_analysis(&analysis);
233 }
234 None => {
235 println!(" No PTX file found for kernel '{kernel_name}'.");
236 println!(" Generate with: cargo build -p trueno-gpu --features cuda");
237 println!(" Or provide path: cgp explain ptx --kernel path/to/kernel.ptx");
238 }
239 }
240 }
241 "wgsl" | "shader" => {
242 let shader_path = kernel.unwrap_or("*.wgsl");
243 println!(" Target: WGSL (WebGPU shader)");
244
245 if Path::new(shader_path).exists() {
246 let source = std::fs::read_to_string(shader_path)?;
247 let analysis = analyze_wgsl(&source);
248 println!(" File: {shader_path}");
249 render_wgsl_analysis(&analysis);
250 } else {
251 println!(" Shader file not found: {shader_path}");
252 println!(
253 " Provide path: cgp explain wgsl --kernel src/backends/gpu/shaders/gemm.wgsl"
254 );
255 }
256 }
257 "simd" => {
258 println!(" Target: SIMD (x86/ARM assembly analysis)");
259 println!(" Analysis: instruction mix, vectorization rate, register usage");
260 println!(
261 " Use: cgp profile simd --function <fn> --arch avx2 for runtime SIMD analysis"
262 );
263 }
264 _ => {
265 println!(" Unknown target: {target}");
266 println!(" Supported: ptx, wgsl, simd");
267 }
268 }
269
270 println!();
271 Ok(())
272}
273
274fn render_ptx_analysis(analysis: &PtxAnalysis) {
276 println!("\n Instruction Mix:");
277 println!(" Total instructions: {}", analysis.total_instructions);
278 println!(" Compute ops: {}", analysis.compute_ops);
279 println!(" Memory ops: {}", analysis.memory_ops);
280 println!(" Control flow: {}", analysis.control_ops);
281 println!(" Sync barriers: {}", analysis.sync_ops);
282 println!(" Shared memory ops: {}", analysis.shared_ops);
283
284 println!(
285 "\n Compute/Memory Ratio: {:.2}",
286 analysis.compute_memory_ratio
287 );
288 if analysis.compute_memory_ratio < 1.0 {
289 println!(" Status: MEMORY-INTENSIVE (more loads than compute)");
290 } else if analysis.compute_memory_ratio > 4.0 {
291 println!(" Status: COMPUTE-INTENSIVE (good arithmetic density)");
292 } else {
293 println!(" Status: BALANCED");
294 }
295
296 println!("\n Features:");
297 println!(" Registers declared: {}", analysis.registers_declared);
298 println!(
299 " Tensor cores (WMMA/MMA): {}",
300 if analysis.has_wmma { "YES" } else { "no" }
301 );
302 println!(
303 " FMA instructions: {}",
304 if analysis.has_fma { "YES" } else { "no" }
305 );
306
307 if !analysis.warnings.is_empty() {
308 println!("\n Warnings:");
309 for w in &analysis.warnings {
310 println!(" \x1b[33m[WARN]\x1b[0m {w}");
311 }
312 }
313}
314
315fn render_wgsl_analysis(analysis: &WgslAnalysis) {
317 println!("\n Shader Info:");
318 println!(" Lines: {}", analysis.total_lines);
319 println!(
320 " Workgroup size: {}",
321 analysis
322 .workgroup_size
323 .as_deref()
324 .unwrap_or("not specified")
325 );
326 println!(" Bindings: {}", analysis.bindings);
327 println!(
328 " Atomics: {}",
329 if analysis.has_atomics { "YES" } else { "no" }
330 );
331 println!(
332 " Shared memory: {}",
333 if analysis.has_shared { "YES" } else { "no" }
334 );
335
336 if !analysis.warnings.is_empty() {
337 println!("\n Warnings:");
338 for w in &analysis.warnings {
339 println!(" \x1b[33m[WARN]\x1b[0m {w}");
340 }
341 }
342}
343
344fn find_ptx_file(kernel_name: &str) -> Option<String> {
346 if Path::new(kernel_name).exists() {
348 return Some(kernel_name.to_string());
349 }
350
351 let search_dirs = ["src/backends/gpu/kernels", "trueno-gpu/src", "."];
353 for dir in &search_dirs {
354 if let Ok(entries) = std::fs::read_dir(dir) {
355 for entry in entries.flatten() {
356 let name = entry.file_name();
357 let name_str = name.to_string_lossy();
358 if name_str.ends_with(".ptx")
359 && (kernel_name == "*" || name_str.contains(kernel_name))
360 {
361 return Some(entry.path().display().to_string());
362 }
363 }
364 }
365 }
366 None
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_analyze_ptx_basic() {
375 let ptx = r#"
376.version 8.0
377.target sm_89
378.entry gemm_kernel {
379 .reg .f32 %f<32>;
380 .reg .pred %p<4>;
381 ld.global.f32 %f1, [%rd1];
382 ld.global.f32 %f2, [%rd2];
383 fma.rn.f32 %f3, %f1, %f2, %f0;
384 st.global.f32 [%rd3], %f3;
385 bar.sync 0;
386}
387"#;
388 let analysis = analyze_ptx(ptx);
389 assert!(analysis.memory_ops >= 3); assert!(analysis.compute_ops >= 1); assert!(analysis.has_fma);
392 assert!(analysis.sync_ops >= 1);
393 assert!(analysis.registers_declared >= 32);
394 }
395
396 #[test]
397 fn test_analyze_ptx_wmma() {
398 let ptx = "wmma.mma.sync.aligned.m16n16k16.row.col.f32.f16 {a}, {b}, {c};";
399 let analysis = analyze_ptx(ptx);
400 assert!(analysis.has_wmma);
401 }
402
403 #[test]
404 fn test_analyze_ptx_high_register_warning() {
405 let ptx = ".reg .f32 %f<256>;";
406 let analysis = analyze_ptx(ptx);
407 assert!(analysis.registers_declared >= 256);
408 assert!(!analysis.warnings.is_empty());
409 }
410
411 #[test]
412 fn test_analyze_wgsl_basic() {
413 let wgsl = r#"
414@group(0) @binding(0) var<storage, read> a: array<f32>;
415@group(0) @binding(1) var<storage, read_write> b: array<f32>;
416
417@compute @workgroup_size(256, 1, 1)
418fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
419 b[gid.x] = a[gid.x] * 2.0;
420}
421"#;
422 let analysis = analyze_wgsl(wgsl);
423 assert_eq!(analysis.bindings, 2);
424 assert_eq!(analysis.workgroup_size.as_deref(), Some("256, 1, 1"));
425 assert!(!analysis.has_atomics);
426 }
427
428 #[test]
429 fn test_analyze_wgsl_small_workgroup() {
430 let wgsl = "@compute @workgroup_size(8, 1, 1)\nfn main() {}";
431 let analysis = analyze_wgsl(wgsl);
432 assert!(!analysis.warnings.is_empty());
433 }
434
435 #[test]
436 fn test_run_explain_ptx() {
437 let result = run_explain("ptx", None);
438 assert!(result.is_ok());
439 }
440
441 #[test]
442 fn test_run_explain_simd() {
443 let result = run_explain("simd", None);
444 assert!(result.is_ok());
445 }
446
447 #[test]
448 fn test_run_explain_unknown() {
449 let result = run_explain("unknown_target", None);
450 assert!(result.is_ok());
451 }
452}