trueno_explain/ptx/parser/
mod.rs1use crate::analyzer::{
6 AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11pub struct PtxAnalyzer {
13 pub register_warning_threshold: u32,
15 pub coalescing_warning_threshold: f32,
17}
18
19impl Default for PtxAnalyzer {
20 fn default() -> Self {
21 Self {
22 register_warning_threshold: 128,
23 coalescing_warning_threshold: 0.8,
24 }
25 }
26}
27
28impl PtxAnalyzer {
29 #[must_use]
31 pub fn new() -> Self {
32 Self::default()
33 }
34
35 fn parse_registers(&self, ptx: &str) -> RegisterUsage {
37 let mut usage = RegisterUsage::default();
38
39 let reg_pattern =
41 Regex::new(r"\.reg\s+\.(\w+)\s+%\w+<(\d+)>").expect("valid regex pattern");
42
43 for cap in reg_pattern.captures_iter(ptx) {
44 let reg_type = &cap[1];
45 let count: u32 = cap[2].parse().unwrap_or(0);
46
47 match reg_type {
48 "f32" => usage.f32_regs += count,
49 "f64" => usage.f64_regs += count,
50 "b32" | "u32" | "s32" => usage.b32_regs += count,
51 "b64" | "u64" | "s64" => usage.b64_regs += count,
52 "pred" => usage.pred_regs += count,
53 _ => {}
54 }
55 }
56
57 usage
58 }
59
60 fn parse_memory_ops(&self, ptx: &str) -> MemoryPattern {
62 let mut pattern = MemoryPattern::default();
63
64 let global_load = Regex::new(r"ld\.global").expect("valid regex pattern");
66 pattern.global_loads = global_load.find_iter(ptx).count() as u32;
67
68 let global_store = Regex::new(r"st\.global").expect("valid regex pattern");
70 pattern.global_stores = global_store.find_iter(ptx).count() as u32;
71
72 let shared_load = Regex::new(r"ld\.shared").expect("valid regex pattern");
74 pattern.shared_loads = shared_load.find_iter(ptx).count() as u32;
75
76 let shared_store = Regex::new(r"st\.shared").expect("valid regex pattern");
78 pattern.shared_stores = shared_store.find_iter(ptx).count() as u32;
79
80 let tid_pattern =
88 Regex::new(r"%tid\.[xy]|%ntid\.[xy]|%ctaid\.[xy]").expect("valid regex pattern");
89 let tid_refs = tid_pattern.find_iter(ptx).count();
90
91 let mad_pattern = Regex::new(r"mad\.lo").expect("valid regex pattern");
93 let mad_refs = mad_pattern.find_iter(ptx).count();
94
95 let mul_lo_pattern = Regex::new(r"mul\.lo").expect("valid regex pattern");
97 let mul_lo_refs = mul_lo_pattern.find_iter(ptx).count();
98
99 let stride_pattern = Regex::new(r"mul\.wide\.[us]32").expect("valid regex pattern");
101 let stride_refs = stride_pattern.find_iter(ptx).count();
102
103 let shfl_pattern = Regex::new(r"shfl\.(down|up|bfly|idx)").expect("valid regex pattern");
105 let shfl_refs = shfl_pattern.find_iter(ptx).count();
106
107 let lane_pattern = Regex::new(r"rem\.u32|div\.u32").expect("valid regex pattern");
109 let lane_refs = lane_pattern.find_iter(ptx).count();
110
111 let total_accesses = pattern.global_loads + pattern.global_stores;
112 if total_accesses > 0 {
113 let coalescing_score = tid_refs as f32
116 + (mad_refs as f32 * 0.6) + (mul_lo_refs as f32 * 0.4) + (stride_refs as f32 * 0.3) + (shfl_refs as f32 * 0.3) + (lane_refs as f32 * 0.2); pattern.coalesced_ratio = (coalescing_score / total_accesses as f32).min(1.0);
122 } else {
123 pattern.coalesced_ratio = 1.0;
124 }
125
126 pattern
127 }
128
129 fn count_instructions(&self, ptx: &str) -> u32 {
131 let instruction_pattern = Regex::new(r"^\s+(add|sub|mul|div|mad|fma|ld|st|mov|setp|bra|ret|cvt|and|or|xor|shl|shr|min|max|abs|neg|sqrt|rsqrt|sin|cos|ex2|lg2|rcp|selp|set|bar)").expect("valid regex pattern");
133
134 ptx.lines()
135 .filter(|line| instruction_pattern.is_match(line))
136 .count() as u32
137 }
138
139 fn extract_kernel_name(&self, ptx: &str) -> String {
141 let entry_pattern = Regex::new(r"\.entry\s+(\w+)").expect("valid regex pattern");
142 entry_pattern
143 .captures(ptx)
144 .map(|c| c[1].to_string())
145 .unwrap_or_else(|| "unknown".to_string())
146 }
147
148 fn detect_spills(&self, ptx: &str) -> Option<MudaWarning> {
150 let local_pattern = Regex::new(r"\.local").expect("valid regex pattern");
152 let spill_count = local_pattern.find_iter(ptx).count();
153
154 if spill_count > 0 {
155 Some(MudaWarning {
156 muda_type: MudaType::Transport,
157 description: format!("{} potential register spills detected", spill_count),
158 impact: "High latency local memory access".to_string(),
159 line: None,
160 suggestion: Some(
161 "Reduce live variables or increase register allocation".to_string(),
162 ),
163 })
164 } else {
165 None
166 }
167 }
168
169 fn detect_uncoalesced(&self, memory: &MemoryPattern) -> Option<MudaWarning> {
171 if memory.coalesced_ratio < self.coalescing_warning_threshold {
172 Some(MudaWarning {
173 muda_type: MudaType::Waiting,
174 description: format!(
175 "Memory coalescing ratio {:.1}% below threshold {:.1}%",
176 memory.coalesced_ratio * 100.0,
177 self.coalescing_warning_threshold * 100.0
178 ),
179 impact: "Serialized memory transactions, reduced bandwidth".to_string(),
180 line: None,
181 suggestion: Some(
182 "Ensure adjacent threads access adjacent memory addresses".to_string(),
183 ),
184 })
185 } else {
186 None
187 }
188 }
189
190 fn detect_register_pressure(&self, registers: &RegisterUsage) -> Option<MudaWarning> {
192 let total = registers.total();
193 if total > self.register_warning_threshold {
194 Some(MudaWarning {
195 muda_type: MudaType::Overprocessing,
196 description: format!(
197 "High register usage: {} registers (threshold: {})",
198 total, self.register_warning_threshold
199 ),
200 impact: "Reduced occupancy, fewer concurrent warps".to_string(),
201 line: None,
202 suggestion: Some(
203 "Consider loop tiling or reducing intermediate values".to_string(),
204 ),
205 })
206 } else {
207 None
208 }
209 }
210}
211
212impl Analyzer for PtxAnalyzer {
213 fn target_name(&self) -> &str {
214 "PTX"
215 }
216
217 fn analyze(&self, ptx: &str) -> Result<AnalysisReport> {
218 let registers = self.parse_registers(ptx);
219 let memory = self.parse_memory_ops(ptx);
220 let instruction_count = self.count_instructions(ptx);
221 let name = self.extract_kernel_name(ptx);
222 let warnings = self.detect_muda(ptx);
223 let estimated_occupancy = registers.estimated_occupancy();
224
225 let mut report = AnalysisReport {
226 name,
227 target: self.target_name().to_string(),
228 registers,
229 memory,
230 warnings,
231 instruction_count,
232 estimated_occupancy,
233 ..Default::default()
234 };
235
236 report.roofline = self.estimate_roofline(&report);
237 Ok(report)
238 }
239
240 fn detect_muda(&self, ptx: &str) -> Vec<MudaWarning> {
241 let mut warnings = Vec::new();
242
243 if let Some(w) = self.detect_spills(ptx) {
244 warnings.push(w);
245 }
246
247 let memory = self.parse_memory_ops(ptx);
248 if let Some(w) = self.detect_uncoalesced(&memory) {
249 warnings.push(w);
250 }
251
252 let registers = self.parse_registers(ptx);
253 if let Some(w) = self.detect_register_pressure(®isters) {
254 warnings.push(w);
255 }
256
257 warnings
258 }
259
260 fn estimate_roofline(&self, analysis: &AnalysisReport) -> RooflineMetric {
261 let mem_ops = analysis.memory.global_loads + analysis.memory.global_stores;
264 let bytes = mem_ops * 4; let flops = analysis.instruction_count; let arithmetic_intensity = if bytes > 0 {
269 flops as f32 / bytes as f32
270 } else {
271 0.0
272 };
273
274 let theoretical_peak_gflops = 15000.0;
276
277 let memory_bound = arithmetic_intensity < 10.0;
279
280 RooflineMetric {
281 arithmetic_intensity,
282 theoretical_peak_gflops,
283 memory_bound,
284 }
285 }
286}
287
288#[cfg(test)]
289mod tests;
290
291#[cfg(test)]
292mod property_tests;