1use std::collections::HashMap;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum WgslIntrinsic {
10 LocalInvocationIdX,
12 LocalInvocationIdY,
13 LocalInvocationIdZ,
14 WorkgroupIdX,
15 WorkgroupIdY,
16 WorkgroupIdZ,
17 GlobalInvocationIdX,
18 GlobalInvocationIdY,
19 GlobalInvocationIdZ,
20 NumWorkgroupsX,
21 NumWorkgroupsY,
22 NumWorkgroupsZ,
23
24 WorkgroupSizeX,
26 WorkgroupSizeY,
27 WorkgroupSizeZ,
28
29 WorkgroupBarrier,
31 StorageBarrier,
32
33 AtomicAdd,
35 AtomicSub,
36 AtomicMin,
37 AtomicMax,
38 AtomicExchange,
39 AtomicCompareExchangeWeak,
40 AtomicLoad,
41 AtomicStore,
42
43 Sqrt,
45 InverseSqrt,
46 Abs,
47 Floor,
48 Ceil,
49 Round,
50 Sin,
51 Cos,
52 Tan,
53 Exp,
54 Log,
55
56 Pow,
58 Min,
59 Max,
60 Clamp,
61 Fma,
62 Mix,
63
64 SubgroupInvocationId,
67 SubgroupSize,
68
69 SubgroupAll,
71 SubgroupAny,
72 SubgroupBallot,
73 SubgroupElect,
74
75 SubgroupShuffle,
77 SubgroupShuffleUp,
78 SubgroupShuffleDown,
79 SubgroupShuffleXor,
80 SubgroupBroadcast,
81 SubgroupBroadcastFirst,
82
83 SubgroupAdd,
85 SubgroupMul,
86 SubgroupMin,
87 SubgroupMax,
88 SubgroupAnd,
89 SubgroupOr,
90 SubgroupXor,
91
92 SubgroupInclusiveAdd,
94 SubgroupExclusiveAdd,
95 SubgroupInclusiveMul,
96 SubgroupExclusiveMul,
97}
98
99impl WgslIntrinsic {
100 pub fn to_wgsl(&self) -> &'static str {
102 match self {
103 WgslIntrinsic::LocalInvocationIdX => "local_invocation_id.x",
105 WgslIntrinsic::LocalInvocationIdY => "local_invocation_id.y",
106 WgslIntrinsic::LocalInvocationIdZ => "local_invocation_id.z",
107 WgslIntrinsic::WorkgroupIdX => "workgroup_id.x",
108 WgslIntrinsic::WorkgroupIdY => "workgroup_id.y",
109 WgslIntrinsic::WorkgroupIdZ => "workgroup_id.z",
110 WgslIntrinsic::GlobalInvocationIdX => "global_invocation_id.x",
111 WgslIntrinsic::GlobalInvocationIdY => "global_invocation_id.y",
112 WgslIntrinsic::GlobalInvocationIdZ => "global_invocation_id.z",
113 WgslIntrinsic::NumWorkgroupsX => "num_workgroups.x",
114 WgslIntrinsic::NumWorkgroupsY => "num_workgroups.y",
115 WgslIntrinsic::NumWorkgroupsZ => "num_workgroups.z",
116
117 WgslIntrinsic::WorkgroupSizeX => "WORKGROUP_SIZE_X",
119 WgslIntrinsic::WorkgroupSizeY => "WORKGROUP_SIZE_Y",
120 WgslIntrinsic::WorkgroupSizeZ => "WORKGROUP_SIZE_Z",
121
122 WgslIntrinsic::WorkgroupBarrier => "workgroupBarrier()",
124 WgslIntrinsic::StorageBarrier => "storageBarrier()",
125
126 WgslIntrinsic::AtomicAdd => "atomicAdd",
128 WgslIntrinsic::AtomicSub => "atomicSub",
129 WgslIntrinsic::AtomicMin => "atomicMin",
130 WgslIntrinsic::AtomicMax => "atomicMax",
131 WgslIntrinsic::AtomicExchange => "atomicExchange",
132 WgslIntrinsic::AtomicCompareExchangeWeak => "atomicCompareExchangeWeak",
133 WgslIntrinsic::AtomicLoad => "atomicLoad",
134 WgslIntrinsic::AtomicStore => "atomicStore",
135
136 WgslIntrinsic::Sqrt => "sqrt",
138 WgslIntrinsic::InverseSqrt => "inverseSqrt",
139 WgslIntrinsic::Abs => "abs",
140 WgslIntrinsic::Floor => "floor",
141 WgslIntrinsic::Ceil => "ceil",
142 WgslIntrinsic::Round => "round",
143 WgslIntrinsic::Sin => "sin",
144 WgslIntrinsic::Cos => "cos",
145 WgslIntrinsic::Tan => "tan",
146 WgslIntrinsic::Exp => "exp",
147 WgslIntrinsic::Log => "log",
148 WgslIntrinsic::Pow => "pow",
149 WgslIntrinsic::Min => "min",
150 WgslIntrinsic::Max => "max",
151 WgslIntrinsic::Clamp => "clamp",
152 WgslIntrinsic::Fma => "fma",
153 WgslIntrinsic::Mix => "mix",
154
155 WgslIntrinsic::SubgroupInvocationId => "subgroup_invocation_id",
157 WgslIntrinsic::SubgroupSize => "subgroup_size",
158
159 WgslIntrinsic::SubgroupAll => "subgroupAll",
161 WgslIntrinsic::SubgroupAny => "subgroupAny",
162 WgslIntrinsic::SubgroupBallot => "subgroupBallot",
163 WgslIntrinsic::SubgroupElect => "subgroupElect",
164
165 WgslIntrinsic::SubgroupShuffle => "subgroupShuffle",
167 WgslIntrinsic::SubgroupShuffleUp => "subgroupShuffleUp",
168 WgslIntrinsic::SubgroupShuffleDown => "subgroupShuffleDown",
169 WgslIntrinsic::SubgroupShuffleXor => "subgroupShuffleXor",
170 WgslIntrinsic::SubgroupBroadcast => "subgroupBroadcast",
171 WgslIntrinsic::SubgroupBroadcastFirst => "subgroupBroadcastFirst",
172
173 WgslIntrinsic::SubgroupAdd => "subgroupAdd",
175 WgslIntrinsic::SubgroupMul => "subgroupMul",
176 WgslIntrinsic::SubgroupMin => "subgroupMin",
177 WgslIntrinsic::SubgroupMax => "subgroupMax",
178 WgslIntrinsic::SubgroupAnd => "subgroupAnd",
179 WgslIntrinsic::SubgroupOr => "subgroupOr",
180 WgslIntrinsic::SubgroupXor => "subgroupXor",
181
182 WgslIntrinsic::SubgroupInclusiveAdd => "subgroupInclusiveAdd",
184 WgslIntrinsic::SubgroupExclusiveAdd => "subgroupExclusiveAdd",
185 WgslIntrinsic::SubgroupInclusiveMul => "subgroupInclusiveMul",
186 WgslIntrinsic::SubgroupExclusiveMul => "subgroupExclusiveMul",
187 }
188 }
189
190 pub fn requires_subgroup_extension(&self) -> bool {
192 matches!(
193 self,
194 WgslIntrinsic::SubgroupInvocationId
196 | WgslIntrinsic::SubgroupSize
197 | WgslIntrinsic::SubgroupAll
199 | WgslIntrinsic::SubgroupAny
200 | WgslIntrinsic::SubgroupBallot
201 | WgslIntrinsic::SubgroupElect
202 | WgslIntrinsic::SubgroupShuffle
204 | WgslIntrinsic::SubgroupShuffleUp
205 | WgslIntrinsic::SubgroupShuffleDown
206 | WgslIntrinsic::SubgroupShuffleXor
207 | WgslIntrinsic::SubgroupBroadcast
208 | WgslIntrinsic::SubgroupBroadcastFirst
209 | WgslIntrinsic::SubgroupAdd
211 | WgslIntrinsic::SubgroupMul
212 | WgslIntrinsic::SubgroupMin
213 | WgslIntrinsic::SubgroupMax
214 | WgslIntrinsic::SubgroupAnd
215 | WgslIntrinsic::SubgroupOr
216 | WgslIntrinsic::SubgroupXor
217 | WgslIntrinsic::SubgroupInclusiveAdd
219 | WgslIntrinsic::SubgroupExclusiveAdd
220 | WgslIntrinsic::SubgroupInclusiveMul
221 | WgslIntrinsic::SubgroupExclusiveMul
222 )
223 }
224
225 pub fn is_subgroup_builtin(&self) -> bool {
227 matches!(
228 self,
229 WgslIntrinsic::SubgroupInvocationId | WgslIntrinsic::SubgroupSize
230 )
231 }
232}
233
234pub struct IntrinsicRegistry {
236 mappings: HashMap<&'static str, WgslIntrinsic>,
237}
238
239impl Default for IntrinsicRegistry {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245impl IntrinsicRegistry {
246 pub fn new() -> Self {
248 let mut mappings = HashMap::new();
249
250 mappings.insert("thread_idx_x", WgslIntrinsic::LocalInvocationIdX);
252 mappings.insert("thread_idx_y", WgslIntrinsic::LocalInvocationIdY);
253 mappings.insert("thread_idx_z", WgslIntrinsic::LocalInvocationIdZ);
254 mappings.insert("block_idx_x", WgslIntrinsic::WorkgroupIdX);
255 mappings.insert("block_idx_y", WgslIntrinsic::WorkgroupIdY);
256 mappings.insert("block_idx_z", WgslIntrinsic::WorkgroupIdZ);
257 mappings.insert("global_thread_id", WgslIntrinsic::GlobalInvocationIdX);
258 mappings.insert("global_thread_id_y", WgslIntrinsic::GlobalInvocationIdY);
259 mappings.insert("global_thread_id_z", WgslIntrinsic::GlobalInvocationIdZ);
260 mappings.insert("grid_dim_x", WgslIntrinsic::NumWorkgroupsX);
261 mappings.insert("grid_dim_y", WgslIntrinsic::NumWorkgroupsY);
262 mappings.insert("grid_dim_z", WgslIntrinsic::NumWorkgroupsZ);
263
264 mappings.insert("block_dim_x", WgslIntrinsic::WorkgroupSizeX);
266 mappings.insert("block_dim_y", WgslIntrinsic::WorkgroupSizeY);
267 mappings.insert("block_dim_z", WgslIntrinsic::WorkgroupSizeZ);
268
269 mappings.insert("sync_threads", WgslIntrinsic::WorkgroupBarrier);
271 mappings.insert("thread_fence", WgslIntrinsic::StorageBarrier);
272 mappings.insert("thread_fence_block", WgslIntrinsic::WorkgroupBarrier);
273
274 mappings.insert("atomic_add", WgslIntrinsic::AtomicAdd);
276 mappings.insert("atomic_sub", WgslIntrinsic::AtomicSub);
277 mappings.insert("atomic_min", WgslIntrinsic::AtomicMin);
278 mappings.insert("atomic_max", WgslIntrinsic::AtomicMax);
279 mappings.insert("atomic_exchange", WgslIntrinsic::AtomicExchange);
280 mappings.insert("atomic_cas", WgslIntrinsic::AtomicCompareExchangeWeak);
281 mappings.insert("atomic_load", WgslIntrinsic::AtomicLoad);
282 mappings.insert("atomic_store", WgslIntrinsic::AtomicStore);
283
284 mappings.insert("sqrt", WgslIntrinsic::Sqrt);
286 mappings.insert("rsqrt", WgslIntrinsic::InverseSqrt);
287 mappings.insert("abs", WgslIntrinsic::Abs);
288 mappings.insert("floor", WgslIntrinsic::Floor);
289 mappings.insert("ceil", WgslIntrinsic::Ceil);
290 mappings.insert("round", WgslIntrinsic::Round);
291 mappings.insert("sin", WgslIntrinsic::Sin);
292 mappings.insert("cos", WgslIntrinsic::Cos);
293 mappings.insert("tan", WgslIntrinsic::Tan);
294 mappings.insert("exp", WgslIntrinsic::Exp);
295 mappings.insert("log", WgslIntrinsic::Log);
296 mappings.insert("powf", WgslIntrinsic::Pow);
297 mappings.insert("min", WgslIntrinsic::Min);
298 mappings.insert("max", WgslIntrinsic::Max);
299 mappings.insert("clamp", WgslIntrinsic::Clamp);
300 mappings.insert("fma", WgslIntrinsic::Fma);
301 mappings.insert("mix", WgslIntrinsic::Mix);
302
303 mappings.insert("lane_id", WgslIntrinsic::SubgroupInvocationId);
305 mappings.insert("subgroup_id", WgslIntrinsic::SubgroupInvocationId);
306 mappings.insert(
307 "subgroup_invocation_id",
308 WgslIntrinsic::SubgroupInvocationId,
309 );
310 mappings.insert("warp_size", WgslIntrinsic::SubgroupSize);
311 mappings.insert("subgroup_size", WgslIntrinsic::SubgroupSize);
312
313 mappings.insert("subgroup_all", WgslIntrinsic::SubgroupAll);
315 mappings.insert("warp_all", WgslIntrinsic::SubgroupAll);
316 mappings.insert("subgroup_any", WgslIntrinsic::SubgroupAny);
317 mappings.insert("warp_any", WgslIntrinsic::SubgroupAny);
318 mappings.insert("subgroup_ballot", WgslIntrinsic::SubgroupBallot);
319 mappings.insert("warp_ballot", WgslIntrinsic::SubgroupBallot);
320 mappings.insert("subgroup_elect", WgslIntrinsic::SubgroupElect);
321 mappings.insert("warp_elect", WgslIntrinsic::SubgroupElect);
322
323 mappings.insert("subgroup_shuffle", WgslIntrinsic::SubgroupShuffle);
325 mappings.insert("warp_shuffle", WgslIntrinsic::SubgroupShuffle);
326 mappings.insert("subgroup_shuffle_up", WgslIntrinsic::SubgroupShuffleUp);
327 mappings.insert("warp_shuffle_up", WgslIntrinsic::SubgroupShuffleUp);
328 mappings.insert("subgroup_shuffle_down", WgslIntrinsic::SubgroupShuffleDown);
329 mappings.insert("warp_shuffle_down", WgslIntrinsic::SubgroupShuffleDown);
330 mappings.insert("subgroup_shuffle_xor", WgslIntrinsic::SubgroupShuffleXor);
331 mappings.insert("warp_shuffle_xor", WgslIntrinsic::SubgroupShuffleXor);
332 mappings.insert("subgroup_broadcast", WgslIntrinsic::SubgroupBroadcast);
333 mappings.insert("warp_broadcast", WgslIntrinsic::SubgroupBroadcast);
334 mappings.insert(
335 "subgroup_broadcast_first",
336 WgslIntrinsic::SubgroupBroadcastFirst,
337 );
338 mappings.insert(
339 "warp_broadcast_first",
340 WgslIntrinsic::SubgroupBroadcastFirst,
341 );
342
343 mappings.insert("subgroup_add", WgslIntrinsic::SubgroupAdd);
345 mappings.insert("warp_reduce_add", WgslIntrinsic::SubgroupAdd);
346 mappings.insert("subgroup_mul", WgslIntrinsic::SubgroupMul);
347 mappings.insert("warp_reduce_mul", WgslIntrinsic::SubgroupMul);
348 mappings.insert("subgroup_min", WgslIntrinsic::SubgroupMin);
349 mappings.insert("warp_reduce_min", WgslIntrinsic::SubgroupMin);
350 mappings.insert("subgroup_max", WgslIntrinsic::SubgroupMax);
351 mappings.insert("warp_reduce_max", WgslIntrinsic::SubgroupMax);
352 mappings.insert("subgroup_and", WgslIntrinsic::SubgroupAnd);
353 mappings.insert("warp_reduce_and", WgslIntrinsic::SubgroupAnd);
354 mappings.insert("subgroup_or", WgslIntrinsic::SubgroupOr);
355 mappings.insert("warp_reduce_or", WgslIntrinsic::SubgroupOr);
356 mappings.insert("subgroup_xor", WgslIntrinsic::SubgroupXor);
357 mappings.insert("warp_reduce_xor", WgslIntrinsic::SubgroupXor);
358
359 mappings.insert(
361 "subgroup_inclusive_add",
362 WgslIntrinsic::SubgroupInclusiveAdd,
363 );
364 mappings.insert("warp_prefix_sum", WgslIntrinsic::SubgroupInclusiveAdd);
365 mappings.insert(
366 "subgroup_exclusive_add",
367 WgslIntrinsic::SubgroupExclusiveAdd,
368 );
369 mappings.insert("warp_exclusive_sum", WgslIntrinsic::SubgroupExclusiveAdd);
370 mappings.insert(
371 "subgroup_inclusive_mul",
372 WgslIntrinsic::SubgroupInclusiveMul,
373 );
374 mappings.insert(
375 "subgroup_exclusive_mul",
376 WgslIntrinsic::SubgroupExclusiveMul,
377 );
378
379 Self { mappings }
380 }
381
382 pub fn lookup(&self, name: &str) -> Option<WgslIntrinsic> {
384 self.mappings.get(name).copied()
385 }
386
387 pub fn is_intrinsic(&self, name: &str) -> bool {
389 self.mappings.contains_key(name)
390 }
391
392 pub fn subgroup_intrinsics(&self) -> Vec<(&'static str, WgslIntrinsic)> {
394 self.mappings
395 .iter()
396 .filter(|(_, intrinsic)| intrinsic.requires_subgroup_extension())
397 .map(|(&name, &intrinsic)| (name, intrinsic))
398 .collect()
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn test_registry_lookup() {
408 let registry = IntrinsicRegistry::new();
409 assert_eq!(
410 registry.lookup("thread_idx_x"),
411 Some(WgslIntrinsic::LocalInvocationIdX)
412 );
413 assert_eq!(
414 registry.lookup("sync_threads"),
415 Some(WgslIntrinsic::WorkgroupBarrier)
416 );
417 assert_eq!(registry.lookup("unknown_function"), None);
418 }
419
420 #[test]
421 fn test_intrinsic_wgsl_output() {
422 assert_eq!(
423 WgslIntrinsic::LocalInvocationIdX.to_wgsl(),
424 "local_invocation_id.x"
425 );
426 assert_eq!(
427 WgslIntrinsic::WorkgroupBarrier.to_wgsl(),
428 "workgroupBarrier()"
429 );
430 assert_eq!(WgslIntrinsic::Sqrt.to_wgsl(), "sqrt");
431 }
432
433 #[test]
434 fn test_subgroup_extension_detection() {
435 assert!(WgslIntrinsic::SubgroupShuffle.requires_subgroup_extension());
436 assert!(WgslIntrinsic::SubgroupAdd.requires_subgroup_extension());
437 assert!(WgslIntrinsic::SubgroupInclusiveAdd.requires_subgroup_extension());
438 assert!(!WgslIntrinsic::Sqrt.requires_subgroup_extension());
439 assert!(!WgslIntrinsic::WorkgroupBarrier.requires_subgroup_extension());
440 }
441
442 #[test]
443 fn test_subgroup_operations_mappings() {
444 let registry = IntrinsicRegistry::new();
445
446 assert_eq!(
448 registry.lookup("subgroup_all"),
449 Some(WgslIntrinsic::SubgroupAll)
450 );
451 assert_eq!(
452 registry.lookup("warp_all"),
453 Some(WgslIntrinsic::SubgroupAll)
454 );
455 assert_eq!(
456 registry.lookup("subgroup_any"),
457 Some(WgslIntrinsic::SubgroupAny)
458 );
459 assert_eq!(
460 registry.lookup("subgroup_ballot"),
461 Some(WgslIntrinsic::SubgroupBallot)
462 );
463 assert_eq!(
464 registry.lookup("subgroup_elect"),
465 Some(WgslIntrinsic::SubgroupElect)
466 );
467
468 assert_eq!(
470 registry.lookup("subgroup_shuffle"),
471 Some(WgslIntrinsic::SubgroupShuffle)
472 );
473 assert_eq!(
474 registry.lookup("warp_shuffle"),
475 Some(WgslIntrinsic::SubgroupShuffle)
476 );
477 assert_eq!(
478 registry.lookup("subgroup_shuffle_xor"),
479 Some(WgslIntrinsic::SubgroupShuffleXor)
480 );
481 assert_eq!(
482 registry.lookup("subgroup_broadcast"),
483 Some(WgslIntrinsic::SubgroupBroadcast)
484 );
485 assert_eq!(
486 registry.lookup("subgroup_broadcast_first"),
487 Some(WgslIntrinsic::SubgroupBroadcastFirst)
488 );
489
490 assert_eq!(
492 registry.lookup("subgroup_add"),
493 Some(WgslIntrinsic::SubgroupAdd)
494 );
495 assert_eq!(
496 registry.lookup("warp_reduce_add"),
497 Some(WgslIntrinsic::SubgroupAdd)
498 );
499 assert_eq!(
500 registry.lookup("subgroup_min"),
501 Some(WgslIntrinsic::SubgroupMin)
502 );
503 assert_eq!(
504 registry.lookup("subgroup_max"),
505 Some(WgslIntrinsic::SubgroupMax)
506 );
507
508 assert_eq!(
510 registry.lookup("subgroup_inclusive_add"),
511 Some(WgslIntrinsic::SubgroupInclusiveAdd)
512 );
513 assert_eq!(
514 registry.lookup("warp_prefix_sum"),
515 Some(WgslIntrinsic::SubgroupInclusiveAdd)
516 );
517 assert_eq!(
518 registry.lookup("subgroup_exclusive_add"),
519 Some(WgslIntrinsic::SubgroupExclusiveAdd)
520 );
521 }
522
523 #[test]
524 fn test_subgroup_wgsl_output() {
525 assert_eq!(WgslIntrinsic::SubgroupAll.to_wgsl(), "subgroupAll");
527 assert_eq!(WgslIntrinsic::SubgroupAny.to_wgsl(), "subgroupAny");
528 assert_eq!(WgslIntrinsic::SubgroupBallot.to_wgsl(), "subgroupBallot");
529 assert_eq!(WgslIntrinsic::SubgroupElect.to_wgsl(), "subgroupElect");
530
531 assert_eq!(WgslIntrinsic::SubgroupShuffle.to_wgsl(), "subgroupShuffle");
533 assert_eq!(
534 WgslIntrinsic::SubgroupShuffleXor.to_wgsl(),
535 "subgroupShuffleXor"
536 );
537 assert_eq!(
538 WgslIntrinsic::SubgroupBroadcast.to_wgsl(),
539 "subgroupBroadcast"
540 );
541
542 assert_eq!(WgslIntrinsic::SubgroupAdd.to_wgsl(), "subgroupAdd");
544 assert_eq!(WgslIntrinsic::SubgroupMin.to_wgsl(), "subgroupMin");
545 assert_eq!(WgslIntrinsic::SubgroupMax.to_wgsl(), "subgroupMax");
546
547 assert_eq!(
549 WgslIntrinsic::SubgroupInclusiveAdd.to_wgsl(),
550 "subgroupInclusiveAdd"
551 );
552 assert_eq!(
553 WgslIntrinsic::SubgroupExclusiveAdd.to_wgsl(),
554 "subgroupExclusiveAdd"
555 );
556
557 assert_eq!(
559 WgslIntrinsic::SubgroupInvocationId.to_wgsl(),
560 "subgroup_invocation_id"
561 );
562 assert_eq!(WgslIntrinsic::SubgroupSize.to_wgsl(), "subgroup_size");
563 }
564
565 #[test]
566 fn test_subgroup_builtin_detection() {
567 assert!(WgslIntrinsic::SubgroupInvocationId.is_subgroup_builtin());
568 assert!(WgslIntrinsic::SubgroupSize.is_subgroup_builtin());
569 assert!(!WgslIntrinsic::SubgroupAdd.is_subgroup_builtin());
570 assert!(!WgslIntrinsic::SubgroupShuffle.is_subgroup_builtin());
571 }
572
573 #[test]
574 fn test_subgroup_intrinsics_list() {
575 let registry = IntrinsicRegistry::new();
576 let subgroup_ops = registry.subgroup_intrinsics();
577
578 assert!(subgroup_ops.len() > 20);
580
581 for (_, intrinsic) in &subgroup_ops {
583 assert!(
584 intrinsic.requires_subgroup_extension(),
585 "Intrinsic {:?} should require subgroup extension",
586 intrinsic
587 );
588 }
589 }
590}