1use std::collections::HashMap;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum WgslIntrinsic {
10 LocalInvocationIdX,
12 LocalInvocationIdY,
13 LocalInvocationIdZ,
14 WorkgroupIdX,
15 WorkgroupIdY,
16 WorkgroupIdZ,
17 GlobalInvocationIdX,
18 GlobalInvocationIdY,
19 GlobalInvocationIdZ,
20 NumWorkgroupsX,
21 NumWorkgroupsY,
22 NumWorkgroupsZ,
23
24 WorkgroupSizeX,
26 WorkgroupSizeY,
27 WorkgroupSizeZ,
28
29 WorkgroupBarrier,
31 StorageBarrier,
32
33 AtomicAdd,
35 AtomicSub,
36 AtomicMin,
37 AtomicMax,
38 AtomicExchange,
39 AtomicCompareExchangeWeak,
40 AtomicLoad,
41 AtomicStore,
42
43 Sqrt,
45 InverseSqrt,
46 Abs,
47 Floor,
48 Ceil,
49 Round,
50 Sin,
51 Cos,
52 Tan,
53 Exp,
54 Log,
55
56 Pow,
58 Min,
59 Max,
60 Clamp,
61 Fma,
62 Mix,
63
64 SubgroupShuffle,
66 SubgroupShuffleUp,
67 SubgroupShuffleDown,
68 SubgroupShuffleXor,
69 SubgroupBallot,
70 SubgroupAll,
71 SubgroupAny,
72 SubgroupInvocationId,
73 SubgroupSize,
74}
75
76impl WgslIntrinsic {
77 pub fn to_wgsl(&self) -> &'static str {
79 match self {
80 WgslIntrinsic::LocalInvocationIdX => "local_invocation_id.x",
82 WgslIntrinsic::LocalInvocationIdY => "local_invocation_id.y",
83 WgslIntrinsic::LocalInvocationIdZ => "local_invocation_id.z",
84 WgslIntrinsic::WorkgroupIdX => "workgroup_id.x",
85 WgslIntrinsic::WorkgroupIdY => "workgroup_id.y",
86 WgslIntrinsic::WorkgroupIdZ => "workgroup_id.z",
87 WgslIntrinsic::GlobalInvocationIdX => "global_invocation_id.x",
88 WgslIntrinsic::GlobalInvocationIdY => "global_invocation_id.y",
89 WgslIntrinsic::GlobalInvocationIdZ => "global_invocation_id.z",
90 WgslIntrinsic::NumWorkgroupsX => "num_workgroups.x",
91 WgslIntrinsic::NumWorkgroupsY => "num_workgroups.y",
92 WgslIntrinsic::NumWorkgroupsZ => "num_workgroups.z",
93
94 WgslIntrinsic::WorkgroupSizeX => "WORKGROUP_SIZE_X",
96 WgslIntrinsic::WorkgroupSizeY => "WORKGROUP_SIZE_Y",
97 WgslIntrinsic::WorkgroupSizeZ => "WORKGROUP_SIZE_Z",
98
99 WgslIntrinsic::WorkgroupBarrier => "workgroupBarrier()",
101 WgslIntrinsic::StorageBarrier => "storageBarrier()",
102
103 WgslIntrinsic::AtomicAdd => "atomicAdd",
105 WgslIntrinsic::AtomicSub => "atomicSub",
106 WgslIntrinsic::AtomicMin => "atomicMin",
107 WgslIntrinsic::AtomicMax => "atomicMax",
108 WgslIntrinsic::AtomicExchange => "atomicExchange",
109 WgslIntrinsic::AtomicCompareExchangeWeak => "atomicCompareExchangeWeak",
110 WgslIntrinsic::AtomicLoad => "atomicLoad",
111 WgslIntrinsic::AtomicStore => "atomicStore",
112
113 WgslIntrinsic::Sqrt => "sqrt",
115 WgslIntrinsic::InverseSqrt => "inverseSqrt",
116 WgslIntrinsic::Abs => "abs",
117 WgslIntrinsic::Floor => "floor",
118 WgslIntrinsic::Ceil => "ceil",
119 WgslIntrinsic::Round => "round",
120 WgslIntrinsic::Sin => "sin",
121 WgslIntrinsic::Cos => "cos",
122 WgslIntrinsic::Tan => "tan",
123 WgslIntrinsic::Exp => "exp",
124 WgslIntrinsic::Log => "log",
125 WgslIntrinsic::Pow => "pow",
126 WgslIntrinsic::Min => "min",
127 WgslIntrinsic::Max => "max",
128 WgslIntrinsic::Clamp => "clamp",
129 WgslIntrinsic::Fma => "fma",
130 WgslIntrinsic::Mix => "mix",
131
132 WgslIntrinsic::SubgroupShuffle => "subgroupShuffle",
134 WgslIntrinsic::SubgroupShuffleUp => "subgroupShuffleUp",
135 WgslIntrinsic::SubgroupShuffleDown => "subgroupShuffleDown",
136 WgslIntrinsic::SubgroupShuffleXor => "subgroupShuffleXor",
137 WgslIntrinsic::SubgroupBallot => "subgroupBallot",
138 WgslIntrinsic::SubgroupAll => "subgroupAll",
139 WgslIntrinsic::SubgroupAny => "subgroupAny",
140 WgslIntrinsic::SubgroupInvocationId => "subgroup_invocation_id",
141 WgslIntrinsic::SubgroupSize => "subgroup_size",
142 }
143 }
144
145 pub fn requires_subgroup_extension(&self) -> bool {
147 matches!(
148 self,
149 WgslIntrinsic::SubgroupShuffle
150 | WgslIntrinsic::SubgroupShuffleUp
151 | WgslIntrinsic::SubgroupShuffleDown
152 | WgslIntrinsic::SubgroupShuffleXor
153 | WgslIntrinsic::SubgroupBallot
154 | WgslIntrinsic::SubgroupAll
155 | WgslIntrinsic::SubgroupAny
156 | WgslIntrinsic::SubgroupInvocationId
157 | WgslIntrinsic::SubgroupSize
158 )
159 }
160}
161
162pub struct IntrinsicRegistry {
164 mappings: HashMap<&'static str, WgslIntrinsic>,
165}
166
167impl Default for IntrinsicRegistry {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl IntrinsicRegistry {
174 pub fn new() -> Self {
176 let mut mappings = HashMap::new();
177
178 mappings.insert("thread_idx_x", WgslIntrinsic::LocalInvocationIdX);
180 mappings.insert("thread_idx_y", WgslIntrinsic::LocalInvocationIdY);
181 mappings.insert("thread_idx_z", WgslIntrinsic::LocalInvocationIdZ);
182 mappings.insert("block_idx_x", WgslIntrinsic::WorkgroupIdX);
183 mappings.insert("block_idx_y", WgslIntrinsic::WorkgroupIdY);
184 mappings.insert("block_idx_z", WgslIntrinsic::WorkgroupIdZ);
185 mappings.insert("global_thread_id", WgslIntrinsic::GlobalInvocationIdX);
186 mappings.insert("global_thread_id_y", WgslIntrinsic::GlobalInvocationIdY);
187 mappings.insert("global_thread_id_z", WgslIntrinsic::GlobalInvocationIdZ);
188 mappings.insert("grid_dim_x", WgslIntrinsic::NumWorkgroupsX);
189 mappings.insert("grid_dim_y", WgslIntrinsic::NumWorkgroupsY);
190 mappings.insert("grid_dim_z", WgslIntrinsic::NumWorkgroupsZ);
191
192 mappings.insert("block_dim_x", WgslIntrinsic::WorkgroupSizeX);
194 mappings.insert("block_dim_y", WgslIntrinsic::WorkgroupSizeY);
195 mappings.insert("block_dim_z", WgslIntrinsic::WorkgroupSizeZ);
196
197 mappings.insert("sync_threads", WgslIntrinsic::WorkgroupBarrier);
199 mappings.insert("thread_fence", WgslIntrinsic::StorageBarrier);
200 mappings.insert("thread_fence_block", WgslIntrinsic::WorkgroupBarrier);
201
202 mappings.insert("atomic_add", WgslIntrinsic::AtomicAdd);
204 mappings.insert("atomic_sub", WgslIntrinsic::AtomicSub);
205 mappings.insert("atomic_min", WgslIntrinsic::AtomicMin);
206 mappings.insert("atomic_max", WgslIntrinsic::AtomicMax);
207 mappings.insert("atomic_exchange", WgslIntrinsic::AtomicExchange);
208 mappings.insert("atomic_cas", WgslIntrinsic::AtomicCompareExchangeWeak);
209 mappings.insert("atomic_load", WgslIntrinsic::AtomicLoad);
210 mappings.insert("atomic_store", WgslIntrinsic::AtomicStore);
211
212 mappings.insert("sqrt", WgslIntrinsic::Sqrt);
214 mappings.insert("rsqrt", WgslIntrinsic::InverseSqrt);
215 mappings.insert("abs", WgslIntrinsic::Abs);
216 mappings.insert("floor", WgslIntrinsic::Floor);
217 mappings.insert("ceil", WgslIntrinsic::Ceil);
218 mappings.insert("round", WgslIntrinsic::Round);
219 mappings.insert("sin", WgslIntrinsic::Sin);
220 mappings.insert("cos", WgslIntrinsic::Cos);
221 mappings.insert("tan", WgslIntrinsic::Tan);
222 mappings.insert("exp", WgslIntrinsic::Exp);
223 mappings.insert("log", WgslIntrinsic::Log);
224 mappings.insert("powf", WgslIntrinsic::Pow);
225 mappings.insert("min", WgslIntrinsic::Min);
226 mappings.insert("max", WgslIntrinsic::Max);
227 mappings.insert("clamp", WgslIntrinsic::Clamp);
228 mappings.insert("fma", WgslIntrinsic::Fma);
229 mappings.insert("mix", WgslIntrinsic::Mix);
230
231 mappings.insert("warp_shuffle", WgslIntrinsic::SubgroupShuffle);
233 mappings.insert("warp_shuffle_up", WgslIntrinsic::SubgroupShuffleUp);
234 mappings.insert("warp_shuffle_down", WgslIntrinsic::SubgroupShuffleDown);
235 mappings.insert("warp_shuffle_xor", WgslIntrinsic::SubgroupShuffleXor);
236 mappings.insert("warp_ballot", WgslIntrinsic::SubgroupBallot);
237 mappings.insert("warp_all", WgslIntrinsic::SubgroupAll);
238 mappings.insert("warp_any", WgslIntrinsic::SubgroupAny);
239 mappings.insert("lane_id", WgslIntrinsic::SubgroupInvocationId);
240 mappings.insert("warp_size", WgslIntrinsic::SubgroupSize);
241
242 Self { mappings }
243 }
244
245 pub fn lookup(&self, name: &str) -> Option<WgslIntrinsic> {
247 self.mappings.get(name).copied()
248 }
249
250 pub fn is_intrinsic(&self, name: &str) -> bool {
252 self.mappings.contains_key(name)
253 }
254
255 pub fn subgroup_intrinsics(&self) -> Vec<(&'static str, WgslIntrinsic)> {
257 self.mappings
258 .iter()
259 .filter(|(_, intrinsic)| intrinsic.requires_subgroup_extension())
260 .map(|(&name, &intrinsic)| (name, intrinsic))
261 .collect()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268
269 #[test]
270 fn test_registry_lookup() {
271 let registry = IntrinsicRegistry::new();
272 assert_eq!(
273 registry.lookup("thread_idx_x"),
274 Some(WgslIntrinsic::LocalInvocationIdX)
275 );
276 assert_eq!(
277 registry.lookup("sync_threads"),
278 Some(WgslIntrinsic::WorkgroupBarrier)
279 );
280 assert_eq!(registry.lookup("unknown_function"), None);
281 }
282
283 #[test]
284 fn test_intrinsic_wgsl_output() {
285 assert_eq!(
286 WgslIntrinsic::LocalInvocationIdX.to_wgsl(),
287 "local_invocation_id.x"
288 );
289 assert_eq!(
290 WgslIntrinsic::WorkgroupBarrier.to_wgsl(),
291 "workgroupBarrier()"
292 );
293 assert_eq!(WgslIntrinsic::Sqrt.to_wgsl(), "sqrt");
294 }
295
296 #[test]
297 fn test_subgroup_extension_detection() {
298 assert!(WgslIntrinsic::SubgroupShuffle.requires_subgroup_extension());
299 assert!(!WgslIntrinsic::Sqrt.requires_subgroup_extension());
300 assert!(!WgslIntrinsic::WorkgroupBarrier.requires_subgroup_extension());
301 }
302}