ringkernel_wgpu_codegen/dsl.rs
1//! DSL marker functions for WGSL code generation.
2//!
3//! These functions are compile-time markers that the transpiler recognizes
4//! and converts to WGSL intrinsics. They do nothing at Rust compile time.
5//!
6//! Common functions (thread indices, basic math, synchronization) are shared with
7//! the CUDA backend via `ringkernel_codegen::dsl_common`.
8//!
9//! # Thread/Workgroup Indices
10//!
11//! | Function | WGSL Equivalent |
12//! |----------|-----------------|
13//! | `thread_idx_x()` | `local_invocation_id.x` |
14//! | `thread_idx_y()` | `local_invocation_id.y` |
15//! | `thread_idx_z()` | `local_invocation_id.z` |
16//! | `block_idx_x()` | `workgroup_id.x` |
17//! | `block_idx_y()` | `workgroup_id.y` |
18//! | `block_idx_z()` | `workgroup_id.z` |
19//! | `block_dim_x()` | `WORKGROUP_SIZE_X` (constant) |
20//! | `global_thread_id()` | `global_invocation_id.x` |
21
22#![allow(unused_variables)]
23
24// Re-export common DSL functions shared across backends.
25// These provide: thread/block indices, sync primitives, and basic math.
26pub use ringkernel_codegen::dsl_common::{
27 block_dim_x,
28 block_dim_y,
29 block_dim_z,
30 block_idx_x,
31 block_idx_y,
32 block_idx_z,
33 ceil,
34 cos,
35 exp,
36 floor,
37 fma,
38 grid_dim_x,
39 grid_dim_y,
40 grid_dim_z,
41 log,
42 powf,
43 round,
44 rsqrt,
45 sin,
46 // Math
47 sqrt,
48 // Synchronization
49 sync_threads,
50 tan,
51 thread_fence,
52 thread_fence_block,
53 // Thread/block indices
54 thread_idx_x,
55 thread_idx_y,
56 thread_idx_z,
57};
58
59// =============================================================================
60// WGSL-Specific Thread/Workgroup Indices
61// =============================================================================
62
63/// Get the global thread ID (x-component).
64/// Maps to `global_invocation_id.x` in WGSL.
65#[inline(always)]
66pub fn global_thread_id() -> i32 {
67 0
68}
69
70/// Get the global thread ID in the y-dimension.
71/// Maps to `global_invocation_id.y` in WGSL.
72#[inline(always)]
73pub fn global_thread_id_y() -> i32 {
74 0
75}
76
77/// Get the global thread ID in the z-dimension.
78/// Maps to `global_invocation_id.z` in WGSL.
79#[inline(always)]
80pub fn global_thread_id_z() -> i32 {
81 0
82}
83
84// =============================================================================
85// Atomic Operations (32-bit)
86// =============================================================================
87
88/// Atomic add. Maps to `atomicAdd(&var, value)` in WGSL.
89#[inline(always)]
90pub fn atomic_add<T>(_ptr: &T, _val: T) -> T {
91 unimplemented!("atomic_add is a transpiler intrinsic")
92}
93
94/// Atomic subtract. Maps to `atomicSub(&var, value)` in WGSL.
95#[inline(always)]
96pub fn atomic_sub<T>(_ptr: &T, _val: T) -> T {
97 unimplemented!("atomic_sub is a transpiler intrinsic")
98}
99
100/// Atomic minimum. Maps to `atomicMin(&var, value)` in WGSL.
101#[inline(always)]
102pub fn atomic_min<T>(_ptr: &T, _val: T) -> T {
103 unimplemented!("atomic_min is a transpiler intrinsic")
104}
105
106/// Atomic maximum. Maps to `atomicMax(&var, value)` in WGSL.
107#[inline(always)]
108pub fn atomic_max<T>(_ptr: &T, _val: T) -> T {
109 unimplemented!("atomic_max is a transpiler intrinsic")
110}
111
112/// Atomic exchange. Maps to `atomicExchange(&var, value)` in WGSL.
113#[inline(always)]
114pub fn atomic_exchange<T>(_ptr: &T, _val: T) -> T {
115 unimplemented!("atomic_exchange is a transpiler intrinsic")
116}
117
118/// Atomic compare-and-swap. Maps to `atomicCompareExchangeWeak(&var, compare, value)` in WGSL.
119/// Returns the old value.
120#[inline(always)]
121pub fn atomic_cas<T>(_ptr: &T, _compare: T, _val: T) -> T {
122 unimplemented!("atomic_cas is a transpiler intrinsic")
123}
124
125/// Atomic load. Maps to `atomicLoad(&var)` in WGSL.
126#[inline(always)]
127pub fn atomic_load<T>(_ptr: &T) -> T {
128 unimplemented!("atomic_load is a transpiler intrinsic")
129}
130
131/// Atomic store. Maps to `atomicStore(&var, value)` in WGSL.
132#[inline(always)]
133pub fn atomic_store<T>(_ptr: &T, _val: T) {
134 unimplemented!("atomic_store is a transpiler intrinsic")
135}
136
137// =============================================================================
138// WGSL-Specific Math Functions
139// =============================================================================
140
141/// Absolute value. Maps to `abs(x)` in WGSL.
142#[inline(always)]
143pub fn abs<T: num_traits::Signed>(x: T) -> T {
144 x.abs()
145}
146
147/// Minimum. Maps to `min(a, b)` in WGSL.
148#[inline(always)]
149pub fn min<T: Ord>(a: T, b: T) -> T {
150 std::cmp::min(a, b)
151}
152
153/// Maximum. Maps to `max(a, b)` in WGSL.
154#[inline(always)]
155pub fn max<T: Ord>(a: T, b: T) -> T {
156 std::cmp::max(a, b)
157}
158
159/// Clamp. Maps to `clamp(x, lo, hi)` in WGSL.
160#[inline(always)]
161pub fn clamp<T: Ord>(x: T, lo: T, hi: T) -> T {
162 std::cmp::max(lo, std::cmp::min(x, hi))
163}
164
165/// Linear interpolation. Maps to `mix(a, b, t)` in WGSL.
166#[inline(always)]
167pub fn mix(a: f32, b: f32, t: f32) -> f32 {
168 a * (1.0 - t) + b * t
169}
170
171// =============================================================================
172// Subgroup Operations (require WGSL extensions)
173// =============================================================================
174
175/// Shuffle within subgroup. Maps to `subgroupShuffle(value, lane)` in WGSL.
176/// Requires: `enable chromium_experimental_subgroups;`
177#[inline(always)]
178pub fn warp_shuffle<T>(_val: T, _lane: u32) -> T {
179 unimplemented!("warp_shuffle requires WGSL subgroup extension")
180}
181
182/// Shuffle up within subgroup. Maps to `subgroupShuffleUp(value, delta)` in WGSL.
183#[inline(always)]
184pub fn warp_shuffle_up<T>(_val: T, _delta: u32) -> T {
185 unimplemented!("warp_shuffle_up requires WGSL subgroup extension")
186}
187
188/// Shuffle down within subgroup. Maps to `subgroupShuffleDown(value, delta)` in WGSL.
189#[inline(always)]
190pub fn warp_shuffle_down<T>(_val: T, _delta: u32) -> T {
191 unimplemented!("warp_shuffle_down requires WGSL subgroup extension")
192}
193
194/// Shuffle XOR within subgroup. Maps to `subgroupShuffleXor(value, mask)` in WGSL.
195#[inline(always)]
196pub fn warp_shuffle_xor<T>(_val: T, _mask: u32) -> T {
197 unimplemented!("warp_shuffle_xor requires WGSL subgroup extension")
198}
199
200/// Ballot within subgroup. Maps to `subgroupBallot(predicate)` in WGSL.
201/// Returns a vec4<u32> where each bit represents a lane's predicate.
202#[inline(always)]
203pub fn warp_ballot(_pred: bool) -> [u32; 4] {
204 unimplemented!("warp_ballot requires WGSL subgroup extension")
205}
206
207/// Check if all lanes in subgroup satisfy predicate. Maps to `subgroupAll(predicate)` in WGSL.
208#[inline(always)]
209pub fn warp_all(_pred: bool) -> bool {
210 unimplemented!("warp_all requires WGSL subgroup extension")
211}
212
213/// Check if any lane in subgroup satisfies predicate. Maps to `subgroupAny(predicate)` in WGSL.
214#[inline(always)]
215pub fn warp_any(_pred: bool) -> bool {
216 unimplemented!("warp_any requires WGSL subgroup extension")
217}
218
219/// Get the lane ID within the subgroup.
220/// Maps to `subgroup_invocation_id` builtin in WGSL.
221#[inline(always)]
222pub fn lane_id() -> u32 {
223 unimplemented!("lane_id requires WGSL subgroup extension")
224}
225
226/// Get the subgroup size.
227/// Maps to `subgroup_size` builtin in WGSL.
228#[inline(always)]
229pub fn warp_size() -> u32 {
230 unimplemented!("warp_size requires WGSL subgroup extension")
231}
232
233// Workaround: num_traits isn't in dependencies, so use a simple trait
234mod num_traits {
235 pub trait Signed {
236 fn abs(self) -> Self;
237 }
238
239 impl Signed for f32 {
240 fn abs(self) -> Self {
241 f32::abs(self)
242 }
243 }
244
245 impl Signed for i32 {
246 fn abs(self) -> Self {
247 i32::abs(self)
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_math_functions() {
258 // These should just pass through to Rust's implementations
259 assert_eq!(sqrt(4.0), 2.0);
260 assert_eq!(floor(3.7), 3.0);
261 assert_eq!(ceil(3.2), 4.0);
262 }
263
264 #[test]
265 fn test_thread_indices_compile() {
266 // These just need to compile - they return 0 as placeholders
267 let _ = thread_idx_x();
268 let _ = block_idx_x();
269 let _ = block_dim_x();
270 }
271
272 #[test]
273 fn test_sync_compiles() {
274 // These just need to compile
275 sync_threads();
276 thread_fence();
277 }
278}