Skip to main content

trueno_explain/ptx/parser/
mod.rs

1//! PTX Parser and Analyzer
2//!
3//! Implements the Analyzer trait for NVIDIA PTX assembly.
4
5use crate::analyzer::{
6    AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11/// PTX code analyzer
12pub struct PtxAnalyzer {
13    /// Warn if register count exceeds this threshold
14    pub register_warning_threshold: u32,
15    /// Warn if coalescing ratio falls below this threshold
16    pub coalescing_warning_threshold: f32,
17}
18
19impl Default for PtxAnalyzer {
20    fn default() -> Self {
21        Self {
22            register_warning_threshold: 128,
23            coalescing_warning_threshold: 0.8,
24        }
25    }
26}
27
28impl PtxAnalyzer {
29    /// Create a new PTX analyzer with default thresholds
30    #[must_use]
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Parse register declarations from PTX
36    fn parse_registers(&self, ptx: &str) -> RegisterUsage {
37        let mut usage = RegisterUsage::default();
38
39        // Match patterns like: .reg .f32 %f<24>;
40        let reg_pattern =
41            Regex::new(r"\.reg\s+\.(\w+)\s+%\w+<(\d+)>").expect("valid regex pattern");
42
43        for cap in reg_pattern.captures_iter(ptx) {
44            let reg_type = &cap[1];
45            let count: u32 = cap[2].parse().unwrap_or(0);
46
47            match reg_type {
48                "f32" => usage.f32_regs += count,
49                "f64" => usage.f64_regs += count,
50                "b32" | "u32" | "s32" => usage.b32_regs += count,
51                "b64" | "u64" | "s64" => usage.b64_regs += count,
52                "pred" => usage.pred_regs += count,
53                _ => {}
54            }
55        }
56
57        usage
58    }
59
60    /// Parse memory operations from PTX
61    fn parse_memory_ops(&self, ptx: &str) -> MemoryPattern {
62        let mut pattern = MemoryPattern::default();
63
64        // Count global loads
65        let global_load = Regex::new(r"ld\.global").expect("valid regex pattern");
66        pattern.global_loads = global_load.find_iter(ptx).count() as u32;
67
68        // Count global stores
69        let global_store = Regex::new(r"st\.global").expect("valid regex pattern");
70        pattern.global_stores = global_store.find_iter(ptx).count() as u32;
71
72        // Count shared loads
73        let shared_load = Regex::new(r"ld\.shared").expect("valid regex pattern");
74        pattern.shared_loads = shared_load.find_iter(ptx).count() as u32;
75
76        // Count shared stores
77        let shared_store = Regex::new(r"st\.shared").expect("valid regex pattern");
78        pattern.shared_stores = shared_store.find_iter(ptx).count() as u32;
79
80        // Estimate coalescing based on access patterns
81        // Coalesced access indicators:
82        // 1. tid/ctaid references (thread and block IDs - used for index computation)
83        // 2. mad.lo with tid (computing linear index from thread/block IDs)
84        // 3. mul.wide with small constant (stride-1 access)
85        // 4. shfl instructions (warp shuffle - implicit coalescing)
86        // Note: Include both x and y dimensions since 2D kernels use both
87        let tid_pattern =
88            Regex::new(r"%tid\.[xy]|%ntid\.[xy]|%ctaid\.[xy]").expect("valid regex pattern");
89        let tid_refs = tid_pattern.find_iter(ptx).count();
90
91        // mad.lo often computes coalesced indices: mad.lo %r, %ctaid, %ntid, %tid
92        let mad_pattern = Regex::new(r"mad\.lo").expect("valid regex pattern");
93        let mad_refs = mad_pattern.find_iter(ptx).count();
94
95        // mul.lo also used for index computation
96        let mul_lo_pattern = Regex::new(r"mul\.lo").expect("valid regex pattern");
97        let mul_lo_refs = mul_lo_pattern.find_iter(ptx).count();
98
99        // mul.wide with small constants indicates stride-based access
100        let stride_pattern = Regex::new(r"mul\.wide\.[us]32").expect("valid regex pattern");
101        let stride_refs = stride_pattern.find_iter(ptx).count();
102
103        // Warp shuffles indicate warp-level data sharing (inherently coalesced)
104        let shfl_pattern = Regex::new(r"shfl\.(down|up|bfly|idx)").expect("valid regex pattern");
105        let shfl_refs = shfl_pattern.find_iter(ptx).count();
106
107        // rem/div operations often used for lane computation in coalesced patterns
108        let lane_pattern = Regex::new(r"rem\.u32|div\.u32").expect("valid regex pattern");
109        let lane_refs = lane_pattern.find_iter(ptx).count();
110
111        let total_accesses = pattern.global_loads + pattern.global_stores;
112        if total_accesses > 0 {
113            // Improved heuristic: weight different indicators
114            // Each indicator suggests thread-based indexing which implies coalescing potential
115            let coalescing_score = tid_refs as f32
116                + (mad_refs as f32 * 0.6)  // mad.lo strongly indicates index computation
117                + (mul_lo_refs as f32 * 0.4) // mul.lo also used for indices
118                + (stride_refs as f32 * 0.3) // stride patterns
119                + (shfl_refs as f32 * 0.3)  // warp shuffles
120                + (lane_refs as f32 * 0.2); // lane computation
121            pattern.coalesced_ratio = (coalescing_score / total_accesses as f32).min(1.0);
122        } else {
123            pattern.coalesced_ratio = 1.0;
124        }
125
126        pattern
127    }
128
129    /// Count total instructions
130    fn count_instructions(&self, ptx: &str) -> u32 {
131        // Count lines that look like instructions (not directives or labels)
132        let instruction_pattern = Regex::new(r"^\s+(add|sub|mul|div|mad|fma|ld|st|mov|setp|bra|ret|cvt|and|or|xor|shl|shr|min|max|abs|neg|sqrt|rsqrt|sin|cos|ex2|lg2|rcp|selp|set|bar)").expect("valid regex pattern");
133
134        ptx.lines()
135            .filter(|line| instruction_pattern.is_match(line))
136            .count() as u32
137    }
138
139    /// Extract kernel name from PTX
140    fn extract_kernel_name(&self, ptx: &str) -> String {
141        let entry_pattern = Regex::new(r"\.entry\s+(\w+)").expect("valid regex pattern");
142        entry_pattern
143            .captures(ptx)
144            .map(|c| c[1].to_string())
145            .unwrap_or_else(|| "unknown".to_string())
146    }
147
148    /// Detect spills (Muda of Transport)
149    fn detect_spills(&self, ptx: &str) -> Option<MudaWarning> {
150        // Spills manifest as .local memory usage
151        let local_pattern = Regex::new(r"\.local").expect("valid regex pattern");
152        let spill_count = local_pattern.find_iter(ptx).count();
153
154        if spill_count > 0 {
155            Some(MudaWarning {
156                muda_type: MudaType::Transport,
157                description: format!("{} potential register spills detected", spill_count),
158                impact: "High latency local memory access".to_string(),
159                line: None,
160                suggestion: Some(
161                    "Reduce live variables or increase register allocation".to_string(),
162                ),
163            })
164        } else {
165            None
166        }
167    }
168
169    /// Detect uncoalesced access (Muda of Waiting)
170    fn detect_uncoalesced(&self, memory: &MemoryPattern) -> Option<MudaWarning> {
171        if memory.coalesced_ratio < self.coalescing_warning_threshold {
172            Some(MudaWarning {
173                muda_type: MudaType::Waiting,
174                description: format!(
175                    "Memory coalescing ratio {:.1}% below threshold {:.1}%",
176                    memory.coalesced_ratio * 100.0,
177                    self.coalescing_warning_threshold * 100.0
178                ),
179                impact: "Serialized memory transactions, reduced bandwidth".to_string(),
180                line: None,
181                suggestion: Some(
182                    "Ensure adjacent threads access adjacent memory addresses".to_string(),
183                ),
184            })
185        } else {
186            None
187        }
188    }
189
190    /// Detect excessive register usage
191    fn detect_register_pressure(&self, registers: &RegisterUsage) -> Option<MudaWarning> {
192        let total = registers.total();
193        if total > self.register_warning_threshold {
194            Some(MudaWarning {
195                muda_type: MudaType::Overprocessing,
196                description: format!(
197                    "High register usage: {} registers (threshold: {})",
198                    total, self.register_warning_threshold
199                ),
200                impact: "Reduced occupancy, fewer concurrent warps".to_string(),
201                line: None,
202                suggestion: Some(
203                    "Consider loop tiling or reducing intermediate values".to_string(),
204                ),
205            })
206        } else {
207            None
208        }
209    }
210}
211
212impl Analyzer for PtxAnalyzer {
213    fn target_name(&self) -> &str {
214        "PTX"
215    }
216
217    fn analyze(&self, ptx: &str) -> Result<AnalysisReport> {
218        let registers = self.parse_registers(ptx);
219        let memory = self.parse_memory_ops(ptx);
220        let instruction_count = self.count_instructions(ptx);
221        let name = self.extract_kernel_name(ptx);
222        let warnings = self.detect_muda(ptx);
223        let estimated_occupancy = registers.estimated_occupancy();
224
225        let mut report = AnalysisReport {
226            name,
227            target: self.target_name().to_string(),
228            registers,
229            memory,
230            warnings,
231            instruction_count,
232            estimated_occupancy,
233            ..Default::default()
234        };
235
236        report.roofline = self.estimate_roofline(&report);
237        Ok(report)
238    }
239
240    fn detect_muda(&self, ptx: &str) -> Vec<MudaWarning> {
241        let mut warnings = Vec::new();
242
243        if let Some(w) = self.detect_spills(ptx) {
244            warnings.push(w);
245        }
246
247        let memory = self.parse_memory_ops(ptx);
248        if let Some(w) = self.detect_uncoalesced(&memory) {
249            warnings.push(w);
250        }
251
252        let registers = self.parse_registers(ptx);
253        if let Some(w) = self.detect_register_pressure(&registers) {
254            warnings.push(w);
255        }
256
257        warnings
258    }
259
260    fn estimate_roofline(&self, analysis: &AnalysisReport) -> RooflineMetric {
261        // Simplified roofline model
262        // Arithmetic intensity = FLOPs / Bytes transferred
263        let mem_ops = analysis.memory.global_loads + analysis.memory.global_stores;
264        let bytes = mem_ops * 4; // Assume f32
265
266        let flops = analysis.instruction_count; // Rough approximation
267
268        let arithmetic_intensity = if bytes > 0 {
269            flops as f32 / bytes as f32
270        } else {
271            0.0
272        };
273
274        // SM 7.0 theoretical peak: ~15 TFLOPS (varies by GPU)
275        let theoretical_peak_gflops = 15000.0;
276
277        // Memory bound if AI < ridge point (typically ~10 for modern GPUs)
278        let memory_bound = arithmetic_intensity < 10.0;
279
280        RooflineMetric {
281            arithmetic_intensity,
282            theoretical_peak_gflops,
283            memory_bound,
284        }
285    }
286}
287
288#[cfg(test)]
289mod tests;
290
291#[cfg(test)]
292mod property_tests;