1use anyhow::Result;
6use std::path::Path;
7
8pub fn analyze_ptx(source: &str) -> PtxAnalysis {
10 let lines: Vec<&str> = source.lines().collect();
11 let total_instructions = count_ptx_instructions(&lines);
12
13 let mut counters = PtxCounters::default();
14 for line in &lines {
15 classify_ptx_line(line, &mut counters);
16 }
17 append_ptx_pressure_warnings(&mut counters);
18
19 let compute_memory_ratio = if counters.memory_ops > 0 {
20 f64::from(counters.compute_ops) / f64::from(counters.memory_ops)
21 } else {
22 f64::INFINITY
23 };
24
25 PtxAnalysis {
26 total_instructions,
27 memory_ops: counters.memory_ops,
28 compute_ops: counters.compute_ops,
29 control_ops: counters.control_ops,
30 sync_ops: counters.sync_ops,
31 shared_ops: counters.shared_ops,
32 registers_declared: counters.registers_declared,
33 has_wmma: counters.has_wmma,
34 has_fma: counters.has_fma,
35 compute_memory_ratio,
36 warnings: counters.warnings,
37 }
38}
39
40#[derive(Default)]
42struct PtxCounters {
43 memory_ops: u32,
44 compute_ops: u32,
45 control_ops: u32,
46 sync_ops: u32,
47 shared_ops: u32,
48 registers_declared: u32,
49 has_wmma: bool,
50 has_fma: bool,
51 warnings: Vec<String>,
52}
53
54fn count_ptx_instructions(lines: &[&str]) -> u32 {
56 u32::try_from(
57 lines
58 .iter()
59 .filter(|l| {
60 let t = l.trim();
61 !t.is_empty()
62 && !t.starts_with("//")
63 && !t.starts_with('.')
64 && !t.starts_with('{')
65 && !t.starts_with('}')
66 })
67 .count(),
68 )
69 .unwrap_or(u32::MAX)
70}
71
72fn classify_ptx_line(line: &str, counters: &mut PtxCounters) {
74 let trimmed = line.trim();
75 tally_register_declaration(trimmed, counters);
76 tally_memory_op(trimmed, counters);
77 tally_compute_op(trimmed, counters);
78 tally_control_op(trimmed, counters);
79 if trimmed.starts_with("bar.") {
80 counters.sync_ops += 1;
81 }
82 if trimmed.contains("wmma.") || trimmed.contains("mma.") {
83 counters.has_wmma = true;
84 }
85}
86
87fn tally_register_declaration(trimmed: &str, counters: &mut PtxCounters) {
89 if !trimmed.starts_with(".reg") {
90 return;
91 }
92 let Some(count_str) = trimmed.split('<').nth(1).and_then(|s| s.split('>').next()) else {
93 return;
94 };
95 let Ok(count) = count_str.parse::<u32>() else {
96 return;
97 };
98 counters.registers_declared += count;
99}
100
101fn tally_memory_op(trimmed: &str, counters: &mut PtxCounters) {
103 if !(trimmed.starts_with("ld.") || trimmed.starts_with("st.")) {
104 return;
105 }
106 counters.memory_ops += 1;
107 if trimmed.contains(".shared") {
108 counters.shared_ops += 1;
109 }
110}
111
112fn tally_compute_op(trimmed: &str, counters: &mut PtxCounters) {
114 let is_compute = trimmed.starts_with("add.")
115 || trimmed.starts_with("mul.")
116 || trimmed.starts_with("mad.")
117 || trimmed.starts_with("fma.");
118 if !is_compute {
119 return;
120 }
121 counters.compute_ops += 1;
122 if trimmed.starts_with("fma.") || trimmed.starts_with("mad.") {
123 counters.has_fma = true;
124 }
125}
126
127fn tally_control_op(trimmed: &str, counters: &mut PtxCounters) {
129 if !(trimmed.starts_with("bra") || trimmed.starts_with('@')) {
130 return;
131 }
132 counters.control_ops += 1;
133 if trimmed.starts_with("@%p") && trimmed.contains("bra") {
134 counters
135 .warnings
136 .push("Data-dependent branch may cause warp divergence".to_string());
137 }
138}
139
140fn append_ptx_pressure_warnings(counters: &mut PtxCounters) {
142 if counters.registers_declared > 128 {
143 counters.warnings.push(format!(
144 "High register usage ({}) may limit occupancy",
145 counters.registers_declared
146 ));
147 }
148 if counters.sync_ops > 2 {
149 counters.warnings.push(format!(
150 "{} barrier syncs — review if all are necessary",
151 counters.sync_ops
152 ));
153 }
154}
155
156#[derive(Debug)]
158pub struct PtxAnalysis {
159 pub total_instructions: u32,
160 pub memory_ops: u32,
161 pub compute_ops: u32,
162 pub control_ops: u32,
163 pub sync_ops: u32,
164 pub shared_ops: u32,
165 pub registers_declared: u32,
166 pub has_wmma: bool,
167 pub has_fma: bool,
168 pub compute_memory_ratio: f64,
169 pub warnings: Vec<String>,
170}
171
172pub fn analyze_wgsl(source: &str) -> WgslAnalysis {
174 let lines: Vec<&str> = source.lines().collect();
175 let total_lines = u32::try_from(lines.len()).unwrap_or(u32::MAX);
176
177 let mut counters = WgslCounters::default();
178 for line in &lines {
179 classify_wgsl_line(line, &mut counters);
180 }
181 append_wgsl_workgroup_warnings(&mut counters);
182
183 WgslAnalysis {
184 total_lines,
185 workgroup_size: counters.workgroup_size,
186 bindings: counters.bindings,
187 has_atomics: counters.has_atomics,
188 has_shared: counters.has_shared,
189 warnings: counters.warnings,
190 }
191}
192
193#[derive(Default)]
195struct WgslCounters {
196 workgroup_size: Option<String>,
197 bindings: u32,
198 has_atomics: bool,
199 has_shared: bool,
200 warnings: Vec<String>,
201}
202
203fn classify_wgsl_line(line: &str, counters: &mut WgslCounters) {
205 let trimmed = line.trim();
206 capture_workgroup_size(trimmed, counters);
207 if trimmed.contains("@binding") {
208 counters.bindings += 1;
209 }
210 if trimmed.contains("atomicAdd") || trimmed.contains("atomicStore") {
211 counters.has_atomics = true;
212 }
213 if trimmed.contains("var<workgroup>") {
214 counters.has_shared = true;
215 }
216}
217
218fn capture_workgroup_size(trimmed: &str, counters: &mut WgslCounters) {
220 if !trimmed.contains("@workgroup_size") {
221 return;
222 }
223 let Some(start) = trimmed.find('(').map(|i| i + 1) else {
224 return;
225 };
226 let Some(end) = trimmed.find(')') else {
227 return;
228 };
229 counters.workgroup_size = Some(trimmed[start..end].to_string());
230}
231
232fn append_wgsl_workgroup_warnings(counters: &mut WgslCounters) {
234 let Some(ref ws) = counters.workgroup_size else {
235 return;
236 };
237 let total: u32 = ws
238 .split(',')
239 .filter_map(|s| s.trim().parse::<u32>().ok())
240 .product();
241 if total < 64 {
242 counters.warnings.push(format!(
243 "Workgroup size ({ws}) = {total} threads — consider >=64 for GPU occupancy"
244 ));
245 }
246 if total > 1024 {
247 counters.warnings.push(format!(
248 "Workgroup size ({ws}) = {total} threads — exceeds common hardware limit (1024)"
249 ));
250 }
251}
252
253#[derive(Debug)]
255pub struct WgslAnalysis {
256 pub total_lines: u32,
257 pub workgroup_size: Option<String>,
258 pub bindings: u32,
259 pub has_atomics: bool,
260 pub has_shared: bool,
261 pub warnings: Vec<String>,
262}
263
264pub fn run_explain(target: &str, kernel: Option<&str>) -> Result<()> {
266 println!("\n=== CGP Explain: {target} ===\n");
267
268 match target {
269 "ptx" => {
270 let kernel_name = kernel.unwrap_or("*");
271 println!(" Target: PTX (CUDA assembly)");
272 println!(" Kernel: {kernel_name}");
273
274 let ptx_path = find_ptx_file(kernel_name);
276 match ptx_path {
277 Some(path) => {
278 let source = std::fs::read_to_string(&path)?;
279 let analysis = analyze_ptx(&source);
280 println!(" File: {path}");
281 render_ptx_analysis(&analysis);
282 }
283 None => {
284 println!(" No PTX file found for kernel '{kernel_name}'.");
285 println!(" Generate with: cargo build -p trueno-gpu --features cuda");
286 println!(" Or provide path: cgp explain ptx --kernel path/to/kernel.ptx");
287 }
288 }
289 }
290 "wgsl" | "shader" => {
291 let shader_path = kernel.unwrap_or("*.wgsl");
292 println!(" Target: WGSL (WebGPU shader)");
293
294 if Path::new(shader_path).exists() {
295 let source = std::fs::read_to_string(shader_path)?;
296 let analysis = analyze_wgsl(&source);
297 println!(" File: {shader_path}");
298 render_wgsl_analysis(&analysis);
299 } else {
300 println!(" Shader file not found: {shader_path}");
301 println!(
302 " Provide path: cgp explain wgsl --kernel src/backends/gpu/shaders/gemm.wgsl"
303 );
304 }
305 }
306 "simd" => {
307 println!(" Target: SIMD (x86/ARM assembly analysis)");
308 println!(" Analysis: instruction mix, vectorization rate, register usage");
309 println!(
310 " Use: cgp profile simd --function <fn> --arch avx2 for runtime SIMD analysis"
311 );
312 }
313 _ => {
314 println!(" Unknown target: {target}");
315 println!(" Supported: ptx, wgsl, simd");
316 }
317 }
318
319 println!();
320 Ok(())
321}
322
323fn render_ptx_analysis(analysis: &PtxAnalysis) {
325 println!("\n Instruction Mix:");
326 println!(" Total instructions: {}", analysis.total_instructions);
327 println!(" Compute ops: {}", analysis.compute_ops);
328 println!(" Memory ops: {}", analysis.memory_ops);
329 println!(" Control flow: {}", analysis.control_ops);
330 println!(" Sync barriers: {}", analysis.sync_ops);
331 println!(" Shared memory ops: {}", analysis.shared_ops);
332
333 println!(
334 "\n Compute/Memory Ratio: {:.2}",
335 analysis.compute_memory_ratio
336 );
337 if analysis.compute_memory_ratio < 1.0 {
338 println!(" Status: MEMORY-INTENSIVE (more loads than compute)");
339 } else if analysis.compute_memory_ratio > 4.0 {
340 println!(" Status: COMPUTE-INTENSIVE (good arithmetic density)");
341 } else {
342 println!(" Status: BALANCED");
343 }
344
345 println!("\n Features:");
346 println!(" Registers declared: {}", analysis.registers_declared);
347 println!(
348 " Tensor cores (WMMA/MMA): {}",
349 if analysis.has_wmma { "YES" } else { "no" }
350 );
351 println!(
352 " FMA instructions: {}",
353 if analysis.has_fma { "YES" } else { "no" }
354 );
355
356 if !analysis.warnings.is_empty() {
357 println!("\n Warnings:");
358 for w in &analysis.warnings {
359 println!(" \x1b[33m[WARN]\x1b[0m {w}");
360 }
361 }
362}
363
364fn render_wgsl_analysis(analysis: &WgslAnalysis) {
366 println!("\n Shader Info:");
367 println!(" Lines: {}", analysis.total_lines);
368 println!(
369 " Workgroup size: {}",
370 analysis
371 .workgroup_size
372 .as_deref()
373 .unwrap_or("not specified")
374 );
375 println!(" Bindings: {}", analysis.bindings);
376 println!(
377 " Atomics: {}",
378 if analysis.has_atomics { "YES" } else { "no" }
379 );
380 println!(
381 " Shared memory: {}",
382 if analysis.has_shared { "YES" } else { "no" }
383 );
384
385 if !analysis.warnings.is_empty() {
386 println!("\n Warnings:");
387 for w in &analysis.warnings {
388 println!(" \x1b[33m[WARN]\x1b[0m {w}");
389 }
390 }
391}
392
393fn find_ptx_file(kernel_name: &str) -> Option<String> {
395 if Path::new(kernel_name).exists() {
397 return Some(kernel_name.to_string());
398 }
399
400 let search_dirs = ["src/backends/gpu/kernels", "trueno-gpu/src", "."];
402 for dir in &search_dirs {
403 if let Ok(entries) = std::fs::read_dir(dir) {
404 for entry in entries.flatten() {
405 let name = entry.file_name();
406 let name_str = name.to_string_lossy();
407 if name_str.ends_with(".ptx")
408 && (kernel_name == "*" || name_str.contains(kernel_name))
409 {
410 return Some(entry.path().display().to_string());
411 }
412 }
413 }
414 }
415 None
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_analyze_ptx_basic() {
424 let ptx = r#"
425.version 8.0
426.target sm_89
427.entry gemm_kernel {
428 .reg .f32 %f<32>;
429 .reg .pred %p<4>;
430 ld.global.f32 %f1, [%rd1];
431 ld.global.f32 %f2, [%rd2];
432 fma.rn.f32 %f3, %f1, %f2, %f0;
433 st.global.f32 [%rd3], %f3;
434 bar.sync 0;
435}
436"#;
437 let analysis = analyze_ptx(ptx);
438 assert!(analysis.memory_ops >= 3); assert!(analysis.compute_ops >= 1); assert!(analysis.has_fma);
441 assert!(analysis.sync_ops >= 1);
442 assert!(analysis.registers_declared >= 32);
443 }
444
445 #[test]
446 fn test_analyze_ptx_wmma() {
447 let ptx = "wmma.mma.sync.aligned.m16n16k16.row.col.f32.f16 {a}, {b}, {c};";
448 let analysis = analyze_ptx(ptx);
449 assert!(analysis.has_wmma);
450 }
451
452 #[test]
453 fn test_analyze_ptx_high_register_warning() {
454 let ptx = ".reg .f32 %f<256>;";
455 let analysis = analyze_ptx(ptx);
456 assert!(analysis.registers_declared >= 256);
457 assert!(!analysis.warnings.is_empty());
458 }
459
460 #[test]
461 fn test_analyze_wgsl_basic() {
462 let wgsl = r#"
463@group(0) @binding(0) var<storage, read> a: array<f32>;
464@group(0) @binding(1) var<storage, read_write> b: array<f32>;
465
466@compute @workgroup_size(256, 1, 1)
467fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
468 b[gid.x] = a[gid.x] * 2.0;
469}
470"#;
471 let analysis = analyze_wgsl(wgsl);
472 assert_eq!(analysis.bindings, 2);
473 assert_eq!(analysis.workgroup_size.as_deref(), Some("256, 1, 1"));
474 assert!(!analysis.has_atomics);
475 }
476
477 #[test]
478 fn test_analyze_wgsl_small_workgroup() {
479 let wgsl = "@compute @workgroup_size(8, 1, 1)\nfn main() {}";
480 let analysis = analyze_wgsl(wgsl);
481 assert!(!analysis.warnings.is_empty());
482 }
483
484 #[test]
485 fn test_run_explain_ptx() {
486 let result = run_explain("ptx", None);
487 assert!(result.is_ok());
488 }
489
490 #[test]
491 fn test_run_explain_simd() {
492 let result = run_explain("simd", None);
493 assert!(result.is_ok());
494 }
495
496 #[test]
497 fn test_run_explain_unknown() {
498 let result = run_explain("unknown_target", None);
499 assert!(result.is_ok());
500 }
501}