ringkernel_wgpu_codegen/
intrinsics.rs

1//! Intrinsic registry for WGSL code generation.
2//!
3//! Maps Rust DSL function calls to WGSL intrinsics.
4
5use std::collections::HashMap;
6
7/// WGSL intrinsic operations.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum WgslIntrinsic {
10    // Thread/workgroup indices
11    LocalInvocationIdX,
12    LocalInvocationIdY,
13    LocalInvocationIdZ,
14    WorkgroupIdX,
15    WorkgroupIdY,
16    WorkgroupIdZ,
17    GlobalInvocationIdX,
18    GlobalInvocationIdY,
19    GlobalInvocationIdZ,
20    NumWorkgroupsX,
21    NumWorkgroupsY,
22    NumWorkgroupsZ,
23
24    // Workgroup size constants
25    WorkgroupSizeX,
26    WorkgroupSizeY,
27    WorkgroupSizeZ,
28
29    // Synchronization
30    WorkgroupBarrier,
31    StorageBarrier,
32
33    // Atomics
34    AtomicAdd,
35    AtomicSub,
36    AtomicMin,
37    AtomicMax,
38    AtomicExchange,
39    AtomicCompareExchangeWeak,
40    AtomicLoad,
41    AtomicStore,
42
43    // Math - single argument
44    Sqrt,
45    InverseSqrt,
46    Abs,
47    Floor,
48    Ceil,
49    Round,
50    Sin,
51    Cos,
52    Tan,
53    Exp,
54    Log,
55
56    // Math - multiple arguments
57    Pow,
58    Min,
59    Max,
60    Clamp,
61    Fma,
62    Mix,
63
64    // Subgroup operations (require extensions)
65    SubgroupShuffle,
66    SubgroupShuffleUp,
67    SubgroupShuffleDown,
68    SubgroupShuffleXor,
69    SubgroupBallot,
70    SubgroupAll,
71    SubgroupAny,
72    SubgroupInvocationId,
73    SubgroupSize,
74}
75
76impl WgslIntrinsic {
77    /// Get the WGSL code representation for this intrinsic.
78    pub fn to_wgsl(&self) -> &'static str {
79        match self {
80            // These are builtins accessed via variables, not function calls
81            WgslIntrinsic::LocalInvocationIdX => "local_invocation_id.x",
82            WgslIntrinsic::LocalInvocationIdY => "local_invocation_id.y",
83            WgslIntrinsic::LocalInvocationIdZ => "local_invocation_id.z",
84            WgslIntrinsic::WorkgroupIdX => "workgroup_id.x",
85            WgslIntrinsic::WorkgroupIdY => "workgroup_id.y",
86            WgslIntrinsic::WorkgroupIdZ => "workgroup_id.z",
87            WgslIntrinsic::GlobalInvocationIdX => "global_invocation_id.x",
88            WgslIntrinsic::GlobalInvocationIdY => "global_invocation_id.y",
89            WgslIntrinsic::GlobalInvocationIdZ => "global_invocation_id.z",
90            WgslIntrinsic::NumWorkgroupsX => "num_workgroups.x",
91            WgslIntrinsic::NumWorkgroupsY => "num_workgroups.y",
92            WgslIntrinsic::NumWorkgroupsZ => "num_workgroups.z",
93
94            // Workgroup size constants (substituted at transpile time)
95            WgslIntrinsic::WorkgroupSizeX => "WORKGROUP_SIZE_X",
96            WgslIntrinsic::WorkgroupSizeY => "WORKGROUP_SIZE_Y",
97            WgslIntrinsic::WorkgroupSizeZ => "WORKGROUP_SIZE_Z",
98
99            // Synchronization
100            WgslIntrinsic::WorkgroupBarrier => "workgroupBarrier()",
101            WgslIntrinsic::StorageBarrier => "storageBarrier()",
102
103            // Atomics - these need arguments, so just return the function name
104            WgslIntrinsic::AtomicAdd => "atomicAdd",
105            WgslIntrinsic::AtomicSub => "atomicSub",
106            WgslIntrinsic::AtomicMin => "atomicMin",
107            WgslIntrinsic::AtomicMax => "atomicMax",
108            WgslIntrinsic::AtomicExchange => "atomicExchange",
109            WgslIntrinsic::AtomicCompareExchangeWeak => "atomicCompareExchangeWeak",
110            WgslIntrinsic::AtomicLoad => "atomicLoad",
111            WgslIntrinsic::AtomicStore => "atomicStore",
112
113            // Math functions
114            WgslIntrinsic::Sqrt => "sqrt",
115            WgslIntrinsic::InverseSqrt => "inverseSqrt",
116            WgslIntrinsic::Abs => "abs",
117            WgslIntrinsic::Floor => "floor",
118            WgslIntrinsic::Ceil => "ceil",
119            WgslIntrinsic::Round => "round",
120            WgslIntrinsic::Sin => "sin",
121            WgslIntrinsic::Cos => "cos",
122            WgslIntrinsic::Tan => "tan",
123            WgslIntrinsic::Exp => "exp",
124            WgslIntrinsic::Log => "log",
125            WgslIntrinsic::Pow => "pow",
126            WgslIntrinsic::Min => "min",
127            WgslIntrinsic::Max => "max",
128            WgslIntrinsic::Clamp => "clamp",
129            WgslIntrinsic::Fma => "fma",
130            WgslIntrinsic::Mix => "mix",
131
132            // Subgroup operations
133            WgslIntrinsic::SubgroupShuffle => "subgroupShuffle",
134            WgslIntrinsic::SubgroupShuffleUp => "subgroupShuffleUp",
135            WgslIntrinsic::SubgroupShuffleDown => "subgroupShuffleDown",
136            WgslIntrinsic::SubgroupShuffleXor => "subgroupShuffleXor",
137            WgslIntrinsic::SubgroupBallot => "subgroupBallot",
138            WgslIntrinsic::SubgroupAll => "subgroupAll",
139            WgslIntrinsic::SubgroupAny => "subgroupAny",
140            WgslIntrinsic::SubgroupInvocationId => "subgroup_invocation_id",
141            WgslIntrinsic::SubgroupSize => "subgroup_size",
142        }
143    }
144
145    /// Check if this intrinsic requires the subgroup extension.
146    pub fn requires_subgroup_extension(&self) -> bool {
147        matches!(
148            self,
149            WgslIntrinsic::SubgroupShuffle
150                | WgslIntrinsic::SubgroupShuffleUp
151                | WgslIntrinsic::SubgroupShuffleDown
152                | WgslIntrinsic::SubgroupShuffleXor
153                | WgslIntrinsic::SubgroupBallot
154                | WgslIntrinsic::SubgroupAll
155                | WgslIntrinsic::SubgroupAny
156                | WgslIntrinsic::SubgroupInvocationId
157                | WgslIntrinsic::SubgroupSize
158        )
159    }
160}
161
162/// Registry mapping Rust DSL function names to WGSL intrinsics.
163pub struct IntrinsicRegistry {
164    mappings: HashMap<&'static str, WgslIntrinsic>,
165}
166
167impl Default for IntrinsicRegistry {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl IntrinsicRegistry {
174    /// Create a new intrinsic registry with all standard mappings.
175    pub fn new() -> Self {
176        let mut mappings = HashMap::new();
177
178        // Thread/workgroup indices
179        mappings.insert("thread_idx_x", WgslIntrinsic::LocalInvocationIdX);
180        mappings.insert("thread_idx_y", WgslIntrinsic::LocalInvocationIdY);
181        mappings.insert("thread_idx_z", WgslIntrinsic::LocalInvocationIdZ);
182        mappings.insert("block_idx_x", WgslIntrinsic::WorkgroupIdX);
183        mappings.insert("block_idx_y", WgslIntrinsic::WorkgroupIdY);
184        mappings.insert("block_idx_z", WgslIntrinsic::WorkgroupIdZ);
185        mappings.insert("global_thread_id", WgslIntrinsic::GlobalInvocationIdX);
186        mappings.insert("global_thread_id_y", WgslIntrinsic::GlobalInvocationIdY);
187        mappings.insert("global_thread_id_z", WgslIntrinsic::GlobalInvocationIdZ);
188        mappings.insert("grid_dim_x", WgslIntrinsic::NumWorkgroupsX);
189        mappings.insert("grid_dim_y", WgslIntrinsic::NumWorkgroupsY);
190        mappings.insert("grid_dim_z", WgslIntrinsic::NumWorkgroupsZ);
191
192        // Workgroup size
193        mappings.insert("block_dim_x", WgslIntrinsic::WorkgroupSizeX);
194        mappings.insert("block_dim_y", WgslIntrinsic::WorkgroupSizeY);
195        mappings.insert("block_dim_z", WgslIntrinsic::WorkgroupSizeZ);
196
197        // Synchronization
198        mappings.insert("sync_threads", WgslIntrinsic::WorkgroupBarrier);
199        mappings.insert("thread_fence", WgslIntrinsic::StorageBarrier);
200        mappings.insert("thread_fence_block", WgslIntrinsic::WorkgroupBarrier);
201
202        // Atomics
203        mappings.insert("atomic_add", WgslIntrinsic::AtomicAdd);
204        mappings.insert("atomic_sub", WgslIntrinsic::AtomicSub);
205        mappings.insert("atomic_min", WgslIntrinsic::AtomicMin);
206        mappings.insert("atomic_max", WgslIntrinsic::AtomicMax);
207        mappings.insert("atomic_exchange", WgslIntrinsic::AtomicExchange);
208        mappings.insert("atomic_cas", WgslIntrinsic::AtomicCompareExchangeWeak);
209        mappings.insert("atomic_load", WgslIntrinsic::AtomicLoad);
210        mappings.insert("atomic_store", WgslIntrinsic::AtomicStore);
211
212        // Math functions
213        mappings.insert("sqrt", WgslIntrinsic::Sqrt);
214        mappings.insert("rsqrt", WgslIntrinsic::InverseSqrt);
215        mappings.insert("abs", WgslIntrinsic::Abs);
216        mappings.insert("floor", WgslIntrinsic::Floor);
217        mappings.insert("ceil", WgslIntrinsic::Ceil);
218        mappings.insert("round", WgslIntrinsic::Round);
219        mappings.insert("sin", WgslIntrinsic::Sin);
220        mappings.insert("cos", WgslIntrinsic::Cos);
221        mappings.insert("tan", WgslIntrinsic::Tan);
222        mappings.insert("exp", WgslIntrinsic::Exp);
223        mappings.insert("log", WgslIntrinsic::Log);
224        mappings.insert("powf", WgslIntrinsic::Pow);
225        mappings.insert("min", WgslIntrinsic::Min);
226        mappings.insert("max", WgslIntrinsic::Max);
227        mappings.insert("clamp", WgslIntrinsic::Clamp);
228        mappings.insert("fma", WgslIntrinsic::Fma);
229        mappings.insert("mix", WgslIntrinsic::Mix);
230
231        // Subgroup operations
232        mappings.insert("warp_shuffle", WgslIntrinsic::SubgroupShuffle);
233        mappings.insert("warp_shuffle_up", WgslIntrinsic::SubgroupShuffleUp);
234        mappings.insert("warp_shuffle_down", WgslIntrinsic::SubgroupShuffleDown);
235        mappings.insert("warp_shuffle_xor", WgslIntrinsic::SubgroupShuffleXor);
236        mappings.insert("warp_ballot", WgslIntrinsic::SubgroupBallot);
237        mappings.insert("warp_all", WgslIntrinsic::SubgroupAll);
238        mappings.insert("warp_any", WgslIntrinsic::SubgroupAny);
239        mappings.insert("lane_id", WgslIntrinsic::SubgroupInvocationId);
240        mappings.insert("warp_size", WgslIntrinsic::SubgroupSize);
241
242        Self { mappings }
243    }
244
245    /// Look up an intrinsic by Rust DSL function name.
246    pub fn lookup(&self, name: &str) -> Option<WgslIntrinsic> {
247        self.mappings.get(name).copied()
248    }
249
250    /// Check if a function name is a known intrinsic.
251    pub fn is_intrinsic(&self, name: &str) -> bool {
252        self.mappings.contains_key(name)
253    }
254
255    /// Get all intrinsics that require the subgroup extension.
256    pub fn subgroup_intrinsics(&self) -> Vec<(&'static str, WgslIntrinsic)> {
257        self.mappings
258            .iter()
259            .filter(|(_, intrinsic)| intrinsic.requires_subgroup_extension())
260            .map(|(&name, &intrinsic)| (name, intrinsic))
261            .collect()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_registry_lookup() {
271        let registry = IntrinsicRegistry::new();
272        assert_eq!(
273            registry.lookup("thread_idx_x"),
274            Some(WgslIntrinsic::LocalInvocationIdX)
275        );
276        assert_eq!(
277            registry.lookup("sync_threads"),
278            Some(WgslIntrinsic::WorkgroupBarrier)
279        );
280        assert_eq!(registry.lookup("unknown_function"), None);
281    }
282
283    #[test]
284    fn test_intrinsic_wgsl_output() {
285        assert_eq!(
286            WgslIntrinsic::LocalInvocationIdX.to_wgsl(),
287            "local_invocation_id.x"
288        );
289        assert_eq!(
290            WgslIntrinsic::WorkgroupBarrier.to_wgsl(),
291            "workgroupBarrier()"
292        );
293        assert_eq!(WgslIntrinsic::Sqrt.to_wgsl(), "sqrt");
294    }
295
296    #[test]
297    fn test_subgroup_extension_detection() {
298        assert!(WgslIntrinsic::SubgroupShuffle.requires_subgroup_extension());
299        assert!(!WgslIntrinsic::Sqrt.requires_subgroup_extension());
300        assert!(!WgslIntrinsic::WorkgroupBarrier.requires_subgroup_extension());
301    }
302}