ringkernel_wgpu_codegen/
dsl.rs

1//! DSL marker functions for WGSL code generation.
2//!
3//! These functions are compile-time markers that the transpiler recognizes
4//! and converts to WGSL intrinsics. They do nothing at Rust compile time.
5//!
6//! # Thread/Workgroup Indices
7//!
8//! | Function | WGSL Equivalent |
9//! |----------|-----------------|
10//! | `thread_idx_x()` | `local_invocation_id.x` |
11//! | `thread_idx_y()` | `local_invocation_id.y` |
12//! | `thread_idx_z()` | `local_invocation_id.z` |
13//! | `block_idx_x()` | `workgroup_id.x` |
14//! | `block_idx_y()` | `workgroup_id.y` |
15//! | `block_idx_z()` | `workgroup_id.z` |
16//! | `block_dim_x()` | `WORKGROUP_SIZE_X` (constant) |
17//! | `global_thread_id()` | `global_invocation_id.x` |
18
19#![allow(unused_variables)]
20
21// =============================================================================
22// Thread/Workgroup Indices
23// =============================================================================
24
25/// Get the x-component of the local thread index within the workgroup.
26/// Maps to `local_invocation_id.x` in WGSL.
27#[inline(always)]
28pub fn thread_idx_x() -> i32 {
29    0
30}
31
32/// Get the y-component of the local thread index within the workgroup.
33/// Maps to `local_invocation_id.y` in WGSL.
34#[inline(always)]
35pub fn thread_idx_y() -> i32 {
36    0
37}
38
39/// Get the z-component of the local thread index within the workgroup.
40/// Maps to `local_invocation_id.z` in WGSL.
41#[inline(always)]
42pub fn thread_idx_z() -> i32 {
43    0
44}
45
46/// Get the x-component of the workgroup index.
47/// Maps to `workgroup_id.x` in WGSL.
48#[inline(always)]
49pub fn block_idx_x() -> i32 {
50    0
51}
52
53/// Get the y-component of the workgroup index.
54/// Maps to `workgroup_id.y` in WGSL.
55#[inline(always)]
56pub fn block_idx_y() -> i32 {
57    0
58}
59
60/// Get the z-component of the workgroup index.
61/// Maps to `workgroup_id.z` in WGSL.
62#[inline(always)]
63pub fn block_idx_z() -> i32 {
64    0
65}
66
67/// Get the x-dimension of the workgroup size.
68/// Maps to a compile-time constant `WORKGROUP_SIZE_X` in WGSL.
69#[inline(always)]
70pub fn block_dim_x() -> i32 {
71    0
72}
73
74/// Get the y-dimension of the workgroup size.
75/// Maps to a compile-time constant `WORKGROUP_SIZE_Y` in WGSL.
76#[inline(always)]
77pub fn block_dim_y() -> i32 {
78    0
79}
80
81/// Get the z-dimension of the workgroup size.
82/// Maps to a compile-time constant `WORKGROUP_SIZE_Z` in WGSL.
83#[inline(always)]
84pub fn block_dim_z() -> i32 {
85    0
86}
87
88/// Get the x-component of the number of workgroups in the dispatch.
89/// Maps to `num_workgroups.x` in WGSL.
90#[inline(always)]
91pub fn grid_dim_x() -> i32 {
92    0
93}
94
95/// Get the y-component of the number of workgroups in the dispatch.
96/// Maps to `num_workgroups.y` in WGSL.
97#[inline(always)]
98pub fn grid_dim_y() -> i32 {
99    0
100}
101
102/// Get the z-component of the number of workgroups in the dispatch.
103/// Maps to `num_workgroups.z` in WGSL.
104#[inline(always)]
105pub fn grid_dim_z() -> i32 {
106    0
107}
108
109/// Get the global thread ID (x-component).
110/// Maps to `global_invocation_id.x` in WGSL.
111#[inline(always)]
112pub fn global_thread_id() -> i32 {
113    0
114}
115
116/// Get the global thread ID in the y-dimension.
117/// Maps to `global_invocation_id.y` in WGSL.
118#[inline(always)]
119pub fn global_thread_id_y() -> i32 {
120    0
121}
122
123/// Get the global thread ID in the z-dimension.
124/// Maps to `global_invocation_id.z` in WGSL.
125#[inline(always)]
126pub fn global_thread_id_z() -> i32 {
127    0
128}
129
130// =============================================================================
131// Synchronization
132// =============================================================================
133
134/// Synchronize all threads in the workgroup.
135/// Maps to `workgroupBarrier()` in WGSL.
136#[inline(always)]
137pub fn sync_threads() {}
138
139/// Memory fence for storage buffers.
140/// Maps to `storageBarrier()` in WGSL.
141#[inline(always)]
142pub fn thread_fence() {}
143
144/// Memory fence within the workgroup.
145/// Maps to `workgroupBarrier()` in WGSL.
146#[inline(always)]
147pub fn thread_fence_block() {}
148
149// =============================================================================
150// Atomic Operations (32-bit)
151// =============================================================================
152
153/// Atomic add. Maps to `atomicAdd(&var, value)` in WGSL.
154#[inline(always)]
155pub fn atomic_add<T>(_ptr: &T, _val: T) -> T {
156    unimplemented!("atomic_add is a transpiler intrinsic")
157}
158
159/// Atomic subtract. Maps to `atomicSub(&var, value)` in WGSL.
160#[inline(always)]
161pub fn atomic_sub<T>(_ptr: &T, _val: T) -> T {
162    unimplemented!("atomic_sub is a transpiler intrinsic")
163}
164
165/// Atomic minimum. Maps to `atomicMin(&var, value)` in WGSL.
166#[inline(always)]
167pub fn atomic_min<T>(_ptr: &T, _val: T) -> T {
168    unimplemented!("atomic_min is a transpiler intrinsic")
169}
170
171/// Atomic maximum. Maps to `atomicMax(&var, value)` in WGSL.
172#[inline(always)]
173pub fn atomic_max<T>(_ptr: &T, _val: T) -> T {
174    unimplemented!("atomic_max is a transpiler intrinsic")
175}
176
177/// Atomic exchange. Maps to `atomicExchange(&var, value)` in WGSL.
178#[inline(always)]
179pub fn atomic_exchange<T>(_ptr: &T, _val: T) -> T {
180    unimplemented!("atomic_exchange is a transpiler intrinsic")
181}
182
183/// Atomic compare-and-swap. Maps to `atomicCompareExchangeWeak(&var, compare, value)` in WGSL.
184/// Returns the old value.
185#[inline(always)]
186pub fn atomic_cas<T>(_ptr: &T, _compare: T, _val: T) -> T {
187    unimplemented!("atomic_cas is a transpiler intrinsic")
188}
189
190/// Atomic load. Maps to `atomicLoad(&var)` in WGSL.
191#[inline(always)]
192pub fn atomic_load<T>(_ptr: &T) -> T {
193    unimplemented!("atomic_load is a transpiler intrinsic")
194}
195
196/// Atomic store. Maps to `atomicStore(&var, value)` in WGSL.
197#[inline(always)]
198pub fn atomic_store<T>(_ptr: &T, _val: T) {
199    unimplemented!("atomic_store is a transpiler intrinsic")
200}
201
202// =============================================================================
203// Math Functions
204// =============================================================================
205
206/// Square root. Maps to `sqrt(x)` in WGSL.
207#[inline(always)]
208pub fn sqrt(x: f32) -> f32 {
209    x.sqrt()
210}
211
212/// Reciprocal square root. Maps to `inverseSqrt(x)` in WGSL.
213#[inline(always)]
214pub fn rsqrt(x: f32) -> f32 {
215    1.0 / x.sqrt()
216}
217
218/// Absolute value. Maps to `abs(x)` in WGSL.
219#[inline(always)]
220pub fn abs<T: num_traits::Signed>(x: T) -> T {
221    x.abs()
222}
223
224/// Floor. Maps to `floor(x)` in WGSL.
225#[inline(always)]
226pub fn floor(x: f32) -> f32 {
227    x.floor()
228}
229
230/// Ceiling. Maps to `ceil(x)` in WGSL.
231#[inline(always)]
232pub fn ceil(x: f32) -> f32 {
233    x.ceil()
234}
235
236/// Round to nearest. Maps to `round(x)` in WGSL.
237#[inline(always)]
238pub fn round(x: f32) -> f32 {
239    x.round()
240}
241
242/// Sine. Maps to `sin(x)` in WGSL.
243#[inline(always)]
244pub fn sin(x: f32) -> f32 {
245    x.sin()
246}
247
248/// Cosine. Maps to `cos(x)` in WGSL.
249#[inline(always)]
250pub fn cos(x: f32) -> f32 {
251    x.cos()
252}
253
254/// Tangent. Maps to `tan(x)` in WGSL.
255#[inline(always)]
256pub fn tan(x: f32) -> f32 {
257    x.tan()
258}
259
260/// Exponential. Maps to `exp(x)` in WGSL.
261#[inline(always)]
262pub fn exp(x: f32) -> f32 {
263    x.exp()
264}
265
266/// Natural logarithm. Maps to `log(x)` in WGSL.
267#[inline(always)]
268pub fn log(x: f32) -> f32 {
269    x.ln()
270}
271
272/// Power. Maps to `pow(x, y)` in WGSL.
273#[inline(always)]
274pub fn powf(x: f32, y: f32) -> f32 {
275    x.powf(y)
276}
277
278/// Minimum. Maps to `min(a, b)` in WGSL.
279#[inline(always)]
280pub fn min<T: Ord>(a: T, b: T) -> T {
281    std::cmp::min(a, b)
282}
283
284/// Maximum. Maps to `max(a, b)` in WGSL.
285#[inline(always)]
286pub fn max<T: Ord>(a: T, b: T) -> T {
287    std::cmp::max(a, b)
288}
289
290/// Clamp. Maps to `clamp(x, lo, hi)` in WGSL.
291#[inline(always)]
292pub fn clamp<T: Ord>(x: T, lo: T, hi: T) -> T {
293    std::cmp::max(lo, std::cmp::min(x, hi))
294}
295
296/// Fused multiply-add. Maps to `fma(a, b, c)` in WGSL.
297#[inline(always)]
298pub fn fma(a: f32, b: f32, c: f32) -> f32 {
299    a.mul_add(b, c)
300}
301
302/// Linear interpolation. Maps to `mix(a, b, t)` in WGSL.
303#[inline(always)]
304pub fn mix(a: f32, b: f32, t: f32) -> f32 {
305    a * (1.0 - t) + b * t
306}
307
308// =============================================================================
309// Subgroup Operations (require WGSL extensions)
310// =============================================================================
311
312/// Shuffle within subgroup. Maps to `subgroupShuffle(value, lane)` in WGSL.
313/// Requires: `enable chromium_experimental_subgroups;`
314#[inline(always)]
315pub fn warp_shuffle<T>(_val: T, _lane: u32) -> T {
316    unimplemented!("warp_shuffle requires WGSL subgroup extension")
317}
318
319/// Shuffle up within subgroup. Maps to `subgroupShuffleUp(value, delta)` in WGSL.
320#[inline(always)]
321pub fn warp_shuffle_up<T>(_val: T, _delta: u32) -> T {
322    unimplemented!("warp_shuffle_up requires WGSL subgroup extension")
323}
324
325/// Shuffle down within subgroup. Maps to `subgroupShuffleDown(value, delta)` in WGSL.
326#[inline(always)]
327pub fn warp_shuffle_down<T>(_val: T, _delta: u32) -> T {
328    unimplemented!("warp_shuffle_down requires WGSL subgroup extension")
329}
330
331/// Shuffle XOR within subgroup. Maps to `subgroupShuffleXor(value, mask)` in WGSL.
332#[inline(always)]
333pub fn warp_shuffle_xor<T>(_val: T, _mask: u32) -> T {
334    unimplemented!("warp_shuffle_xor requires WGSL subgroup extension")
335}
336
337/// Ballot within subgroup. Maps to `subgroupBallot(predicate)` in WGSL.
338/// Returns a vec4<u32> where each bit represents a lane's predicate.
339#[inline(always)]
340pub fn warp_ballot(_pred: bool) -> [u32; 4] {
341    unimplemented!("warp_ballot requires WGSL subgroup extension")
342}
343
344/// Check if all lanes in subgroup satisfy predicate. Maps to `subgroupAll(predicate)` in WGSL.
345#[inline(always)]
346pub fn warp_all(_pred: bool) -> bool {
347    unimplemented!("warp_all requires WGSL subgroup extension")
348}
349
350/// Check if any lane in subgroup satisfies predicate. Maps to `subgroupAny(predicate)` in WGSL.
351#[inline(always)]
352pub fn warp_any(_pred: bool) -> bool {
353    unimplemented!("warp_any requires WGSL subgroup extension")
354}
355
356/// Get the lane ID within the subgroup.
357/// Maps to `subgroup_invocation_id` builtin in WGSL.
358#[inline(always)]
359pub fn lane_id() -> u32 {
360    unimplemented!("lane_id requires WGSL subgroup extension")
361}
362
363/// Get the subgroup size.
364/// Maps to `subgroup_size` builtin in WGSL.
365#[inline(always)]
366pub fn warp_size() -> u32 {
367    unimplemented!("warp_size requires WGSL subgroup extension")
368}
369
370// Workaround: num_traits isn't in dependencies, so use a simple trait
371mod num_traits {
372    pub trait Signed {
373        fn abs(self) -> Self;
374    }
375
376    impl Signed for f32 {
377        fn abs(self) -> Self {
378            f32::abs(self)
379        }
380    }
381
382    impl Signed for i32 {
383        fn abs(self) -> Self {
384            i32::abs(self)
385        }
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_math_functions() {
395        // These should just pass through to Rust's implementations
396        assert_eq!(sqrt(4.0), 2.0);
397        assert_eq!(floor(3.7), 3.0);
398        assert_eq!(ceil(3.2), 4.0);
399    }
400
401    #[test]
402    fn test_thread_indices_compile() {
403        // These just need to compile - they return 0 as placeholders
404        let _ = thread_idx_x();
405        let _ = block_idx_x();
406        let _ = block_dim_x();
407    }
408
409    #[test]
410    fn test_sync_compiles() {
411        // These just need to compile
412        sync_threads();
413        thread_fence();
414    }
415}