trueno_explain/ptx/bugs/analyzer/
mod.rs1use regex::Regex;
2
3use super::types::{PtxBugClass, PtxBugReport};
4
5#[derive(Debug, Clone)]
7pub struct WhitelistEntry {
8 pub kernel_pattern: String,
10 pub bug_class: PtxBugClass,
12 pub reason: String,
14}
15
16#[derive(Debug, Default, Clone)]
18pub struct PtxBugAnalyzer {
19 pub strict: bool,
21 pub whitelist: Vec<WhitelistEntry>,
23}
24
25impl PtxBugAnalyzer {
26 #[must_use]
28 pub fn new() -> Self {
29 Self::default()
30 }
31
32 #[must_use]
34 pub fn strict() -> Self {
35 Self {
36 strict: true,
37 whitelist: Vec::new(),
38 }
39 }
40
41 #[must_use]
43 pub fn with_whitelist(
44 mut self,
45 kernel_pattern: &str,
46 bug_class: PtxBugClass,
47 reason: &str,
48 ) -> Self {
49 self.whitelist.push(WhitelistEntry {
50 kernel_pattern: kernel_pattern.to_string(),
51 bug_class,
52 reason: reason.to_string(),
53 });
54 self
55 }
56
57 #[must_use]
59 pub fn with_quantized_whitelist() -> Self {
60 Self::new()
61 .with_whitelist(
62 "q4k*",
63 PtxBugClass::HighRegisterPressure,
64 "Quantized kernels require high registers for dequantization",
65 )
66 .with_whitelist(
67 "q5k*",
68 PtxBugClass::HighRegisterPressure,
69 "Quantized kernels require high registers for dequantization",
70 )
71 .with_whitelist(
72 "q6k*",
73 PtxBugClass::HighRegisterPressure,
74 "Quantized kernels require high registers for dequantization",
75 )
76 .with_whitelist(
77 "q8k*",
78 PtxBugClass::HighRegisterPressure,
79 "Quantized kernels require high registers for dequantization",
80 )
81 }
82
83 #[must_use]
92 pub fn with_performance_whitelist() -> Self {
93 Self::new()
94 .with_whitelist(
98 "gemm_tensor_core*",
99 PtxBugClass::HighRegisterPressure,
100 "Tensor Core WMMA requires many registers for matrix fragments",
101 )
102 .with_whitelist(
103 "gemm_tensor_core*",
104 PtxBugClass::PredicateOverflow,
105 "Tensor Core kernels use predicates for bounds checking and masking",
106 )
107 .with_whitelist(
108 "gemm_wmma*",
109 PtxBugClass::HighRegisterPressure,
110 "WMMA FP16 requires registers for A/B/C/D matrix fragments",
111 )
112 .with_whitelist(
113 "gemm_wmma*",
114 PtxBugClass::PredicateOverflow,
115 "WMMA kernels use predicates for tile boundary handling",
116 )
117 .with_whitelist(
119 "flash_attention*",
120 PtxBugClass::HighRegisterPressure,
121 "FlashAttention tiling requires registers for Q/K/V/O tiles and softmax state",
122 )
123 .with_whitelist(
124 "attention*",
125 PtxBugClass::HighRegisterPressure,
126 "Attention kernels require registers for Q/K/V tiles and reduction",
127 )
128 .with_whitelist(
130 "q4k*",
131 PtxBugClass::HighRegisterPressure,
132 "Q4_K dequantization requires registers for scale/min extraction",
133 )
134 .with_whitelist(
135 "q5k*",
136 PtxBugClass::HighRegisterPressure,
137 "Q5_K dequantization requires registers for 5-bit value reconstruction",
138 )
139 .with_whitelist(
140 "q6k*",
141 PtxBugClass::HighRegisterPressure,
142 "Q6_K dequantization requires registers for 6-bit value reconstruction",
143 )
144 .with_whitelist(
145 "q8k*",
146 PtxBugClass::HighRegisterPressure,
147 "Q8_K dequantization requires registers for scale application",
148 )
149 }
150
151 fn is_whitelisted(&self, kernel_name: Option<&String>, bug_class: &PtxBugClass) -> bool {
153 let Some(kernel) = kernel_name else {
154 return false;
155 };
156
157 for entry in &self.whitelist {
158 if &entry.bug_class != bug_class {
159 continue;
160 }
161 if entry.kernel_pattern.ends_with('*') {
163 let prefix = &entry.kernel_pattern[..entry.kernel_pattern.len() - 1];
164 if kernel.starts_with(prefix) {
165 return true;
166 }
167 } else if &entry.kernel_pattern == kernel {
168 return true;
169 }
170 }
171 false
172 }
173
174 #[must_use]
176 pub fn analyze(&self, ptx: &str) -> PtxBugReport {
177 let mut bugs = Vec::new();
178 let lines: Vec<&str> = ptx.lines().collect();
179
180 let kernel_name = self.extract_kernel_name(ptx);
182
183 bugs.extend(self.detect_shared_mem_u64(ptx, &lines));
185 bugs.extend(self.detect_loop_branch_to_end(ptx, &lines));
186 bugs.extend(self.detect_missing_barrier_sync(ptx, &lines));
187 bugs.extend(self.detect_early_exit_before_barrier(ptx));
188 bugs.extend(self.detect_register_spills(ptx, &lines));
189 bugs.extend(self.detect_missing_entry_point(ptx, &lines));
190 bugs.extend(self.detect_redundant_moves(ptx, &lines));
191 bugs.extend(self.detect_unoptimized_memory(ptx, &lines));
192 bugs.extend(self.detect_high_register_pressure(ptx, &lines));
193 bugs.extend(self.detect_predicate_overflow(ptx, &lines));
194 bugs.extend(self.detect_placeholder_code(ptx, &lines));
195 bugs.extend(self.detect_empty_loop_body(ptx, &lines));
197 bugs.extend(self.detect_missing_bounds_check(ptx, &lines));
198 bugs.extend(self.detect_dead_code(ptx, &lines));
199
200 bugs.retain(|bug| !self.is_whitelisted(kernel_name.as_ref(), &bug.class));
202
203 PtxBugReport {
204 kernel_name,
205 bugs,
206 lines_analyzed: lines.len(),
207 strict_mode: self.strict,
208 }
209 }
210
211 fn extract_kernel_name(&self, ptx: &str) -> Option<String> {
213 let entry_pattern = Regex::new(r"\.(?:visible\s+)?\.entry\s+(\w+)")
214 .expect("invariant: regex pattern is valid");
215 entry_pattern.captures(ptx).map(|c| {
216 c.get(1)
217 .expect("invariant: capture group 1 exists")
218 .as_str()
219 .to_string()
220 })
221 }
222}
223
224mod detectors;