ringkernel_codegen/dsl_common.rs
1//! Common DSL marker functions shared across GPU backends.
2//!
3//! These functions provide CPU-fallback implementations of GPU intrinsics.
4//! During transpilation, the CUDA and WGSL code generators recognize these
5//! function names and replace them with their backend-specific equivalents.
6//!
7//! # Thread/Block Indices
8//!
9//! | Function | CUDA | WGSL |
10//! |----------|------|------|
11//! | `thread_idx_x()` | `threadIdx.x` | `local_invocation_id.x` |
12//! | `block_idx_x()` | `blockIdx.x` | `workgroup_id.x` |
13//! | `block_dim_x()` | `blockDim.x` | `WORKGROUP_SIZE_X` |
14//! | `grid_dim_x()` | `gridDim.x` | `num_workgroups.x` |
15//!
16//! # Math Functions
17//!
18//! All math functions use Rust's standard library for CPU fallback:
19//! `sqrt`, `rsqrt`, `floor`, `ceil`, `round`, `sin`, `cos`, `tan`,
20//! `exp`, `log`, `fma`.
21
22// =============================================================================
23// Thread/Block Index Functions
24// =============================================================================
25
26/// Get the thread index within a block/workgroup (x dimension).
27///
28/// - CUDA: `threadIdx.x`
29/// - WGSL: `local_invocation_id.x`
30#[inline(always)]
31pub fn thread_idx_x() -> i32 {
32 0 // CPU fallback: single-threaded execution
33}
34
35/// Get the thread index within a block/workgroup (y dimension).
36///
37/// - CUDA: `threadIdx.y`
38/// - WGSL: `local_invocation_id.y`
39#[inline(always)]
40pub fn thread_idx_y() -> i32 {
41 0
42}
43
44/// Get the thread index within a block/workgroup (z dimension).
45///
46/// - CUDA: `threadIdx.z`
47/// - WGSL: `local_invocation_id.z`
48#[inline(always)]
49pub fn thread_idx_z() -> i32 {
50 0
51}
52
53/// Get the block/workgroup index (x dimension).
54///
55/// - CUDA: `blockIdx.x`
56/// - WGSL: `workgroup_id.x`
57#[inline(always)]
58pub fn block_idx_x() -> i32 {
59 0
60}
61
62/// Get the block/workgroup index (y dimension).
63///
64/// - CUDA: `blockIdx.y`
65/// - WGSL: `workgroup_id.y`
66#[inline(always)]
67pub fn block_idx_y() -> i32 {
68 0
69}
70
71/// Get the block/workgroup index (z dimension).
72///
73/// - CUDA: `blockIdx.z`
74/// - WGSL: `workgroup_id.z`
75#[inline(always)]
76pub fn block_idx_z() -> i32 {
77 0
78}
79
80/// Get the block/workgroup dimension (x dimension).
81///
82/// - CUDA: `blockDim.x`
83/// - WGSL: `WORKGROUP_SIZE_X`
84#[inline(always)]
85pub fn block_dim_x() -> i32 {
86 1 // CPU fallback: one thread per block
87}
88
89/// Get the block/workgroup dimension (y dimension).
90///
91/// - CUDA: `blockDim.y`
92/// - WGSL: `WORKGROUP_SIZE_Y`
93#[inline(always)]
94pub fn block_dim_y() -> i32 {
95 1
96}
97
98/// Get the block/workgroup dimension (z dimension).
99///
100/// - CUDA: `blockDim.z`
101/// - WGSL: `WORKGROUP_SIZE_Z`
102#[inline(always)]
103pub fn block_dim_z() -> i32 {
104 1
105}
106
107/// Get the grid/dispatch dimension (x dimension).
108///
109/// - CUDA: `gridDim.x`
110/// - WGSL: `num_workgroups.x`
111#[inline(always)]
112pub fn grid_dim_x() -> i32 {
113 1 // CPU fallback: one block in the grid
114}
115
116/// Get the grid/dispatch dimension (y dimension).
117///
118/// - CUDA: `gridDim.y`
119/// - WGSL: `num_workgroups.y`
120#[inline(always)]
121pub fn grid_dim_y() -> i32 {
122 1
123}
124
125/// Get the grid/dispatch dimension (z dimension).
126///
127/// - CUDA: `gridDim.z`
128/// - WGSL: `num_workgroups.z`
129#[inline(always)]
130pub fn grid_dim_z() -> i32 {
131 1
132}
133
134// =============================================================================
135// Synchronization Functions
136// =============================================================================
137
138/// Synchronize all threads in a block/workgroup.
139///
140/// - CUDA: `__syncthreads()`
141/// - WGSL: `workgroupBarrier()`
142#[inline(always)]
143pub fn sync_threads() {
144 // CPU fallback: no-op (single-threaded)
145}
146
147/// Thread/storage memory fence.
148///
149/// - CUDA: `__threadfence()`
150/// - WGSL: `storageBarrier()`
151#[inline(always)]
152pub fn thread_fence() {
153 std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
154}
155
156/// Block-level/workgroup memory fence.
157///
158/// - CUDA: `__threadfence_block()`
159/// - WGSL: `workgroupBarrier()`
160#[inline(always)]
161pub fn thread_fence_block() {
162 std::sync::atomic::fence(std::sync::atomic::Ordering::Release);
163}
164
165// =============================================================================
166// Math Functions (f32)
167// =============================================================================
168
169/// Square root.
170///
171/// - CUDA: `sqrtf(x)`
172/// - WGSL: `sqrt(x)`
173#[inline(always)]
174pub fn sqrt(x: f32) -> f32 {
175 x.sqrt()
176}
177
178/// Reciprocal square root.
179///
180/// - CUDA: `rsqrtf(x)`
181/// - WGSL: `inverseSqrt(x)`
182#[inline(always)]
183pub fn rsqrt(x: f32) -> f32 {
184 1.0 / x.sqrt()
185}
186
187/// Floor.
188///
189/// - CUDA: `floorf(x)`
190/// - WGSL: `floor(x)`
191#[inline(always)]
192pub fn floor(x: f32) -> f32 {
193 x.floor()
194}
195
196/// Ceiling.
197///
198/// - CUDA: `ceilf(x)`
199/// - WGSL: `ceil(x)`
200#[inline(always)]
201pub fn ceil(x: f32) -> f32 {
202 x.ceil()
203}
204
205/// Round to nearest.
206///
207/// - CUDA: `roundf(x)`
208/// - WGSL: `round(x)`
209#[inline(always)]
210pub fn round(x: f32) -> f32 {
211 x.round()
212}
213
214/// Sine.
215///
216/// - CUDA: `sinf(x)`
217/// - WGSL: `sin(x)`
218#[inline(always)]
219pub fn sin(x: f32) -> f32 {
220 x.sin()
221}
222
223/// Cosine.
224///
225/// - CUDA: `cosf(x)`
226/// - WGSL: `cos(x)`
227#[inline(always)]
228pub fn cos(x: f32) -> f32 {
229 x.cos()
230}
231
232/// Tangent.
233///
234/// - CUDA: `tanf(x)`
235/// - WGSL: `tan(x)`
236#[inline(always)]
237pub fn tan(x: f32) -> f32 {
238 x.tan()
239}
240
241/// Exponential (base e).
242///
243/// - CUDA: `expf(x)`
244/// - WGSL: `exp(x)`
245#[inline(always)]
246pub fn exp(x: f32) -> f32 {
247 x.exp()
248}
249
250/// Natural logarithm (base e).
251///
252/// - CUDA: `logf(x)`
253/// - WGSL: `log(x)`
254#[inline(always)]
255pub fn log(x: f32) -> f32 {
256 x.ln()
257}
258
259/// Fused multiply-add: `a * b + c`.
260///
261/// - CUDA: `fmaf(a, b, c)`
262/// - WGSL: `fma(a, b, c)`
263#[inline(always)]
264pub fn fma(a: f32, b: f32, c: f32) -> f32 {
265 a.mul_add(b, c)
266}
267
268/// Power: `x^y`.
269///
270/// - CUDA: `powf(x, y)`
271/// - WGSL: `pow(x, y)`
272#[inline(always)]
273pub fn powf(x: f32, y: f32) -> f32 {
274 x.powf(y)
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn test_thread_indices() {
283 assert_eq!(thread_idx_x(), 0);
284 assert_eq!(thread_idx_y(), 0);
285 assert_eq!(thread_idx_z(), 0);
286 }
287
288 #[test]
289 fn test_block_indices() {
290 assert_eq!(block_idx_x(), 0);
291 assert_eq!(block_idx_y(), 0);
292 assert_eq!(block_idx_z(), 0);
293 }
294
295 #[test]
296 fn test_dimensions() {
297 assert_eq!(block_dim_x(), 1);
298 assert_eq!(block_dim_y(), 1);
299 assert_eq!(block_dim_z(), 1);
300 assert_eq!(grid_dim_x(), 1);
301 assert_eq!(grid_dim_y(), 1);
302 assert_eq!(grid_dim_z(), 1);
303 }
304
305 #[test]
306 fn test_sync_compiles() {
307 sync_threads();
308 thread_fence();
309 thread_fence_block();
310 }
311
312 #[test]
313 fn test_math_functions() {
314 assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
315 assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
316 assert_eq!(floor(3.7), 3.0);
317 assert_eq!(ceil(3.2), 4.0);
318 assert_eq!(round(3.5), 4.0);
319 assert!((sin(0.0)).abs() < 1e-6);
320 assert!((cos(0.0) - 1.0).abs() < 1e-6);
321 assert!((tan(0.0)).abs() < 1e-6);
322 assert!((exp(0.0) - 1.0).abs() < 1e-6);
323 assert!((log(1.0)).abs() < 1e-6);
324 assert!((fma(2.0, 3.0, 1.0) - 7.0).abs() < 1e-6);
325 assert!((powf(2.0, 3.0) - 8.0).abs() < 1e-6);
326 }
327}