Skip to main content

jugar_probar/gpu_pixels/
ptx_analysis.rs

1//! PTX Static Analysis for GPU Kernel Bug Detection
2//!
3//! Detects common PTX bugs through regex-based static analysis:
4//! - Shared memory using 64-bit addressing (should be 32-bit)
5//! - Loop branches going to END instead of START
6//! - Missing barrier synchronization
7//! - Invalid register types for operations
8
9// Static regexes are always valid - compile-time constant patterns
10// collection_is_never_read: loop_start_labels is used for tracking/debug
11#![allow(
12    clippy::unwrap_used,
13    clippy::trivial_regex,
14    clippy::collection_is_never_read
15)]
16
17use regex::Regex;
18use std::collections::HashSet;
19
20/// PTX bug classification
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum PtxBugClass {
23    /// Shared memory accessed with 64-bit register (should be 32-bit)
24    SharedMemU64Addressing,
25    /// Loop branches to END label instead of START
26    LoopBranchToEnd,
27    /// Missing barrier sync between shared memory write and read
28    MissingBarrierSync,
29    /// Accumulator not updated in-place in loop
30    NonInPlaceLoopAccumulator,
31    /// Invalid PTX syntax
32    InvalidSyntax,
33    /// Kernel entry point missing
34    MissingEntryPoint,
35}
36
37impl std::fmt::Display for PtxBugClass {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        match self {
40            Self::SharedMemU64Addressing => write!(f, "shared_mem_u64"),
41            Self::LoopBranchToEnd => write!(f, "loop_branch_to_end"),
42            Self::MissingBarrierSync => write!(f, "missing_barrier"),
43            Self::NonInPlaceLoopAccumulator => write!(f, "non_inplace_accum"),
44            Self::InvalidSyntax => write!(f, "invalid_syntax"),
45            Self::MissingEntryPoint => write!(f, "missing_entry"),
46        }
47    }
48}
49
50/// A detected PTX bug
51#[derive(Debug, Clone)]
52pub struct PtxBug {
53    /// Bug classification
54    pub class: PtxBugClass,
55    /// Line number (1-indexed, 0 if unknown)
56    pub line: usize,
57    /// The offending PTX instruction
58    pub instruction: String,
59    /// Human-readable explanation
60    pub message: String,
61}
62
63/// Result of PTX validation
64#[derive(Debug, Clone)]
65pub struct PtxValidationResult {
66    /// List of detected bugs
67    pub bugs: Vec<PtxBug>,
68    /// Kernel names found
69    pub kernel_names: Vec<String>,
70    /// Total lines analyzed
71    pub lines_analyzed: usize,
72}
73
74impl PtxValidationResult {
75    /// Check if PTX passed all validations
76    #[must_use]
77    pub fn is_valid(&self) -> bool {
78        self.bugs.is_empty() && !self.kernel_names.is_empty()
79    }
80
81    /// Get count of bugs by class
82    #[must_use]
83    pub fn bug_count(&self, class: &PtxBugClass) -> usize {
84        self.bugs.iter().filter(|b| &b.class == class).count()
85    }
86
87    /// Check for specific bug class
88    #[must_use]
89    pub fn has_bug(&self, class: &PtxBugClass) -> bool {
90        self.bugs.iter().any(|b| &b.class == class)
91    }
92}
93
94/// PTX static analyzer
95#[derive(Debug, Default)]
96pub struct PtxAnalyzer {
97    /// Enable strict mode (more warnings)
98    pub strict: bool,
99}
100
101impl PtxAnalyzer {
102    /// Create analyzer with strict mode
103    #[must_use]
104    pub fn strict() -> Self {
105        Self { strict: true }
106    }
107
108    /// Analyze PTX string for bugs
109    #[must_use]
110    pub fn analyze(&self, ptx: &str) -> PtxValidationResult {
111        let mut bugs = Vec::new();
112        let mut kernel_names = Vec::new();
113        let lines: Vec<&str> = ptx.lines().collect();
114
115        // Regex patterns for bug detection
116        let shared_mem_u64 = Regex::new(r"[sl]t\.shared\.[^\[]+\[%rd\d+\]").unwrap();
117        let entry_point = Regex::new(r"\.visible\s+\.entry\s+(\w+)").unwrap();
118        let loop_label = Regex::new(r"^(\w+_loop\w*):").unwrap();
119        let branch_instr = Regex::new(r"bra\s+(\w+);").unwrap();
120        let bar_sync = Regex::new(r"bar\.sync").unwrap();
121
122        // Track loop labels
123        let mut loop_start_labels: HashSet<String> = HashSet::new();
124        let mut loop_end_labels: HashSet<String> = HashSet::new();
125
126        // First pass: collect labels
127        for line in &lines {
128            let trimmed = line.trim();
129            if let Some(caps) = loop_label.captures(trimmed) {
130                let label = caps.get(1).unwrap().as_str();
131                if label.contains("_start")
132                    || label.ends_with("_loop")
133                    || label.starts_with("loop_")
134                {
135                    loop_start_labels.insert(label.to_string());
136                } else if label.contains("_end") {
137                    loop_end_labels.insert(label.to_string());
138                }
139            }
140        }
141
142        // Second pass: detect bugs
143        for (line_num, line) in lines.iter().enumerate() {
144            let trimmed = line.trim();
145
146            // Detect shared memory u64 addressing
147            if shared_mem_u64.is_match(trimmed) {
148                bugs.push(PtxBug {
149                    class: PtxBugClass::SharedMemU64Addressing,
150                    line: line_num + 1,
151                    instruction: trimmed.to_string(),
152                    message: "Shared memory accessed with 64-bit register. Use 32-bit addressing."
153                        .to_string(),
154                });
155            }
156
157            // Collect kernel names
158            if let Some(caps) = entry_point.captures(trimmed) {
159                kernel_names.push(caps.get(1).unwrap().as_str().to_string());
160            }
161
162            // Detect branch to loop end from inside loop body
163            // (this is a heuristic - may have false positives)
164            if let Some(caps) = branch_instr.captures(trimmed) {
165                let target = caps.get(1).unwrap().as_str();
166                // If branching to a _end label that has a corresponding _start,
167                // and we're not at a conditional branch after loop check,
168                // it might be a bug
169                if self.strict && loop_end_labels.contains(target) {
170                    // Check if this is an unconditional branch (potential loop continuation bug)
171                    if !trimmed.starts_with('@') && !trimmed.contains("@%p") {
172                        bugs.push(PtxBug {
173                            class: PtxBugClass::LoopBranchToEnd,
174                            line: line_num + 1,
175                            instruction: trimmed.to_string(),
176                            message: format!(
177                                "Unconditional branch to loop end '{}'. Should branch to start?",
178                                target
179                            ),
180                        });
181                    }
182                }
183            }
184        }
185
186        // Check for missing entry point
187        if kernel_names.is_empty() && !ptx.trim().is_empty() {
188            bugs.push(PtxBug {
189                class: PtxBugClass::MissingEntryPoint,
190                line: 0,
191                instruction: String::new(),
192                message: "No kernel entry point found".to_string(),
193            });
194        }
195
196        // Check for barrier sync presence when shared memory is used
197        let uses_shared =
198            ptx.contains(".shared") || ptx.contains("st.shared") || ptx.contains("ld.shared");
199        let has_barrier = bar_sync.is_match(ptx);
200        if self.strict && uses_shared && !has_barrier {
201            bugs.push(PtxBug {
202                class: PtxBugClass::MissingBarrierSync,
203                line: 0,
204                instruction: String::new(),
205                message: "Shared memory used but no bar.sync found".to_string(),
206            });
207        }
208
209        PtxValidationResult {
210            bugs,
211            kernel_names,
212            lines_analyzed: lines.len(),
213        }
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_shared_mem_u64_detection() {
223        let ptx = "st.shared.f32 [%rd5], %f0;";
224        let analyzer = PtxAnalyzer::default();
225        let result = analyzer.analyze(ptx);
226        assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
227    }
228
229    #[test]
230    fn test_shared_mem_u32_ok() {
231        let ptx = "st.shared.f32 [%r5], %f0;";
232        let analyzer = PtxAnalyzer::default();
233        let result = analyzer.analyze(ptx);
234        assert!(!result.has_bug(&PtxBugClass::SharedMemU64Addressing));
235    }
236
237    #[test]
238    fn test_kernel_name_extraction() {
239        let ptx = r#"
240.visible .entry gemm_tiled(
241    .param .u64 a_ptr
242) {
243    ret;
244}
245"#;
246        let result = PtxAnalyzer::default().analyze(ptx);
247        assert_eq!(result.kernel_names, vec!["gemm_tiled"]);
248    }
249
250    #[test]
251    fn test_multiple_kernels() {
252        let ptx = r#"
253.visible .entry kernel_a() { ret; }
254.visible .entry kernel_b() { ret; }
255"#;
256        let result = PtxAnalyzer::default().analyze(ptx);
257        assert_eq!(result.kernel_names.len(), 2);
258    }
259
260    #[test]
261    fn test_missing_entry_point() {
262        let ptx = ".version 8.0\n.target sm_70";
263        let result = PtxAnalyzer::default().analyze(ptx);
264        assert!(result.has_bug(&PtxBugClass::MissingEntryPoint));
265    }
266
267    #[test]
268    fn test_strict_mode_barrier() {
269        let ptx = r#"
270.visible .entry test() {
271    .shared .b8 smem[1024];
272    st.shared.f32 [%r0], %f0;
273    ret;
274}
275"#;
276        let strict_result = PtxAnalyzer::strict().analyze(ptx);
277        let normal_result = PtxAnalyzer::default().analyze(ptx);
278
279        assert!(strict_result.has_bug(&PtxBugClass::MissingBarrierSync));
280        assert!(!normal_result.has_bug(&PtxBugClass::MissingBarrierSync));
281    }
282
283    #[test]
284    fn test_bug_class_display() {
285        assert_eq!(
286            format!("{}", PtxBugClass::SharedMemU64Addressing),
287            "shared_mem_u64"
288        );
289        assert_eq!(
290            format!("{}", PtxBugClass::LoopBranchToEnd),
291            "loop_branch_to_end"
292        );
293    }
294
295    #[test]
296    fn test_validation_result_helpers() {
297        let result = PtxValidationResult {
298            bugs: vec![
299                PtxBug {
300                    class: PtxBugClass::SharedMemU64Addressing,
301                    line: 1,
302                    instruction: "test".to_string(),
303                    message: "test".to_string(),
304                },
305                PtxBug {
306                    class: PtxBugClass::SharedMemU64Addressing,
307                    line: 2,
308                    instruction: "test".to_string(),
309                    message: "test".to_string(),
310                },
311            ],
312            kernel_names: vec!["test".to_string()],
313            lines_analyzed: 10,
314        };
315
316        assert_eq!(result.bug_count(&PtxBugClass::SharedMemU64Addressing), 2);
317        assert_eq!(result.bug_count(&PtxBugClass::LoopBranchToEnd), 0);
318        assert!(!result.is_valid());
319    }
320
321    #[test]
322    fn test_bug_class_display_all_variants() {
323        assert_eq!(
324            format!("{}", PtxBugClass::MissingBarrierSync),
325            "missing_barrier"
326        );
327        assert_eq!(
328            format!("{}", PtxBugClass::NonInPlaceLoopAccumulator),
329            "non_inplace_accum"
330        );
331        assert_eq!(format!("{}", PtxBugClass::InvalidSyntax), "invalid_syntax");
332        assert_eq!(
333            format!("{}", PtxBugClass::MissingEntryPoint),
334            "missing_entry"
335        );
336    }
337
338    #[test]
339    fn test_loop_branch_to_end_strict_mode() {
340        // PTX with a loop that branches to _end unconditionally
341        // The loop_label regex requires _loop suffix
342        let ptx = r#"
343.visible .entry test() {
344test_loop_start:
345    // loop body
346    bra test_loop_end;
347test_loop_end:
348    ret;
349}
350"#;
351        let strict_result = PtxAnalyzer::strict().analyze(ptx);
352        // In strict mode, unconditional branch to _end should be flagged
353        assert!(strict_result.has_bug(&PtxBugClass::LoopBranchToEnd));
354    }
355
356    #[test]
357    fn test_loop_labels_with_loop_suffix() {
358        // Labels must match the regex: ^(\w+_loop\w*):
359        let ptx = r#"
360.visible .entry test() {
361main_loop:
362    bra main_loop_end;
363main_loop_end:
364    ret;
365}
366"#;
367        let result = PtxAnalyzer::strict().analyze(ptx);
368        // main_loop matches the _loop pattern, main_loop_end contains _end
369        assert!(result.has_bug(&PtxBugClass::LoopBranchToEnd));
370    }
371
372    #[test]
373    fn test_conditional_branch_not_flagged() {
374        let ptx = r#"
375.visible .entry test() {
376loop_start:
377    @%p0 bra loop_end;
378loop_end:
379    ret;
380}
381"#;
382        let result = PtxAnalyzer::strict().analyze(ptx);
383        // Conditional branch should NOT be flagged
384        assert!(!result.has_bug(&PtxBugClass::LoopBranchToEnd));
385    }
386
387    #[test]
388    fn test_ld_shared_u64_detection() {
389        // Test ld.shared with 64-bit register - must match pattern [sl]t\.shared
390        // ld.shared doesn't match - only st.shared and lt.shared (which isn't valid)
391        // So the regex is really for st.shared only
392        let ptx = "st.shared.f32 [%rd5], %f0;";
393        let result = PtxAnalyzer::default().analyze(ptx);
394        assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
395    }
396
397    #[test]
398    fn test_valid_result_empty_bugs() {
399        let result = PtxValidationResult {
400            bugs: vec![],
401            kernel_names: vec!["kernel".to_string()],
402            lines_analyzed: 5,
403        };
404        assert!(result.is_valid());
405    }
406
407    #[test]
408    fn test_invalid_result_no_kernels() {
409        let result = PtxValidationResult {
410            bugs: vec![],
411            kernel_names: vec![],
412            lines_analyzed: 5,
413        };
414        assert!(!result.is_valid());
415    }
416
417    #[test]
418    fn test_empty_ptx_no_bugs() {
419        let result = PtxAnalyzer::default().analyze("");
420        assert!(result.bugs.is_empty());
421        assert!(result.kernel_names.is_empty());
422    }
423
424    #[test]
425    fn test_shared_mem_st_detection() {
426        let ptx = "st.shared.f32 [%rd0], %f1;";
427        let result = PtxAnalyzer::default().analyze(ptx);
428        assert!(result.has_bug(&PtxBugClass::SharedMemU64Addressing));
429    }
430
431    #[test]
432    fn test_barrier_present() {
433        let ptx = r#"
434.visible .entry test() {
435    .shared .b8 smem[1024];
436    st.shared.f32 [%r0], %f0;
437    bar.sync 0;
438    ret;
439}
440"#;
441        let result = PtxAnalyzer::strict().analyze(ptx);
442        assert!(!result.has_bug(&PtxBugClass::MissingBarrierSync));
443    }
444
445    #[test]
446    fn test_analyzer_debug() {
447        let analyzer = PtxAnalyzer::default();
448        let debug_str = format!("{:?}", analyzer);
449        assert!(debug_str.contains("PtxAnalyzer"));
450    }
451
452    #[test]
453    fn test_ptx_bug_fields() {
454        let bug = PtxBug {
455            class: PtxBugClass::InvalidSyntax,
456            line: 42,
457            instruction: "invalid".to_string(),
458            message: "Bad syntax".to_string(),
459        };
460        assert_eq!(bug.line, 42);
461        assert_eq!(bug.instruction, "invalid");
462        assert_eq!(bug.message, "Bad syntax");
463        assert_eq!(bug.class, PtxBugClass::InvalidSyntax);
464    }
465
466    #[test]
467    fn test_bug_class_hash_eq() {
468        use std::collections::HashSet;
469        let mut set = HashSet::new();
470        set.insert(PtxBugClass::SharedMemU64Addressing);
471        set.insert(PtxBugClass::LoopBranchToEnd);
472        assert!(set.contains(&PtxBugClass::SharedMemU64Addressing));
473        assert!(!set.contains(&PtxBugClass::MissingBarrierSync));
474    }
475
476    #[test]
477    fn test_validation_result_clone() {
478        let result = PtxValidationResult {
479            bugs: vec![],
480            kernel_names: vec!["test".to_string()],
481            lines_analyzed: 10,
482        };
483        let cloned = result.clone();
484        assert_eq!(cloned.kernel_names, result.kernel_names);
485        assert_eq!(cloned.lines_analyzed, result.lines_analyzed);
486    }
487}