Skip to main content

trueno_explain/
wgpu.rs

1//! wgpu/WGSL shader analyzer
2//!
3//! Analyzes WGSL compute shaders for workgroup configuration and potential issues.
4
5use crate::analyzer::{
6    AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11/// Workgroup size configuration
12#[derive(Debug, Clone, Default, PartialEq, Eq)]
13pub struct WorkgroupSize {
14    /// X dimension
15    pub x: u32,
16    /// Y dimension
17    pub y: u32,
18    /// Z dimension
19    pub z: u32,
20}
21
22impl WorkgroupSize {
23    /// Total threads per workgroup
24    #[must_use]
25    pub fn total(&self) -> u32 {
26        self.x * self.y * self.z
27    }
28
29    /// Check if workgroup size is optimal for GPU (multiple of 32 for warp efficiency)
30    #[must_use]
31    pub fn is_warp_aligned(&self) -> bool {
32        self.total().is_multiple_of(32)
33    }
34}
35
36/// WGSL shader statistics
37#[derive(Debug, Clone, Default)]
38pub struct WgslStats {
39    /// Workgroup size from `@workgroup_size` attribute
40    pub workgroup_size: WorkgroupSize,
41    /// Number of storage buffer bindings
42    pub storage_buffers: u32,
43    /// Number of uniform buffer bindings
44    pub uniform_buffers: u32,
45    /// Number of texture bindings
46    pub textures: u32,
47    /// Number of arithmetic operations
48    pub arithmetic_ops: u32,
49    /// Number of memory operations
50    pub memory_ops: u32,
51}
52
53/// WGSL/wgpu compute shader analyzer
54pub struct WgpuAnalyzer {
55    /// Minimum workgroup size for efficiency warning
56    pub min_workgroup_size: u32,
57    /// Maximum workgroup size before occupancy warning
58    pub max_workgroup_size: u32,
59}
60
61impl Default for WgpuAnalyzer {
62    fn default() -> Self {
63        Self {
64            min_workgroup_size: 64,
65            max_workgroup_size: 1024,
66        }
67    }
68}
69
70impl WgpuAnalyzer {
71    /// Create a new WGSL analyzer
72    #[must_use]
73    pub fn new() -> Self {
74        Self::default()
75    }
76
77    /// Parse workgroup size from WGSL
78    fn parse_workgroup_size(&self, wgsl: &str) -> WorkgroupSize {
79        // Match @workgroup_size(x), @workgroup_size(x, y), or @workgroup_size(x, y, z)
80        let pattern =
81            Regex::new(r"@workgroup_size\s*\(\s*(\d+)(?:\s*,\s*(\d+))?(?:\s*,\s*(\d+))?\s*\)")
82                .expect("valid regex pattern");
83
84        if let Some(caps) = pattern.captures(wgsl) {
85            let x = caps.get(1).map_or(1, |m| m.as_str().parse().unwrap_or(1));
86            let y = caps.get(2).map_or(1, |m| m.as_str().parse().unwrap_or(1));
87            let z = caps.get(3).map_or(1, |m| m.as_str().parse().unwrap_or(1));
88            WorkgroupSize { x, y, z }
89        } else {
90            WorkgroupSize { x: 1, y: 1, z: 1 }
91        }
92    }
93
94    /// Count bindings in WGSL
95    fn count_bindings(&self, wgsl: &str) -> (u32, u32, u32) {
96        let storage_pattern = Regex::new(r"var<storage").expect("valid regex pattern");
97        let uniform_pattern = Regex::new(r"var<uniform>").expect("valid regex pattern");
98        let texture_pattern = Regex::new(r"texture_\w+<").expect("valid regex pattern");
99
100        let storage = storage_pattern.find_iter(wgsl).count() as u32;
101        let uniform = uniform_pattern.find_iter(wgsl).count() as u32;
102        let textures = texture_pattern.find_iter(wgsl).count() as u32;
103
104        (storage, uniform, textures)
105    }
106
107    /// Count operations in WGSL
108    fn count_operations(&self, wgsl: &str) -> (u32, u32) {
109        // Arithmetic: +, -, *, /, dot, cross, etc.
110        let arith_pattern =
111            Regex::new(r"(\+|-|\*|/|dot|cross|normalize|length|sqrt|pow|exp|log|sin|cos|tan)")
112                .expect("valid regex pattern");
113        // Memory: load, store, array access
114        let mem_pattern = Regex::new(r"(\[[\w\s+\-*/]+\]|textureLoad|textureSample|textureStore)")
115            .expect("valid regex pattern");
116
117        let arith = arith_pattern.find_iter(wgsl).count() as u32;
118        let mem = mem_pattern.find_iter(wgsl).count() as u32;
119
120        (arith, mem)
121    }
122
123    /// Analyze WGSL shader
124    fn analyze_wgsl(&self, wgsl: &str) -> WgslStats {
125        let workgroup_size = self.parse_workgroup_size(wgsl);
126        let (storage_buffers, uniform_buffers, textures) = self.count_bindings(wgsl);
127        let (arithmetic_ops, memory_ops) = self.count_operations(wgsl);
128
129        WgslStats {
130            workgroup_size,
131            storage_buffers,
132            uniform_buffers,
133            textures,
134            arithmetic_ops,
135            memory_ops,
136        }
137    }
138
139    /// Detect potential issues in WGSL
140    fn detect_wgsl_muda(&self, stats: &WgslStats) -> Vec<MudaWarning> {
141        let mut warnings = Vec::new();
142
143        // Check workgroup size
144        let total = stats.workgroup_size.total();
145
146        if total < self.min_workgroup_size {
147            warnings.push(MudaWarning {
148                muda_type: MudaType::Waiting,
149                description: format!(
150                    "Small workgroup size: {} threads (minimum recommended: {})",
151                    total, self.min_workgroup_size
152                ),
153                impact: "Low GPU occupancy, potential for underutilization".to_string(),
154                line: None,
155                suggestion: Some(format!(
156                    "Consider increasing workgroup size to at least {}",
157                    self.min_workgroup_size
158                )),
159            });
160        }
161
162        if total > self.max_workgroup_size {
163            warnings.push(MudaWarning {
164                muda_type: MudaType::Overprocessing,
165                description: format!(
166                    "Large workgroup size: {} threads (maximum recommended: {})",
167                    total, self.max_workgroup_size
168                ),
169                impact: "May cause register pressure and reduce occupancy".to_string(),
170                line: None,
171                suggestion: Some(format!(
172                    "Consider reducing workgroup size to at most {}",
173                    self.max_workgroup_size
174                )),
175            });
176        }
177
178        if !stats.workgroup_size.is_warp_aligned() && total > 1 {
179            warnings.push(MudaWarning {
180                muda_type: MudaType::Waiting,
181                description: format!(
182                    "Workgroup size {} is not a multiple of 32 (warp size)",
183                    total
184                ),
185                impact: "Partial warp execution wastes GPU cycles".to_string(),
186                line: None,
187                suggestion: Some("Align workgroup size to a multiple of 32".to_string()),
188            });
189        }
190
191        warnings
192    }
193}
194
195impl Analyzer for WgpuAnalyzer {
196    fn target_name(&self) -> &str {
197        "WGSL (wgpu)"
198    }
199
200    fn analyze(&self, wgsl: &str) -> Result<AnalysisReport> {
201        let stats = self.analyze_wgsl(wgsl);
202        let warnings = self.detect_muda(wgsl);
203
204        // Estimate instruction count from operations
205        let instruction_count = stats.arithmetic_ops + stats.memory_ops;
206
207        // Estimate "occupancy" based on workgroup efficiency
208        let total_threads = stats.workgroup_size.total();
209        let occupancy = if total_threads >= self.min_workgroup_size {
210            (total_threads as f32 / self.max_workgroup_size as f32).min(1.0)
211        } else {
212            total_threads as f32 / self.min_workgroup_size as f32
213        };
214
215        Ok(AnalysisReport {
216            name: "wgsl_analysis".to_string(),
217            target: self.target_name().to_string(),
218            registers: RegisterUsage::default(), // WGSL doesn't expose register info
219            memory: MemoryPattern {
220                global_loads: stats.memory_ops,
221                global_stores: 0,
222                shared_loads: 0,
223                shared_stores: 0,
224                coalesced_ratio: 1.0, // Assume coalesced by default
225            },
226            roofline: self.estimate_roofline(&AnalysisReport::default()),
227            warnings,
228            instruction_count,
229            estimated_occupancy: occupancy,
230        })
231    }
232
233    fn detect_muda(&self, wgsl: &str) -> Vec<MudaWarning> {
234        let stats = self.analyze_wgsl(wgsl);
235        self.detect_wgsl_muda(&stats)
236    }
237
238    fn estimate_roofline(&self, _analysis: &AnalysisReport) -> RooflineMetric {
239        RooflineMetric {
240            arithmetic_intensity: 1.0,
241            theoretical_peak_gflops: 500.0, // Placeholder for wgpu
242            memory_bound: true,
243        }
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_parse_workgroup_size_1d() {
253        let wgsl = "@workgroup_size(64)\nfn main() {}";
254        let analyzer = WgpuAnalyzer::new();
255        let size = analyzer.parse_workgroup_size(wgsl);
256
257        assert_eq!(size.x, 64);
258        assert_eq!(size.y, 1);
259        assert_eq!(size.z, 1);
260        assert_eq!(size.total(), 64);
261    }
262
263    #[test]
264    fn test_parse_workgroup_size_2d() {
265        let wgsl = "@workgroup_size(8, 8)\nfn main() {}";
266        let analyzer = WgpuAnalyzer::new();
267        let size = analyzer.parse_workgroup_size(wgsl);
268
269        assert_eq!(size.x, 8);
270        assert_eq!(size.y, 8);
271        assert_eq!(size.z, 1);
272        assert_eq!(size.total(), 64);
273    }
274
275    #[test]
276    fn test_parse_workgroup_size_3d() {
277        let wgsl = "@workgroup_size(4, 4, 4)\nfn main() {}";
278        let analyzer = WgpuAnalyzer::new();
279        let size = analyzer.parse_workgroup_size(wgsl);
280
281        assert_eq!(size.x, 4);
282        assert_eq!(size.y, 4);
283        assert_eq!(size.z, 4);
284        assert_eq!(size.total(), 64);
285    }
286
287    #[test]
288    fn test_parse_workgroup_size_missing() {
289        let wgsl = "fn main() {}";
290        let analyzer = WgpuAnalyzer::new();
291        let size = analyzer.parse_workgroup_size(wgsl);
292
293        assert_eq!(size.total(), 1);
294    }
295
296    #[test]
297    fn test_warp_aligned() {
298        assert!(WorkgroupSize { x: 64, y: 1, z: 1 }.is_warp_aligned());
299        assert!(WorkgroupSize { x: 8, y: 8, z: 1 }.is_warp_aligned());
300        assert!(WorkgroupSize { x: 256, y: 1, z: 1 }.is_warp_aligned());
301        assert!(!WorkgroupSize { x: 33, y: 1, z: 1 }.is_warp_aligned());
302        assert!(!WorkgroupSize { x: 7, y: 7, z: 1 }.is_warp_aligned());
303    }
304
305    #[test]
306    fn test_count_bindings() {
307        let wgsl = r"
308            @group(0) @binding(0) var<storage, read> input: array<f32>;
309            @group(0) @binding(1) var<storage, read_write> output: array<f32>;
310            @group(0) @binding(2) var<uniform> params: Params;
311        ";
312        let analyzer = WgpuAnalyzer::new();
313        let (storage, uniform, textures) = analyzer.count_bindings(wgsl);
314
315        assert_eq!(storage, 2);
316        assert_eq!(uniform, 1);
317        assert_eq!(textures, 0);
318    }
319
320    #[test]
321    fn test_detect_small_workgroup() {
322        let wgsl = "@workgroup_size(8)\nfn main() {}";
323        let analyzer = WgpuAnalyzer::new();
324        let warnings = analyzer.detect_muda(wgsl);
325
326        assert!(!warnings.is_empty(), "Should warn on small workgroup");
327        assert!(warnings
328            .iter()
329            .any(|w| w.description.contains("Small workgroup")));
330    }
331
332    #[test]
333    fn test_detect_non_warp_aligned() {
334        let wgsl = "@workgroup_size(33)\nfn main() {}";
335        let analyzer = WgpuAnalyzer::new();
336        let warnings = analyzer.detect_muda(wgsl);
337
338        assert!(warnings
339            .iter()
340            .any(|w| w.description.contains("not a multiple of 32")));
341    }
342
343    #[test]
344    fn test_optimal_workgroup_no_warnings() {
345        let wgsl = "@workgroup_size(256)\nfn main() {}";
346        let analyzer = WgpuAnalyzer::new();
347        let warnings = analyzer.detect_muda(wgsl);
348
349        // 256 is warp-aligned and within bounds
350        assert!(
351            warnings.is_empty(),
352            "Optimal workgroup should have no warnings"
353        );
354    }
355
356    /// F067: Detects workgroup size
357    #[test]
358    fn f067_detect_workgroup_size() {
359        let wgsl = r"
360            @compute @workgroup_size(64, 4, 1)
361            fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
362                // compute work
363            }
364        ";
365
366        let analyzer = WgpuAnalyzer::new();
367        let report = analyzer.analyze(wgsl).unwrap();
368
369        // Verify analysis completed
370        assert_eq!(report.target, "WGSL (wgpu)");
371
372        // Verify workgroup was parsed (64*4*1 = 256 threads)
373        let stats = analyzer.analyze_wgsl(wgsl);
374        assert_eq!(stats.workgroup_size.x, 64);
375        assert_eq!(stats.workgroup_size.y, 4);
376        assert_eq!(stats.workgroup_size.z, 1);
377        assert_eq!(stats.workgroup_size.total(), 256);
378    }
379
380    #[test]
381    fn test_analyze_full_wgsl() {
382        let wgsl = r"
383            struct Params {
384                size: u32,
385            }
386
387            @group(0) @binding(0) var<storage, read> a: array<f32>;
388            @group(0) @binding(1) var<storage, read> b: array<f32>;
389            @group(0) @binding(2) var<storage, read_write> result: array<f32>;
390            @group(0) @binding(3) var<uniform> params: Params;
391
392            @compute @workgroup_size(256)
393            fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
394                let idx = gid.x;
395                if idx < params.size {
396                    result[idx] = a[idx] + b[idx];
397                }
398            }
399        ";
400
401        let analyzer = WgpuAnalyzer::new();
402        let report = analyzer.analyze(wgsl).unwrap();
403
404        assert_eq!(report.target, "WGSL (wgpu)");
405        assert!(
406            report.warnings.is_empty(),
407            "Valid WGSL should have no warnings"
408        );
409    }
410}