1use crate::analyzer::{
6 AnalysisReport, Analyzer, MemoryPattern, MudaType, MudaWarning, RegisterUsage, RooflineMetric,
7};
8use crate::error::Result;
9use regex::Regex;
10
11#[derive(Debug, Clone, Default, PartialEq, Eq)]
13pub struct WorkgroupSize {
14 pub x: u32,
16 pub y: u32,
18 pub z: u32,
20}
21
22impl WorkgroupSize {
23 #[must_use]
25 pub fn total(&self) -> u32 {
26 self.x * self.y * self.z
27 }
28
29 #[must_use]
31 pub fn is_warp_aligned(&self) -> bool {
32 self.total().is_multiple_of(32)
33 }
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct WgslStats {
39 pub workgroup_size: WorkgroupSize,
41 pub storage_buffers: u32,
43 pub uniform_buffers: u32,
45 pub textures: u32,
47 pub arithmetic_ops: u32,
49 pub memory_ops: u32,
51}
52
53pub struct WgpuAnalyzer {
55 pub min_workgroup_size: u32,
57 pub max_workgroup_size: u32,
59}
60
61impl Default for WgpuAnalyzer {
62 fn default() -> Self {
63 Self {
64 min_workgroup_size: 64,
65 max_workgroup_size: 1024,
66 }
67 }
68}
69
70impl WgpuAnalyzer {
71 #[must_use]
73 pub fn new() -> Self {
74 Self::default()
75 }
76
77 fn parse_workgroup_size(&self, wgsl: &str) -> WorkgroupSize {
79 let pattern =
81 Regex::new(r"@workgroup_size\s*\(\s*(\d+)(?:\s*,\s*(\d+))?(?:\s*,\s*(\d+))?\s*\)")
82 .expect("valid regex pattern");
83
84 if let Some(caps) = pattern.captures(wgsl) {
85 let x = caps.get(1).map_or(1, |m| m.as_str().parse().unwrap_or(1));
86 let y = caps.get(2).map_or(1, |m| m.as_str().parse().unwrap_or(1));
87 let z = caps.get(3).map_or(1, |m| m.as_str().parse().unwrap_or(1));
88 WorkgroupSize { x, y, z }
89 } else {
90 WorkgroupSize { x: 1, y: 1, z: 1 }
91 }
92 }
93
94 fn count_bindings(&self, wgsl: &str) -> (u32, u32, u32) {
96 let storage_pattern = Regex::new(r"var<storage").expect("valid regex pattern");
97 let uniform_pattern = Regex::new(r"var<uniform>").expect("valid regex pattern");
98 let texture_pattern = Regex::new(r"texture_\w+<").expect("valid regex pattern");
99
100 let storage = storage_pattern.find_iter(wgsl).count() as u32;
101 let uniform = uniform_pattern.find_iter(wgsl).count() as u32;
102 let textures = texture_pattern.find_iter(wgsl).count() as u32;
103
104 (storage, uniform, textures)
105 }
106
107 fn count_operations(&self, wgsl: &str) -> (u32, u32) {
109 let arith_pattern =
111 Regex::new(r"(\+|-|\*|/|dot|cross|normalize|length|sqrt|pow|exp|log|sin|cos|tan)")
112 .expect("valid regex pattern");
113 let mem_pattern = Regex::new(r"(\[[\w\s+\-*/]+\]|textureLoad|textureSample|textureStore)")
115 .expect("valid regex pattern");
116
117 let arith = arith_pattern.find_iter(wgsl).count() as u32;
118 let mem = mem_pattern.find_iter(wgsl).count() as u32;
119
120 (arith, mem)
121 }
122
123 fn analyze_wgsl(&self, wgsl: &str) -> WgslStats {
125 let workgroup_size = self.parse_workgroup_size(wgsl);
126 let (storage_buffers, uniform_buffers, textures) = self.count_bindings(wgsl);
127 let (arithmetic_ops, memory_ops) = self.count_operations(wgsl);
128
129 WgslStats {
130 workgroup_size,
131 storage_buffers,
132 uniform_buffers,
133 textures,
134 arithmetic_ops,
135 memory_ops,
136 }
137 }
138
139 fn detect_wgsl_muda(&self, stats: &WgslStats) -> Vec<MudaWarning> {
141 let mut warnings = Vec::new();
142
143 let total = stats.workgroup_size.total();
145
146 if total < self.min_workgroup_size {
147 warnings.push(MudaWarning {
148 muda_type: MudaType::Waiting,
149 description: format!(
150 "Small workgroup size: {} threads (minimum recommended: {})",
151 total, self.min_workgroup_size
152 ),
153 impact: "Low GPU occupancy, potential for underutilization".to_string(),
154 line: None,
155 suggestion: Some(format!(
156 "Consider increasing workgroup size to at least {}",
157 self.min_workgroup_size
158 )),
159 });
160 }
161
162 if total > self.max_workgroup_size {
163 warnings.push(MudaWarning {
164 muda_type: MudaType::Overprocessing,
165 description: format!(
166 "Large workgroup size: {} threads (maximum recommended: {})",
167 total, self.max_workgroup_size
168 ),
169 impact: "May cause register pressure and reduce occupancy".to_string(),
170 line: None,
171 suggestion: Some(format!(
172 "Consider reducing workgroup size to at most {}",
173 self.max_workgroup_size
174 )),
175 });
176 }
177
178 if !stats.workgroup_size.is_warp_aligned() && total > 1 {
179 warnings.push(MudaWarning {
180 muda_type: MudaType::Waiting,
181 description: format!(
182 "Workgroup size {} is not a multiple of 32 (warp size)",
183 total
184 ),
185 impact: "Partial warp execution wastes GPU cycles".to_string(),
186 line: None,
187 suggestion: Some("Align workgroup size to a multiple of 32".to_string()),
188 });
189 }
190
191 warnings
192 }
193}
194
195impl Analyzer for WgpuAnalyzer {
196 fn target_name(&self) -> &str {
197 "WGSL (wgpu)"
198 }
199
200 fn analyze(&self, wgsl: &str) -> Result<AnalysisReport> {
201 let stats = self.analyze_wgsl(wgsl);
202 let warnings = self.detect_muda(wgsl);
203
204 let instruction_count = stats.arithmetic_ops + stats.memory_ops;
206
207 let total_threads = stats.workgroup_size.total();
209 let occupancy = if total_threads >= self.min_workgroup_size {
210 (total_threads as f32 / self.max_workgroup_size as f32).min(1.0)
211 } else {
212 total_threads as f32 / self.min_workgroup_size as f32
213 };
214
215 Ok(AnalysisReport {
216 name: "wgsl_analysis".to_string(),
217 target: self.target_name().to_string(),
218 registers: RegisterUsage::default(), memory: MemoryPattern {
220 global_loads: stats.memory_ops,
221 global_stores: 0,
222 shared_loads: 0,
223 shared_stores: 0,
224 coalesced_ratio: 1.0, },
226 roofline: self.estimate_roofline(&AnalysisReport::default()),
227 warnings,
228 instruction_count,
229 estimated_occupancy: occupancy,
230 })
231 }
232
233 fn detect_muda(&self, wgsl: &str) -> Vec<MudaWarning> {
234 let stats = self.analyze_wgsl(wgsl);
235 self.detect_wgsl_muda(&stats)
236 }
237
238 fn estimate_roofline(&self, _analysis: &AnalysisReport) -> RooflineMetric {
239 RooflineMetric {
240 arithmetic_intensity: 1.0,
241 theoretical_peak_gflops: 500.0, memory_bound: true,
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_parse_workgroup_size_1d() {
253 let wgsl = "@workgroup_size(64)\nfn main() {}";
254 let analyzer = WgpuAnalyzer::new();
255 let size = analyzer.parse_workgroup_size(wgsl);
256
257 assert_eq!(size.x, 64);
258 assert_eq!(size.y, 1);
259 assert_eq!(size.z, 1);
260 assert_eq!(size.total(), 64);
261 }
262
263 #[test]
264 fn test_parse_workgroup_size_2d() {
265 let wgsl = "@workgroup_size(8, 8)\nfn main() {}";
266 let analyzer = WgpuAnalyzer::new();
267 let size = analyzer.parse_workgroup_size(wgsl);
268
269 assert_eq!(size.x, 8);
270 assert_eq!(size.y, 8);
271 assert_eq!(size.z, 1);
272 assert_eq!(size.total(), 64);
273 }
274
275 #[test]
276 fn test_parse_workgroup_size_3d() {
277 let wgsl = "@workgroup_size(4, 4, 4)\nfn main() {}";
278 let analyzer = WgpuAnalyzer::new();
279 let size = analyzer.parse_workgroup_size(wgsl);
280
281 assert_eq!(size.x, 4);
282 assert_eq!(size.y, 4);
283 assert_eq!(size.z, 4);
284 assert_eq!(size.total(), 64);
285 }
286
287 #[test]
288 fn test_parse_workgroup_size_missing() {
289 let wgsl = "fn main() {}";
290 let analyzer = WgpuAnalyzer::new();
291 let size = analyzer.parse_workgroup_size(wgsl);
292
293 assert_eq!(size.total(), 1);
294 }
295
296 #[test]
297 fn test_warp_aligned() {
298 assert!(WorkgroupSize { x: 64, y: 1, z: 1 }.is_warp_aligned());
299 assert!(WorkgroupSize { x: 8, y: 8, z: 1 }.is_warp_aligned());
300 assert!(WorkgroupSize { x: 256, y: 1, z: 1 }.is_warp_aligned());
301 assert!(!WorkgroupSize { x: 33, y: 1, z: 1 }.is_warp_aligned());
302 assert!(!WorkgroupSize { x: 7, y: 7, z: 1 }.is_warp_aligned());
303 }
304
305 #[test]
306 fn test_count_bindings() {
307 let wgsl = r"
308 @group(0) @binding(0) var<storage, read> input: array<f32>;
309 @group(0) @binding(1) var<storage, read_write> output: array<f32>;
310 @group(0) @binding(2) var<uniform> params: Params;
311 ";
312 let analyzer = WgpuAnalyzer::new();
313 let (storage, uniform, textures) = analyzer.count_bindings(wgsl);
314
315 assert_eq!(storage, 2);
316 assert_eq!(uniform, 1);
317 assert_eq!(textures, 0);
318 }
319
320 #[test]
321 fn test_detect_small_workgroup() {
322 let wgsl = "@workgroup_size(8)\nfn main() {}";
323 let analyzer = WgpuAnalyzer::new();
324 let warnings = analyzer.detect_muda(wgsl);
325
326 assert!(!warnings.is_empty(), "Should warn on small workgroup");
327 assert!(warnings
328 .iter()
329 .any(|w| w.description.contains("Small workgroup")));
330 }
331
332 #[test]
333 fn test_detect_non_warp_aligned() {
334 let wgsl = "@workgroup_size(33)\nfn main() {}";
335 let analyzer = WgpuAnalyzer::new();
336 let warnings = analyzer.detect_muda(wgsl);
337
338 assert!(warnings
339 .iter()
340 .any(|w| w.description.contains("not a multiple of 32")));
341 }
342
343 #[test]
344 fn test_optimal_workgroup_no_warnings() {
345 let wgsl = "@workgroup_size(256)\nfn main() {}";
346 let analyzer = WgpuAnalyzer::new();
347 let warnings = analyzer.detect_muda(wgsl);
348
349 assert!(
351 warnings.is_empty(),
352 "Optimal workgroup should have no warnings"
353 );
354 }
355
356 #[test]
358 fn f067_detect_workgroup_size() {
359 let wgsl = r"
360 @compute @workgroup_size(64, 4, 1)
361 fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
362 // compute work
363 }
364 ";
365
366 let analyzer = WgpuAnalyzer::new();
367 let report = analyzer.analyze(wgsl).unwrap();
368
369 assert_eq!(report.target, "WGSL (wgpu)");
371
372 let stats = analyzer.analyze_wgsl(wgsl);
374 assert_eq!(stats.workgroup_size.x, 64);
375 assert_eq!(stats.workgroup_size.y, 4);
376 assert_eq!(stats.workgroup_size.z, 1);
377 assert_eq!(stats.workgroup_size.total(), 256);
378 }
379
380 #[test]
381 fn test_analyze_full_wgsl() {
382 let wgsl = r"
383 struct Params {
384 size: u32,
385 }
386
387 @group(0) @binding(0) var<storage, read> a: array<f32>;
388 @group(0) @binding(1) var<storage, read> b: array<f32>;
389 @group(0) @binding(2) var<storage, read_write> result: array<f32>;
390 @group(0) @binding(3) var<uniform> params: Params;
391
392 @compute @workgroup_size(256)
393 fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
394 let idx = gid.x;
395 if idx < params.size {
396 result[idx] = a[idx] + b[idx];
397 }
398 }
399 ";
400
401 let analyzer = WgpuAnalyzer::new();
402 let report = analyzer.analyze(wgsl).unwrap();
403
404 assert_eq!(report.target, "WGSL (wgpu)");
405 assert!(
406 report.warnings.is_empty(),
407 "Valid WGSL should have no warnings"
408 );
409 }
410}