Skip to main content

ringkernel_codegen/
dsl_common.rs

1//! Common DSL marker functions shared across GPU backends.
2//!
3//! These functions provide CPU-fallback implementations of GPU intrinsics.
4//! During transpilation, the CUDA and WGSL code generators recognize these
5//! function names and replace them with their backend-specific equivalents.
6//!
7//! # Thread/Block Indices
8//!
9//! | Function | CUDA | WGSL |
10//! |----------|------|------|
11//! | `thread_idx_x()` | `threadIdx.x` | `local_invocation_id.x` |
12//! | `block_idx_x()` | `blockIdx.x` | `workgroup_id.x` |
13//! | `block_dim_x()` | `blockDim.x` | `WORKGROUP_SIZE_X` |
14//! | `grid_dim_x()` | `gridDim.x` | `num_workgroups.x` |
15//!
16//! # Math Functions
17//!
18//! All math functions use Rust's standard library for CPU fallback:
19//! `sqrt`, `rsqrt`, `floor`, `ceil`, `round`, `sin`, `cos`, `tan`,
20//! `exp`, `log`, `fma`.
21
22// =============================================================================
23// Thread/Block Index Functions
24// =============================================================================
25
26/// Get the thread index within a block/workgroup (x dimension).
27///
28/// - CUDA: `threadIdx.x`
29/// - WGSL: `local_invocation_id.x`
30#[inline(always)]
31pub fn thread_idx_x() -> i32 {
32    0 // CPU fallback: single-threaded execution
33}
34
35/// Get the thread index within a block/workgroup (y dimension).
36///
37/// - CUDA: `threadIdx.y`
38/// - WGSL: `local_invocation_id.y`
39#[inline(always)]
40pub fn thread_idx_y() -> i32 {
41    0
42}
43
44/// Get the thread index within a block/workgroup (z dimension).
45///
46/// - CUDA: `threadIdx.z`
47/// - WGSL: `local_invocation_id.z`
48#[inline(always)]
49pub fn thread_idx_z() -> i32 {
50    0
51}
52
53/// Get the block/workgroup index (x dimension).
54///
55/// - CUDA: `blockIdx.x`
56/// - WGSL: `workgroup_id.x`
57#[inline(always)]
58pub fn block_idx_x() -> i32 {
59    0
60}
61
62/// Get the block/workgroup index (y dimension).
63///
64/// - CUDA: `blockIdx.y`
65/// - WGSL: `workgroup_id.y`
66#[inline(always)]
67pub fn block_idx_y() -> i32 {
68    0
69}
70
71/// Get the block/workgroup index (z dimension).
72///
73/// - CUDA: `blockIdx.z`
74/// - WGSL: `workgroup_id.z`
75#[inline(always)]
76pub fn block_idx_z() -> i32 {
77    0
78}
79
80/// Get the block/workgroup dimension (x dimension).
81///
82/// - CUDA: `blockDim.x`
83/// - WGSL: `WORKGROUP_SIZE_X`
84#[inline(always)]
85pub fn block_dim_x() -> i32 {
86    1 // CPU fallback: one thread per block
87}
88
89/// Get the block/workgroup dimension (y dimension).
90///
91/// - CUDA: `blockDim.y`
92/// - WGSL: `WORKGROUP_SIZE_Y`
93#[inline(always)]
94pub fn block_dim_y() -> i32 {
95    1
96}
97
98/// Get the block/workgroup dimension (z dimension).
99///
100/// - CUDA: `blockDim.z`
101/// - WGSL: `WORKGROUP_SIZE_Z`
102#[inline(always)]
103pub fn block_dim_z() -> i32 {
104    1
105}
106
107/// Get the grid/dispatch dimension (x dimension).
108///
109/// - CUDA: `gridDim.x`
110/// - WGSL: `num_workgroups.x`
111#[inline(always)]
112pub fn grid_dim_x() -> i32 {
113    1 // CPU fallback: one block in the grid
114}
115
116/// Get the grid/dispatch dimension (y dimension).
117///
118/// - CUDA: `gridDim.y`
119/// - WGSL: `num_workgroups.y`
120#[inline(always)]
121pub fn grid_dim_y() -> i32 {
122    1
123}
124
125/// Get the grid/dispatch dimension (z dimension).
126///
127/// - CUDA: `gridDim.z`
128/// - WGSL: `num_workgroups.z`
129#[inline(always)]
130pub fn grid_dim_z() -> i32 {
131    1
132}
133
134// =============================================================================
135// Synchronization Functions
136// =============================================================================
137
138/// Synchronize all threads in a block/workgroup.
139///
140/// - CUDA: `__syncthreads()`
141/// - WGSL: `workgroupBarrier()`
142#[inline(always)]
143pub fn sync_threads() {
144    // CPU fallback: no-op (single-threaded)
145}
146
147/// Thread/storage memory fence.
148///
149/// - CUDA: `__threadfence()`
150/// - WGSL: `storageBarrier()`
151#[inline(always)]
152pub fn thread_fence() {
153    std::sync::atomic::fence(std::sync::atomic::Ordering::SeqCst);
154}
155
156/// Block-level/workgroup memory fence.
157///
158/// - CUDA: `__threadfence_block()`
159/// - WGSL: `workgroupBarrier()`
160#[inline(always)]
161pub fn thread_fence_block() {
162    std::sync::atomic::fence(std::sync::atomic::Ordering::Release);
163}
164
165// =============================================================================
166// Math Functions (f32)
167// =============================================================================
168
169/// Square root.
170///
171/// - CUDA: `sqrtf(x)`
172/// - WGSL: `sqrt(x)`
173#[inline(always)]
174pub fn sqrt(x: f32) -> f32 {
175    x.sqrt()
176}
177
178/// Reciprocal square root.
179///
180/// - CUDA: `rsqrtf(x)`
181/// - WGSL: `inverseSqrt(x)`
182#[inline(always)]
183pub fn rsqrt(x: f32) -> f32 {
184    1.0 / x.sqrt()
185}
186
187/// Floor.
188///
189/// - CUDA: `floorf(x)`
190/// - WGSL: `floor(x)`
191#[inline(always)]
192pub fn floor(x: f32) -> f32 {
193    x.floor()
194}
195
196/// Ceiling.
197///
198/// - CUDA: `ceilf(x)`
199/// - WGSL: `ceil(x)`
200#[inline(always)]
201pub fn ceil(x: f32) -> f32 {
202    x.ceil()
203}
204
205/// Round to nearest.
206///
207/// - CUDA: `roundf(x)`
208/// - WGSL: `round(x)`
209#[inline(always)]
210pub fn round(x: f32) -> f32 {
211    x.round()
212}
213
214/// Sine.
215///
216/// - CUDA: `sinf(x)`
217/// - WGSL: `sin(x)`
218#[inline(always)]
219pub fn sin(x: f32) -> f32 {
220    x.sin()
221}
222
223/// Cosine.
224///
225/// - CUDA: `cosf(x)`
226/// - WGSL: `cos(x)`
227#[inline(always)]
228pub fn cos(x: f32) -> f32 {
229    x.cos()
230}
231
232/// Tangent.
233///
234/// - CUDA: `tanf(x)`
235/// - WGSL: `tan(x)`
236#[inline(always)]
237pub fn tan(x: f32) -> f32 {
238    x.tan()
239}
240
241/// Exponential (base e).
242///
243/// - CUDA: `expf(x)`
244/// - WGSL: `exp(x)`
245#[inline(always)]
246pub fn exp(x: f32) -> f32 {
247    x.exp()
248}
249
250/// Natural logarithm (base e).
251///
252/// - CUDA: `logf(x)`
253/// - WGSL: `log(x)`
254#[inline(always)]
255pub fn log(x: f32) -> f32 {
256    x.ln()
257}
258
259/// Fused multiply-add: `a * b + c`.
260///
261/// - CUDA: `fmaf(a, b, c)`
262/// - WGSL: `fma(a, b, c)`
263#[inline(always)]
264pub fn fma(a: f32, b: f32, c: f32) -> f32 {
265    a.mul_add(b, c)
266}
267
268/// Power: `x^y`.
269///
270/// - CUDA: `powf(x, y)`
271/// - WGSL: `pow(x, y)`
272#[inline(always)]
273pub fn powf(x: f32, y: f32) -> f32 {
274    x.powf(y)
275}
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn test_thread_indices() {
283        assert_eq!(thread_idx_x(), 0);
284        assert_eq!(thread_idx_y(), 0);
285        assert_eq!(thread_idx_z(), 0);
286    }
287
288    #[test]
289    fn test_block_indices() {
290        assert_eq!(block_idx_x(), 0);
291        assert_eq!(block_idx_y(), 0);
292        assert_eq!(block_idx_z(), 0);
293    }
294
295    #[test]
296    fn test_dimensions() {
297        assert_eq!(block_dim_x(), 1);
298        assert_eq!(block_dim_y(), 1);
299        assert_eq!(block_dim_z(), 1);
300        assert_eq!(grid_dim_x(), 1);
301        assert_eq!(grid_dim_y(), 1);
302        assert_eq!(grid_dim_z(), 1);
303    }
304
305    #[test]
306    fn test_sync_compiles() {
307        sync_threads();
308        thread_fence();
309        thread_fence_block();
310    }
311
312    #[test]
313    fn test_math_functions() {
314        assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
315        assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
316        assert_eq!(floor(3.7), 3.0);
317        assert_eq!(ceil(3.2), 4.0);
318        assert_eq!(round(3.5), 4.0);
319        assert!((sin(0.0)).abs() < 1e-6);
320        assert!((cos(0.0) - 1.0).abs() < 1e-6);
321        assert!((tan(0.0)).abs() < 1e-6);
322        assert!((exp(0.0) - 1.0).abs() < 1e-6);
323        assert!((log(1.0)).abs() < 1e-6);
324        assert!((fma(2.0, 3.0, 1.0) - 7.0).abs() < 1e-6);
325        assert!((powf(2.0, 3.0) - 8.0).abs() < 1e-6);
326    }
327}