Skip to main content

trueno_explain/ptx/bugs/analyzer/
mod.rs

1use regex::Regex;
2
3use super::types::{PtxBugClass, PtxBugReport};
4
5/// Whitelist entry for suppressing known acceptable warnings
6#[derive(Debug, Clone)]
7pub struct WhitelistEntry {
8    /// Kernel name pattern (supports prefix matching with *)
9    pub kernel_pattern: String,
10    /// Bug class to suppress
11    pub bug_class: PtxBugClass,
12    /// Reason for whitelisting
13    pub reason: String,
14}
15
16/// PTX bug hunting analyzer (inspired by probar `gpu_pixels`)
17#[derive(Debug, Default, Clone)]
18pub struct PtxBugAnalyzer {
19    /// Enable strict mode (more warnings, catches PARITY-114 pattern)
20    pub strict: bool,
21    /// Whitelist for suppressing known acceptable warnings
22    pub whitelist: Vec<WhitelistEntry>,
23}
24
25impl PtxBugAnalyzer {
26    /// Create analyzer with default (non-strict) mode
27    #[must_use]
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Create analyzer with strict mode enabled
33    #[must_use]
34    pub fn strict() -> Self {
35        Self {
36            strict: true,
37            whitelist: Vec::new(),
38        }
39    }
40
41    /// Add a whitelist entry to suppress warnings
42    #[must_use]
43    pub fn with_whitelist(
44        mut self,
45        kernel_pattern: &str,
46        bug_class: PtxBugClass,
47        reason: &str,
48    ) -> Self {
49        self.whitelist.push(WhitelistEntry {
50            kernel_pattern: kernel_pattern.to_string(),
51            bug_class,
52            reason: reason.to_string(),
53        });
54        self
55    }
56
57    /// Create analyzer with default whitelist for quantized kernels
58    #[must_use]
59    pub fn with_quantized_whitelist() -> Self {
60        Self::new()
61            .with_whitelist(
62                "q4k*",
63                PtxBugClass::HighRegisterPressure,
64                "Quantized kernels require high registers for dequantization",
65            )
66            .with_whitelist(
67                "q5k*",
68                PtxBugClass::HighRegisterPressure,
69                "Quantized kernels require high registers for dequantization",
70            )
71            .with_whitelist(
72                "q6k*",
73                PtxBugClass::HighRegisterPressure,
74                "Quantized kernels require high registers for dequantization",
75            )
76            .with_whitelist(
77                "q8k*",
78                PtxBugClass::HighRegisterPressure,
79                "Quantized kernels require high registers for dequantization",
80            )
81    }
82
83    /// Create analyzer with comprehensive whitelist for all high-performance kernels
84    ///
85    /// This whitelist covers expected register pressure and predicate usage in:
86    /// - Tensor Core kernels (WMMA requires many registers for matrix fragments)
87    /// - Attention kernels (`FlashAttention` needs registers for tiling state)
88    /// - Quantized kernels (dequantization requires intermediate values)
89    ///
90    /// These are documented performance tradeoffs, not bugs.
91    #[must_use]
92    pub fn with_performance_whitelist() -> Self {
93        Self::new()
94            // Tensor Core / WMMA kernels - high register usage is expected
95            // WMMA m16n16k16 requires 8 registers per fragment × 3 fragments = 24+ registers
96            // Plus accumulator, addresses, loop counters, etc.
97            .with_whitelist(
98                "gemm_tensor_core*",
99                PtxBugClass::HighRegisterPressure,
100                "Tensor Core WMMA requires many registers for matrix fragments",
101            )
102            .with_whitelist(
103                "gemm_tensor_core*",
104                PtxBugClass::PredicateOverflow,
105                "Tensor Core kernels use predicates for bounds checking and masking",
106            )
107            .with_whitelist(
108                "gemm_wmma*",
109                PtxBugClass::HighRegisterPressure,
110                "WMMA FP16 requires registers for A/B/C/D matrix fragments",
111            )
112            .with_whitelist(
113                "gemm_wmma*",
114                PtxBugClass::PredicateOverflow,
115                "WMMA kernels use predicates for tile boundary handling",
116            )
117            // Attention kernels - FlashAttention tiling requires state
118            .with_whitelist(
119                "flash_attention*",
120                PtxBugClass::HighRegisterPressure,
121                "FlashAttention tiling requires registers for Q/K/V/O tiles and softmax state",
122            )
123            .with_whitelist(
124                "attention*",
125                PtxBugClass::HighRegisterPressure,
126                "Attention kernels require registers for Q/K/V tiles and reduction",
127            )
128            // Quantized kernels - dequantization math
129            .with_whitelist(
130                "q4k*",
131                PtxBugClass::HighRegisterPressure,
132                "Q4_K dequantization requires registers for scale/min extraction",
133            )
134            .with_whitelist(
135                "q5k*",
136                PtxBugClass::HighRegisterPressure,
137                "Q5_K dequantization requires registers for 5-bit value reconstruction",
138            )
139            .with_whitelist(
140                "q6k*",
141                PtxBugClass::HighRegisterPressure,
142                "Q6_K dequantization requires registers for 6-bit value reconstruction",
143            )
144            .with_whitelist(
145                "q8k*",
146                PtxBugClass::HighRegisterPressure,
147                "Q8_K dequantization requires registers for scale application",
148            )
149    }
150
151    /// Check if a bug should be suppressed by whitelist
152    fn is_whitelisted(&self, kernel_name: Option<&String>, bug_class: &PtxBugClass) -> bool {
153        let Some(kernel) = kernel_name else {
154            return false;
155        };
156
157        for entry in &self.whitelist {
158            if &entry.bug_class != bug_class {
159                continue;
160            }
161            // Pattern matching: "q4k*" matches "q4k_gemm_ggml"
162            if entry.kernel_pattern.ends_with('*') {
163                let prefix = &entry.kernel_pattern[..entry.kernel_pattern.len() - 1];
164                if kernel.starts_with(prefix) {
165                    return true;
166                }
167            } else if &entry.kernel_pattern == kernel {
168                return true;
169            }
170        }
171        false
172    }
173
174    /// Analyze PTX for bugs
175    #[must_use]
176    pub fn analyze(&self, ptx: &str) -> PtxBugReport {
177        let mut bugs = Vec::new();
178        let lines: Vec<&str> = ptx.lines().collect();
179
180        // Extract kernel name
181        let kernel_name = self.extract_kernel_name(ptx);
182
183        // Execute all pattern detectors
184        bugs.extend(self.detect_shared_mem_u64(ptx, &lines));
185        bugs.extend(self.detect_loop_branch_to_end(ptx, &lines));
186        bugs.extend(self.detect_missing_barrier_sync(ptx, &lines));
187        bugs.extend(self.detect_early_exit_before_barrier(ptx));
188        bugs.extend(self.detect_register_spills(ptx, &lines));
189        bugs.extend(self.detect_missing_entry_point(ptx, &lines));
190        bugs.extend(self.detect_redundant_moves(ptx, &lines));
191        bugs.extend(self.detect_unoptimized_memory(ptx, &lines));
192        bugs.extend(self.detect_high_register_pressure(ptx, &lines));
193        bugs.extend(self.detect_predicate_overflow(ptx, &lines));
194        bugs.extend(self.detect_placeholder_code(ptx, &lines));
195        // New extended detectors
196        bugs.extend(self.detect_empty_loop_body(ptx, &lines));
197        bugs.extend(self.detect_missing_bounds_check(ptx, &lines));
198        bugs.extend(self.detect_dead_code(ptx, &lines));
199
200        // Filter out whitelisted bugs
201        bugs.retain(|bug| !self.is_whitelisted(kernel_name.as_ref(), &bug.class));
202
203        PtxBugReport {
204            kernel_name,
205            bugs,
206            lines_analyzed: lines.len(),
207            strict_mode: self.strict,
208        }
209    }
210
211    /// Extract kernel name from PTX
212    fn extract_kernel_name(&self, ptx: &str) -> Option<String> {
213        let entry_pattern = Regex::new(r"\.(?:visible\s+)?\.entry\s+(\w+)")
214            .expect("invariant: regex pattern is valid");
215        entry_pattern.captures(ptx).map(|c| {
216            c.get(1)
217                .expect("invariant: capture group 1 exists")
218                .as_str()
219                .to_string()
220        })
221    }
222}
223
224mod detectors;