1#![allow(
12 clippy::unwrap_used,
13 clippy::trivial_regex,
14 clippy::collection_is_never_read
15)]
16
17use regex::Regex;
18use std::collections::HashSet;
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum PtxBugClass {
23 SharedMemU64Addressing,
25 LoopBranchToEnd,
27 MissingBarrierSync,
29 NonInPlaceLoopAccumulator,
31 InvalidSyntax,
33 MissingEntryPoint,
35}
36
37impl std::fmt::Display for PtxBugClass {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 match self {
40 Self::SharedMemU64Addressing => write!(f, "shared_mem_u64"),
41 Self::LoopBranchToEnd => write!(f, "loop_branch_to_end"),
42 Self::MissingBarrierSync => write!(f, "missing_barrier"),
43 Self::NonInPlaceLoopAccumulator => write!(f, "non_inplace_accum"),
44 Self::InvalidSyntax => write!(f, "invalid_syntax"),
45 Self::MissingEntryPoint => write!(f, "missing_entry"),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct PtxBug {
53 pub class: PtxBugClass,
55 pub line: usize,
57 pub instruction: String,
59 pub message: String,
61}
62
63#[derive(Debug, Clone)]
65pub struct PtxValidationResult {
66 pub bugs: Vec<PtxBug>,
68 pub kernel_names: Vec<String>,
70 pub lines_analyzed: usize,
72}
73
74impl PtxValidationResult {
75 #[must_use]
77 pub fn is_valid(&self) -> bool {
78 self.bugs.is_empty() && !self.kernel_names.is_empty()
79 }
80
81 #[must_use]
83 pub fn bug_count(&self, class: &PtxBugClass) -> usize {
84 self.bugs.iter().filter(|b| &b.class == class).count()
85 }
86
87 #[must_use]
89 pub fn has_bug(&self, class: &PtxBugClass) -> bool {
90 self.bugs.iter().any(|b| &b.class == class)
91 }
92}
93
94#[derive(Debug, Default)]
96pub struct PtxAnalyzer {
97 pub strict: bool,
99}
100
101impl PtxAnalyzer {
102 #[must_use]
104 pub fn strict() -> Self {
105 Self { strict: true }
106 }
107
108 #[must_use]
110 pub fn analyze(&self, ptx: &str) -> PtxValidationResult {
111 let mut bugs = Vec::new();
112 let mut kernel_names = Vec::new();
113 let lines: Vec<&str> = ptx.lines().collect();
114
115 let shared_mem_u64 = Regex::new(r"[sl]t\.shared\.[^\[]+\[%rd\d+\]").unwrap();
117 let entry_point = Regex::new(r"\.visible\s+\.entry\s+(\w+)").unwrap();
118 let loop_label = Regex::new(r"^(\w+_loop\w*):").unwrap();
119 let branch_instr = Regex::new(r"bra\s+(\w+);").unwrap();
120 let bar_sync = Regex::new(r"bar\.sync").unwrap();
121
122 let mut loop_start_labels: HashSet<String> = HashSet::new();
124 let mut loop_end_labels: HashSet<String> = HashSet::new();
125
126 for line in &lines {
128 let trimmed = line.trim();
129 if let Some(caps) = loop_label.captures(trimmed) {
130 let label = caps.get(1).unwrap().as_str();
131 if label.contains("_start")
132 || label.ends_with("_loop")
133 || label.starts_with("loop_")
134 {
135 loop_start_labels.insert(label.to_string());
136 } else if label.contains("_end") {
137 loop_end_labels.insert(label.to_string());
138 }
139 }
140 }
141
142 for (line_num, line) in lines.iter().enumerate() {
144 let trimmed = line.trim();
145
146 if shared_mem_u64.is_match(trimmed) {
148 bugs.push(PtxBug {
149 class: PtxBugClass::SharedMemU64Addressing,
150 line: line_num + 1,
151 instruction: trimmed.to_string(),
152 message: "Shared memory accessed with 64-bit register. Use 32-bit addressing."
153 .to_string(),
154 });
155 }
156
157 if let Some(caps) = entry_point.captures(trimmed) {
159 kernel_names.push(caps.get(1).unwrap().as_str().to_string());
160 }
161
162 if let Some(caps) = branch_instr.captures(trimmed) {
165 let target = caps.get(1).unwrap().as_str();
166 if self.strict && loop_end_labels.contains(target) {
170 if !trimmed.starts_with('@') && !trimmed.contains("@%p") {
172 bugs.push(PtxBug {
173 class: PtxBugClass::LoopBranchToEnd,
174 line: line_num + 1,
175 instruction: trimmed.to_string(),
176 message: format!(
177 "Unconditional branch to loop end '{}'. Should branch to start?",
178 target
179 ),
180 });
181 }
182 }
183 }
184 }
185
186 if kernel_names.is_empty() && !ptx.trim().is_empty() {
188 bugs.push(PtxBug {
189 class: PtxBugClass::MissingEntryPoint,
190 line: 0,
191 instruction: String::new(),
192 message: "No kernel entry point found".to_string(),
193 });
194 }
195
196 let uses_shared =
198 ptx.contains(".shared") || ptx.contains("st.shared") || ptx.contains("ld.shared");
199 let has_barrier = bar_sync.is_match(ptx);
200 if self.strict && uses_shared && !has_barrier {
201 bugs.push(PtxBug {
202 class: PtxBugClass::MissingBarrierSync,
203 line: 0,
204 instruction: String::new(),
205 message: "Shared memory used but no bar.sync found".to_string(),
206 });
207 }
208
209 PtxValidationResult {
210 bugs,
211 kernel_names,
212 lines_analyzed: lines.len(),
213 }
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_shared_mem_u64_detection() {
223 let ptx = "st.shared.f32 [%rd5], %f0;";
224 let analyzer = PtxAnalyzer::default();
225 let result = analyzer.analyze(ptx);
226 assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
227 }
228
229 #[test]
230 fn test_shared_mem_u32_ok() {
231 let ptx = "st.shared.f32 [%r5], %f0;";
232 let analyzer = PtxAnalyzer::default();
233 let result = analyzer.analyze(ptx);
234 assert!(!result.has_bug(&PtxBugClass::SharedMemU64Addressing));
235 }
236
237 #[test]
238 fn test_kernel_name_extraction() {
239 let ptx = r#"
240.visible .entry gemm_tiled(
241 .param .u64 a_ptr
242) {
243 ret;
244}
245"#;
246 let result = PtxAnalyzer::default().analyze(ptx);
247 assert_eq!(result.kernel_names, vec!["gemm_tiled"]);
248 }
249
250 #[test]
251 fn test_multiple_kernels() {
252 let ptx = r#"
253.visible .entry kernel_a() { ret; }
254.visible .entry kernel_b() { ret; }
255"#;
256 let result = PtxAnalyzer::default().analyze(ptx);
257 assert_eq!(result.kernel_names.len(), 2);
258 }
259
260 #[test]
261 fn test_missing_entry_point() {
262 let ptx = ".version 8.0\n.target sm_70";
263 let result = PtxAnalyzer::default().analyze(ptx);
264 assert!(result.has_bug(&PtxBugClass::MissingEntryPoint));
265 }
266
267 #[test]
268 fn test_strict_mode_barrier() {
269 let ptx = r#"
270.visible .entry test() {
271 .shared .b8 smem[1024];
272 st.shared.f32 [%r0], %f0;
273 ret;
274}
275"#;
276 let strict_result = PtxAnalyzer::strict().analyze(ptx);
277 let normal_result = PtxAnalyzer::default().analyze(ptx);
278
279 assert!(strict_result.has_bug(&PtxBugClass::MissingBarrierSync));
280 assert!(!normal_result.has_bug(&PtxBugClass::MissingBarrierSync));
281 }
282
283 #[test]
284 fn test_bug_class_display() {
285 assert_eq!(
286 format!("{}", PtxBugClass::SharedMemU64Addressing),
287 "shared_mem_u64"
288 );
289 assert_eq!(
290 format!("{}", PtxBugClass::LoopBranchToEnd),
291 "loop_branch_to_end"
292 );
293 }
294
295 #[test]
296 fn test_validation_result_helpers() {
297 let result = PtxValidationResult {
298 bugs: vec![
299 PtxBug {
300 class: PtxBugClass::SharedMemU64Addressing,
301 line: 1,
302 instruction: "test".to_string(),
303 message: "test".to_string(),
304 },
305 PtxBug {
306 class: PtxBugClass::SharedMemU64Addressing,
307 line: 2,
308 instruction: "test".to_string(),
309 message: "test".to_string(),
310 },
311 ],
312 kernel_names: vec!["test".to_string()],
313 lines_analyzed: 10,
314 };
315
316 assert_eq!(result.bug_count(&PtxBugClass::SharedMemU64Addressing), 2);
317 assert_eq!(result.bug_count(&PtxBugClass::LoopBranchToEnd), 0);
318 assert!(!result.is_valid());
319 }
320
321 #[test]
322 fn test_bug_class_display_all_variants() {
323 assert_eq!(
324 format!("{}", PtxBugClass::MissingBarrierSync),
325 "missing_barrier"
326 );
327 assert_eq!(
328 format!("{}", PtxBugClass::NonInPlaceLoopAccumulator),
329 "non_inplace_accum"
330 );
331 assert_eq!(format!("{}", PtxBugClass::InvalidSyntax), "invalid_syntax");
332 assert_eq!(
333 format!("{}", PtxBugClass::MissingEntryPoint),
334 "missing_entry"
335 );
336 }
337
338 #[test]
339 fn test_loop_branch_to_end_strict_mode() {
340 let ptx = r#"
343.visible .entry test() {
344test_loop_start:
345 // loop body
346 bra test_loop_end;
347test_loop_end:
348 ret;
349}
350"#;
351 let strict_result = PtxAnalyzer::strict().analyze(ptx);
352 assert!(strict_result.has_bug(&PtxBugClass::LoopBranchToEnd));
354 }
355
356 #[test]
357 fn test_loop_labels_with_loop_suffix() {
358 let ptx = r#"
360.visible .entry test() {
361main_loop:
362 bra main_loop_end;
363main_loop_end:
364 ret;
365}
366"#;
367 let result = PtxAnalyzer::strict().analyze(ptx);
368 assert!(result.has_bug(&PtxBugClass::LoopBranchToEnd));
370 }
371
372 #[test]
373 fn test_conditional_branch_not_flagged() {
374 let ptx = r#"
375.visible .entry test() {
376loop_start:
377 @%p0 bra loop_end;
378loop_end:
379 ret;
380}
381"#;
382 let result = PtxAnalyzer::strict().analyze(ptx);
383 assert!(!result.has_bug(&PtxBugClass::LoopBranchToEnd));
385 }
386
387 #[test]
388 fn test_ld_shared_u64_detection() {
389 let ptx = "st.shared.f32 [%rd5], %f0;";
393 let result = PtxAnalyzer::default().analyze(ptx);
394 assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
395 }
396
397 #[test]
398 fn test_valid_result_empty_bugs() {
399 let result = PtxValidationResult {
400 bugs: vec![],
401 kernel_names: vec!["kernel".to_string()],
402 lines_analyzed: 5,
403 };
404 assert!(result.is_valid());
405 }
406
407 #[test]
408 fn test_invalid_result_no_kernels() {
409 let result = PtxValidationResult {
410 bugs: vec![],
411 kernel_names: vec![],
412 lines_analyzed: 5,
413 };
414 assert!(!result.is_valid());
415 }
416
417 #[test]
418 fn test_empty_ptx_no_bugs() {
419 let result = PtxAnalyzer::default().analyze("");
420 assert!(result.bugs.is_empty());
421 assert!(result.kernel_names.is_empty());
422 }
423
424 #[test]
425 fn test_shared_mem_st_detection() {
426 let ptx = "st.shared.f32 [%rd0], %f1;";
427 let result = PtxAnalyzer::default().analyze(ptx);
428 assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
429 }
430
431 #[test]
432 fn test_barrier_present() {
433 let ptx = r#"
434.visible .entry test() {
435 .shared .b8 smem[1024];
436 st.shared.f32 [%r0], %f0;
437 bar.sync 0;
438 ret;
439}
440"#;
441 let result = PtxAnalyzer::strict().analyze(ptx);
442 assert!(!result.has_bug(&PtxBugClass::MissingBarrierSync));
443 }
444
445 #[test]
446 fn test_analyzer_debug() {
447 let analyzer = PtxAnalyzer::default();
448 let debug_str = format!("{:?}", analyzer);
449 assert!(debug_str.contains("PtxAnalyzer"));
450 }
451
452 #[test]
453 fn test_ptx_bug_fields() {
454 let bug = PtxBug {
455 class: PtxBugClass::InvalidSyntax,
456 line: 42,
457 instruction: "invalid".to_string(),
458 message: "Bad syntax".to_string(),
459 };
460 assert_eq!(bug.line, 42);
461 assert_eq!(bug.instruction, "invalid");
462 assert_eq!(bug.message, "Bad syntax");
463 assert_eq!(bug.class, PtxBugClass::InvalidSyntax);
464 }
465
466 #[test]
467 fn test_bug_class_hash_eq() {
468 use std::collections::HashSet;
469 let mut set = HashSet::new();
470 set.insert(PtxBugClass::SharedMemU64Addressing);
471 set.insert(PtxBugClass::LoopBranchToEnd);
472 assert!(set.contains(&PtxBugClass::SharedMemU64Addressing));
473 assert!(!set.contains(&PtxBugClass::MissingBarrierSync));
474 }
475
476 #[test]
477 fn test_validation_result_clone() {
478 let result = PtxValidationResult {
479 bugs: vec![],
480 kernel_names: vec!["test".to_string()],
481 lines_analyzed: 10,
482 };
483 let cloned = result.clone();
484 assert_eq!(cloned.kernel_names, result.kernel_names);
485 assert_eq!(cloned.lines_analyzed, result.lines_analyzed);
486 }
487}