ringkernel_cuda_codegen/
dsl.rs

1//! Rust DSL functions for writing CUDA kernels.
2//!
3//! This module provides Rust functions that map to CUDA intrinsics during transpilation.
4//! These functions have CPU fallback implementations for testing but are transpiled
5//! to the corresponding CUDA operations when used in kernel code.
6//!
7//! # Thread/Block Index Access
8//!
9//! ```ignore
10//! use ringkernel_cuda_codegen::dsl::*;
11//!
12//! fn my_kernel(...) {
13//!     let tx = thread_idx_x();  // -> threadIdx.x
14//!     let bx = block_idx_x();   // -> blockIdx.x
15//!     let idx = bx * block_dim_x() + tx;  // Global thread index
16//! }
17//! ```
18//!
19//! # Thread Synchronization
20//!
21//! ```ignore
22//! sync_threads();  // -> __syncthreads()
23//! ```
24//!
25//! # Math Functions
26//!
27//! All standard math functions are available with CPU fallbacks:
28//! - Trigonometric: sin, cos, tan, asin, acos, atan, atan2
29//! - Hyperbolic: sinh, cosh, tanh, asinh, acosh, atanh
30//! - Exponential: exp, exp2, exp10, expm1, log, log2, log10, log1p
31//! - Power: pow, sqrt, rsqrt, cbrt
32//! - Rounding: floor, ceil, round, trunc
33//! - Comparison: fmin, fmax, fdim, copysign
34//!
35//! # Warp Operations
36//!
37//! ```ignore
38//! let mask = warp_active_mask();       // Get active lane mask
39//! let result = warp_reduce_add(mask, value);  // Warp-level sum
40//! let shuffled = warp_shfl(mask, value, lane); // Shuffle
41//! ```
42//!
43//! # Bit Manipulation
44//!
45//! ```ignore
46//! let bits = popc(x);      // Count set bits
47//! let zeros = clz(x);      // Count leading zeros
48//! let rev = brev(x);       // Reverse bits
49//! ```
50
51use std::sync::atomic::{fence, Ordering};
52
53// ============================================================================
54// Thread/Block Index Functions
55// ============================================================================
56
57/// Get the thread index within a block (x dimension).
58/// Transpiles to: `threadIdx.x`
59#[inline]
60pub fn thread_idx_x() -> i32 {
61    0 // CPU fallback: single-threaded execution
62}
63
64/// Get the thread index within a block (y dimension).
65/// Transpiles to: `threadIdx.y`
66#[inline]
67pub fn thread_idx_y() -> i32 {
68    0
69}
70
71/// Get the thread index within a block (z dimension).
72/// Transpiles to: `threadIdx.z`
73#[inline]
74pub fn thread_idx_z() -> i32 {
75    0
76}
77
78/// Get the block index within a grid (x dimension).
79/// Transpiles to: `blockIdx.x`
80#[inline]
81pub fn block_idx_x() -> i32 {
82    0
83}
84
85/// Get the block index within a grid (y dimension).
86/// Transpiles to: `blockIdx.y`
87#[inline]
88pub fn block_idx_y() -> i32 {
89    0
90}
91
92/// Get the block index within a grid (z dimension).
93/// Transpiles to: `blockIdx.z`
94#[inline]
95pub fn block_idx_z() -> i32 {
96    0
97}
98
99/// Get the block dimension (x dimension).
100/// Transpiles to: `blockDim.x`
101#[inline]
102pub fn block_dim_x() -> i32 {
103    1
104}
105
106/// Get the block dimension (y dimension).
107/// Transpiles to: `blockDim.y`
108#[inline]
109pub fn block_dim_y() -> i32 {
110    1
111}
112
113/// Get the block dimension (z dimension).
114/// Transpiles to: `blockDim.z`
115#[inline]
116pub fn block_dim_z() -> i32 {
117    1
118}
119
120/// Get the grid dimension (x dimension).
121/// Transpiles to: `gridDim.x`
122#[inline]
123pub fn grid_dim_x() -> i32 {
124    1
125}
126
127/// Get the grid dimension (y dimension).
128/// Transpiles to: `gridDim.y`
129#[inline]
130pub fn grid_dim_y() -> i32 {
131    1
132}
133
134/// Get the grid dimension (z dimension).
135/// Transpiles to: `gridDim.z`
136#[inline]
137pub fn grid_dim_z() -> i32 {
138    1
139}
140
141/// Get the warp size (always 32 on NVIDIA GPUs).
142/// Transpiles to: `warpSize`
143#[inline]
144pub fn warp_size() -> i32 {
145    32
146}
147
148// ============================================================================
149// Synchronization Functions
150// ============================================================================
151
152/// Synchronize all threads in a block.
153/// Transpiles to: `__syncthreads()`
154#[inline]
155pub fn sync_threads() {
156    // CPU fallback: no-op (single-threaded)
157}
158
159/// Synchronize threads and count predicate.
160/// Transpiles to: `__syncthreads_count(predicate)`
161#[inline]
162pub fn sync_threads_count(predicate: bool) -> i32 {
163    if predicate {
164        1
165    } else {
166        0
167    }
168}
169
170/// Synchronize threads with AND of predicate.
171/// Transpiles to: `__syncthreads_and(predicate)`
172#[inline]
173pub fn sync_threads_and(predicate: bool) -> i32 {
174    if predicate {
175        1
176    } else {
177        0
178    }
179}
180
181/// Synchronize threads with OR of predicate.
182/// Transpiles to: `__syncthreads_or(predicate)`
183#[inline]
184pub fn sync_threads_or(predicate: bool) -> i32 {
185    if predicate {
186        1
187    } else {
188        0
189    }
190}
191
192/// Thread memory fence.
193/// Transpiles to: `__threadfence()`
194#[inline]
195pub fn thread_fence() {
196    fence(Ordering::SeqCst);
197}
198
199/// Block-level memory fence.
200/// Transpiles to: `__threadfence_block()`
201#[inline]
202pub fn thread_fence_block() {
203    fence(Ordering::Release);
204}
205
206/// System-wide memory fence.
207/// Transpiles to: `__threadfence_system()`
208#[inline]
209pub fn thread_fence_system() {
210    fence(Ordering::SeqCst);
211}
212
213// ============================================================================
214// Atomic Operations (CPU fallbacks - not thread-safe!)
215// ============================================================================
216
217/// Atomic add. Transpiles to: `atomicAdd(addr, val)`
218/// WARNING: CPU fallback is NOT thread-safe!
219#[inline]
220pub fn atomic_add(addr: &mut i32, val: i32) -> i32 {
221    let old = *addr;
222    *addr += val;
223    old
224}
225
226/// Atomic add for f32. Transpiles to: `atomicAdd(addr, val)`
227#[inline]
228pub fn atomic_add_f32(addr: &mut f32, val: f32) -> f32 {
229    let old = *addr;
230    *addr += val;
231    old
232}
233
234/// Atomic subtract. Transpiles to: `atomicSub(addr, val)`
235#[inline]
236pub fn atomic_sub(addr: &mut i32, val: i32) -> i32 {
237    let old = *addr;
238    *addr -= val;
239    old
240}
241
242/// Atomic minimum. Transpiles to: `atomicMin(addr, val)`
243#[inline]
244pub fn atomic_min(addr: &mut i32, val: i32) -> i32 {
245    let old = *addr;
246    *addr = old.min(val);
247    old
248}
249
250/// Atomic maximum. Transpiles to: `atomicMax(addr, val)`
251#[inline]
252pub fn atomic_max(addr: &mut i32, val: i32) -> i32 {
253    let old = *addr;
254    *addr = old.max(val);
255    old
256}
257
258/// Atomic exchange. Transpiles to: `atomicExch(addr, val)`
259#[inline]
260pub fn atomic_exchange(addr: &mut i32, val: i32) -> i32 {
261    let old = *addr;
262    *addr = val;
263    old
264}
265
266/// Atomic compare and swap. Transpiles to: `atomicCAS(addr, compare, val)`
267#[inline]
268pub fn atomic_cas(addr: &mut i32, compare: i32, val: i32) -> i32 {
269    let old = *addr;
270    if old == compare {
271        *addr = val;
272    }
273    old
274}
275
276/// Atomic AND. Transpiles to: `atomicAnd(addr, val)`
277#[inline]
278pub fn atomic_and(addr: &mut i32, val: i32) -> i32 {
279    let old = *addr;
280    *addr &= val;
281    old
282}
283
284/// Atomic OR. Transpiles to: `atomicOr(addr, val)`
285#[inline]
286pub fn atomic_or(addr: &mut i32, val: i32) -> i32 {
287    let old = *addr;
288    *addr |= val;
289    old
290}
291
292/// Atomic XOR. Transpiles to: `atomicXor(addr, val)`
293#[inline]
294pub fn atomic_xor(addr: &mut i32, val: i32) -> i32 {
295    let old = *addr;
296    *addr ^= val;
297    old
298}
299
300/// Atomic increment with wrap. Transpiles to: `atomicInc(addr, val)`
301#[inline]
302pub fn atomic_inc(addr: &mut u32, val: u32) -> u32 {
303    let old = *addr;
304    *addr = if old >= val { 0 } else { old + 1 };
305    old
306}
307
308/// Atomic decrement with wrap. Transpiles to: `atomicDec(addr, val)`
309#[inline]
310pub fn atomic_dec(addr: &mut u32, val: u32) -> u32 {
311    let old = *addr;
312    *addr = if old == 0 || old > val { val } else { old - 1 };
313    old
314}
315
316// ============================================================================
317// Basic Math Functions
318// ============================================================================
319
320/// Square root. Transpiles to: `sqrtf(x)`
321#[inline]
322pub fn sqrt(x: f32) -> f32 {
323    x.sqrt()
324}
325
326/// Reciprocal square root. Transpiles to: `rsqrtf(x)`
327#[inline]
328pub fn rsqrt(x: f32) -> f32 {
329    1.0 / x.sqrt()
330}
331
332/// Absolute value for f32. Transpiles to: `fabsf(x)`
333#[inline]
334pub fn fabs(x: f32) -> f32 {
335    x.abs()
336}
337
338/// Floor. Transpiles to: `floorf(x)`
339#[inline]
340pub fn floor(x: f32) -> f32 {
341    x.floor()
342}
343
344/// Ceiling. Transpiles to: `ceilf(x)`
345#[inline]
346pub fn ceil(x: f32) -> f32 {
347    x.ceil()
348}
349
350/// Round to nearest. Transpiles to: `roundf(x)`
351#[inline]
352pub fn round(x: f32) -> f32 {
353    x.round()
354}
355
356/// Truncate toward zero. Transpiles to: `truncf(x)`
357#[inline]
358pub fn trunc(x: f32) -> f32 {
359    x.trunc()
360}
361
362/// Fused multiply-add. Transpiles to: `fmaf(a, b, c)`
363#[inline]
364pub fn fma(a: f32, b: f32, c: f32) -> f32 {
365    a.mul_add(b, c)
366}
367
368/// Minimum. Transpiles to: `fminf(a, b)`
369#[inline]
370pub fn fmin(a: f32, b: f32) -> f32 {
371    a.min(b)
372}
373
374/// Maximum. Transpiles to: `fmaxf(a, b)`
375#[inline]
376pub fn fmax(a: f32, b: f32) -> f32 {
377    a.max(b)
378}
379
380/// Floating-point modulo. Transpiles to: `fmodf(x, y)`
381#[inline]
382pub fn fmod(x: f32, y: f32) -> f32 {
383    x % y
384}
385
386/// Remainder. Transpiles to: `remainderf(x, y)`
387#[inline]
388pub fn remainder(x: f32, y: f32) -> f32 {
389    x - (x / y).round() * y
390}
391
392/// Copy sign. Transpiles to: `copysignf(x, y)`
393#[inline]
394pub fn copysign(x: f32, y: f32) -> f32 {
395    x.copysign(y)
396}
397
398/// Cube root. Transpiles to: `cbrtf(x)`
399#[inline]
400pub fn cbrt(x: f32) -> f32 {
401    x.cbrt()
402}
403
404/// Hypotenuse. Transpiles to: `hypotf(x, y)`
405#[inline]
406pub fn hypot(x: f32, y: f32) -> f32 {
407    x.hypot(y)
408}
409
410// ============================================================================
411// Trigonometric Functions
412// ============================================================================
413
414/// Sine. Transpiles to: `sinf(x)`
415#[inline]
416pub fn sin(x: f32) -> f32 {
417    x.sin()
418}
419
420/// Cosine. Transpiles to: `cosf(x)`
421#[inline]
422pub fn cos(x: f32) -> f32 {
423    x.cos()
424}
425
426/// Tangent. Transpiles to: `tanf(x)`
427#[inline]
428pub fn tan(x: f32) -> f32 {
429    x.tan()
430}
431
432/// Arcsine. Transpiles to: `asinf(x)`
433#[inline]
434pub fn asin(x: f32) -> f32 {
435    x.asin()
436}
437
438/// Arccosine. Transpiles to: `acosf(x)`
439#[inline]
440pub fn acos(x: f32) -> f32 {
441    x.acos()
442}
443
444/// Arctangent. Transpiles to: `atanf(x)`
445#[inline]
446pub fn atan(x: f32) -> f32 {
447    x.atan()
448}
449
450/// Two-argument arctangent. Transpiles to: `atan2f(y, x)`
451#[inline]
452pub fn atan2(y: f32, x: f32) -> f32 {
453    y.atan2(x)
454}
455
456/// Sine and cosine together. Transpiles to: `sincosf(x, &s, &c)`
457#[inline]
458pub fn sincos(x: f32) -> (f32, f32) {
459    (x.sin(), x.cos())
460}
461
462/// Sine of pi*x. Transpiles to: `sinpif(x)`
463#[inline]
464pub fn sinpi(x: f32) -> f32 {
465    (x * std::f32::consts::PI).sin()
466}
467
468/// Cosine of pi*x. Transpiles to: `cospif(x)`
469#[inline]
470pub fn cospi(x: f32) -> f32 {
471    (x * std::f32::consts::PI).cos()
472}
473
474// ============================================================================
475// Hyperbolic Functions
476// ============================================================================
477
478/// Hyperbolic sine. Transpiles to: `sinhf(x)`
479#[inline]
480pub fn sinh(x: f32) -> f32 {
481    x.sinh()
482}
483
484/// Hyperbolic cosine. Transpiles to: `coshf(x)`
485#[inline]
486pub fn cosh(x: f32) -> f32 {
487    x.cosh()
488}
489
490/// Hyperbolic tangent. Transpiles to: `tanhf(x)`
491#[inline]
492pub fn tanh(x: f32) -> f32 {
493    x.tanh()
494}
495
496/// Inverse hyperbolic sine. Transpiles to: `asinhf(x)`
497#[inline]
498pub fn asinh(x: f32) -> f32 {
499    x.asinh()
500}
501
502/// Inverse hyperbolic cosine. Transpiles to: `acoshf(x)`
503#[inline]
504pub fn acosh(x: f32) -> f32 {
505    x.acosh()
506}
507
508/// Inverse hyperbolic tangent. Transpiles to: `atanhf(x)`
509#[inline]
510pub fn atanh(x: f32) -> f32 {
511    x.atanh()
512}
513
514// ============================================================================
515// Exponential and Logarithmic Functions
516// ============================================================================
517
518/// Exponential (base e). Transpiles to: `expf(x)`
519#[inline]
520pub fn exp(x: f32) -> f32 {
521    x.exp()
522}
523
524/// Exponential (base 2). Transpiles to: `exp2f(x)`
525#[inline]
526pub fn exp2(x: f32) -> f32 {
527    x.exp2()
528}
529
530/// Exponential (base 10). Transpiles to: `exp10f(x)`
531#[inline]
532pub fn exp10(x: f32) -> f32 {
533    (x * std::f32::consts::LN_10).exp()
534}
535
536/// exp(x) - 1 (accurate for small x). Transpiles to: `expm1f(x)`
537#[inline]
538pub fn expm1(x: f32) -> f32 {
539    x.exp_m1()
540}
541
542/// Natural logarithm (base e). Transpiles to: `logf(x)`
543#[inline]
544pub fn log(x: f32) -> f32 {
545    x.ln()
546}
547
548/// Logarithm (base 2). Transpiles to: `log2f(x)`
549#[inline]
550pub fn log2(x: f32) -> f32 {
551    x.log2()
552}
553
554/// Logarithm (base 10). Transpiles to: `log10f(x)`
555#[inline]
556pub fn log10(x: f32) -> f32 {
557    x.log10()
558}
559
560/// log(1 + x) (accurate for small x). Transpiles to: `log1pf(x)`
561#[inline]
562pub fn log1p(x: f32) -> f32 {
563    x.ln_1p()
564}
565
566/// Power. Transpiles to: `powf(x, y)`
567#[inline]
568pub fn pow(x: f32, y: f32) -> f32 {
569    x.powf(y)
570}
571
572/// Load exponent. Transpiles to: `ldexpf(x, exp)`
573#[inline]
574pub fn ldexp(x: f32, exp: i32) -> f32 {
575    x * 2.0_f32.powi(exp)
576}
577
578/// Scale by power of 2. Transpiles to: `scalbnf(x, n)`
579#[inline]
580pub fn scalbn(x: f32, n: i32) -> f32 {
581    x * 2.0_f32.powi(n)
582}
583
584/// Extract exponent. Transpiles to: `ilogbf(x)`
585#[inline]
586pub fn ilogb(x: f32) -> i32 {
587    if x == 0.0 {
588        i32::MIN
589    } else if x.is_infinite() {
590        i32::MAX
591    } else {
592        x.abs().log2().floor() as i32
593    }
594}
595
596/// Error function. Transpiles to: `erff(x)`
597#[inline]
598pub fn erf(x: f32) -> f32 {
599    // Approximation using Horner form
600    let a1 = 0.254_829_6_f32;
601    let a2 = -0.284_496_74_f32;
602    let a3 = 1.421_413_7_f32;
603    let a4 = -1.453_152_f32;
604    let a5 = 1.061_405_4_f32;
605    let p = 0.327_591_1_f32;
606
607    let sign = if x < 0.0 { -1.0 } else { 1.0 };
608    let x = x.abs();
609    let t = 1.0 / (1.0 + p * x);
610    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
611    sign * y
612}
613
614/// Complementary error function. Transpiles to: `erfcf(x)`
615#[inline]
616pub fn erfc(x: f32) -> f32 {
617    1.0 - erf(x)
618}
619
620// ============================================================================
621// Classification and Comparison Functions
622// ============================================================================
623
624/// Check if NaN. Transpiles to: `isnan(x)`
625#[inline]
626pub fn is_nan(x: f32) -> bool {
627    x.is_nan()
628}
629
630/// Check if infinite. Transpiles to: `isinf(x)`
631#[inline]
632pub fn is_infinite(x: f32) -> bool {
633    x.is_infinite()
634}
635
636/// Check if finite. Transpiles to: `isfinite(x)`
637#[inline]
638pub fn is_finite(x: f32) -> bool {
639    x.is_finite()
640}
641
642/// Check if normal. Transpiles to: `isnormal(x)`
643#[inline]
644pub fn is_normal(x: f32) -> bool {
645    x.is_normal()
646}
647
648/// Check sign bit. Transpiles to: `signbit(x)`
649#[inline]
650pub fn signbit(x: f32) -> bool {
651    x.is_sign_negative()
652}
653
654/// Next representable value. Transpiles to: `nextafterf(x, y)`
655#[inline]
656pub fn nextafter(x: f32, y: f32) -> f32 {
657    if x == y {
658        y
659    } else if y > x {
660        f32::from_bits(x.to_bits() + 1)
661    } else {
662        f32::from_bits(x.to_bits() - 1)
663    }
664}
665
666/// Floating-point difference. Transpiles to: `fdimf(x, y)`
667#[inline]
668pub fn fdim(x: f32, y: f32) -> f32 {
669    if x > y {
670        x - y
671    } else {
672        0.0
673    }
674}
675
676// ============================================================================
677// Warp-Level Operations
678// ============================================================================
679
680/// Get active thread mask. Transpiles to: `__activemask()`
681#[inline]
682pub fn warp_active_mask() -> u32 {
683    1 // CPU fallback: only one thread active
684}
685
686/// Warp ballot. Transpiles to: `__ballot_sync(mask, predicate)`
687#[inline]
688pub fn warp_ballot(_mask: u32, predicate: bool) -> u32 {
689    if predicate {
690        1
691    } else {
692        0
693    }
694}
695
696/// Warp all predicate. Transpiles to: `__all_sync(mask, predicate)`
697#[inline]
698pub fn warp_all(_mask: u32, predicate: bool) -> bool {
699    predicate
700}
701
702/// Warp any predicate. Transpiles to: `__any_sync(mask, predicate)`
703#[inline]
704pub fn warp_any(_mask: u32, predicate: bool) -> bool {
705    predicate
706}
707
708/// Warp shuffle. Transpiles to: `__shfl_sync(mask, val, lane)`
709#[inline]
710pub fn warp_shfl<T: Copy>(_mask: u32, val: T, _lane: i32) -> T {
711    val // CPU fallback: return same value
712}
713
714/// Warp shuffle up. Transpiles to: `__shfl_up_sync(mask, val, delta)`
715#[inline]
716pub fn warp_shfl_up<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
717    val
718}
719
720/// Warp shuffle down. Transpiles to: `__shfl_down_sync(mask, val, delta)`
721#[inline]
722pub fn warp_shfl_down<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
723    val
724}
725
726/// Warp shuffle XOR. Transpiles to: `__shfl_xor_sync(mask, val, lane_mask)`
727#[inline]
728pub fn warp_shfl_xor<T: Copy>(_mask: u32, val: T, _lane_mask: i32) -> T {
729    val
730}
731
732/// Warp reduce add. Transpiles to: `__reduce_add_sync(mask, val)`
733#[inline]
734pub fn warp_reduce_add(_mask: u32, val: i32) -> i32 {
735    val // CPU: single thread, no reduction needed
736}
737
738/// Warp reduce min. Transpiles to: `__reduce_min_sync(mask, val)`
739#[inline]
740pub fn warp_reduce_min(_mask: u32, val: i32) -> i32 {
741    val
742}
743
744/// Warp reduce max. Transpiles to: `__reduce_max_sync(mask, val)`
745#[inline]
746pub fn warp_reduce_max(_mask: u32, val: i32) -> i32 {
747    val
748}
749
750/// Warp reduce AND. Transpiles to: `__reduce_and_sync(mask, val)`
751#[inline]
752pub fn warp_reduce_and(_mask: u32, val: u32) -> u32 {
753    val
754}
755
756/// Warp reduce OR. Transpiles to: `__reduce_or_sync(mask, val)`
757#[inline]
758pub fn warp_reduce_or(_mask: u32, val: u32) -> u32 {
759    val
760}
761
762/// Warp reduce XOR. Transpiles to: `__reduce_xor_sync(mask, val)`
763#[inline]
764pub fn warp_reduce_xor(_mask: u32, val: u32) -> u32 {
765    val
766}
767
768/// Warp match any. Transpiles to: `__match_any_sync(mask, val)`
769#[inline]
770pub fn warp_match_any(_mask: u32, _val: u32) -> u32 {
771    1 // CPU: single thread always matches itself
772}
773
774/// Warp match all. Transpiles to: `__match_all_sync(mask, val, pred)`
775#[inline]
776pub fn warp_match_all(_mask: u32, _val: u32) -> (u32, bool) {
777    (1, true) // CPU: single thread, trivially all match
778}
779
780// ============================================================================
781// Bit Manipulation Functions
782// ============================================================================
783
784/// Population count (count set bits). Transpiles to: `__popc(x)`
785#[inline]
786pub fn popc(x: u32) -> i32 {
787    x.count_ones() as i32
788}
789
790/// Population count (i32 version).
791#[inline]
792pub fn popcount(x: i32) -> i32 {
793    (x as u32).count_ones() as i32
794}
795
796/// Count leading zeros. Transpiles to: `__clz(x)`
797#[inline]
798pub fn clz(x: u32) -> i32 {
799    x.leading_zeros() as i32
800}
801
802/// Count leading zeros (i32 version).
803#[inline]
804pub fn leading_zeros(x: i32) -> i32 {
805    (x as u32).leading_zeros() as i32
806}
807
808/// Count trailing zeros. Transpiles to: `__ffs(x) - 1`
809#[inline]
810pub fn ctz(x: u32) -> i32 {
811    if x == 0 {
812        32
813    } else {
814        x.trailing_zeros() as i32
815    }
816}
817
818/// Count trailing zeros (i32 version).
819#[inline]
820pub fn trailing_zeros(x: i32) -> i32 {
821    if x == 0 {
822        32
823    } else {
824        (x as u32).trailing_zeros() as i32
825    }
826}
827
828/// Find first set bit (1-indexed, 0 if none). Transpiles to: `__ffs(x)`
829#[inline]
830pub fn ffs(x: u32) -> i32 {
831    if x == 0 {
832        0
833    } else {
834        (x.trailing_zeros() + 1) as i32
835    }
836}
837
838/// Bit reverse. Transpiles to: `__brev(x)`
839#[inline]
840pub fn brev(x: u32) -> u32 {
841    x.reverse_bits()
842}
843
844/// Bit reverse (i32 version).
845#[inline]
846pub fn reverse_bits(x: i32) -> i32 {
847    (x as u32).reverse_bits() as i32
848}
849
850/// Byte permutation. Transpiles to: `__byte_perm(x, y, s)`
851#[inline]
852pub fn byte_perm(x: u32, y: u32, s: u32) -> u32 {
853    let bytes = [
854        (x & 0xFF) as u8,
855        ((x >> 8) & 0xFF) as u8,
856        ((x >> 16) & 0xFF) as u8,
857        ((x >> 24) & 0xFF) as u8,
858        (y & 0xFF) as u8,
859        ((y >> 8) & 0xFF) as u8,
860        ((y >> 16) & 0xFF) as u8,
861        ((y >> 24) & 0xFF) as u8,
862    ];
863    let b0 = bytes[(s & 0x7) as usize] as u32;
864    let b1 = bytes[((s >> 4) & 0x7) as usize] as u32;
865    let b2 = bytes[((s >> 8) & 0x7) as usize] as u32;
866    let b3 = bytes[((s >> 12) & 0x7) as usize] as u32;
867    b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
868}
869
870/// Funnel shift left. Transpiles to: `__funnelshift_l(lo, hi, shift)`
871#[inline]
872pub fn funnel_shift_left(lo: u32, hi: u32, shift: u32) -> u32 {
873    let shift = shift & 31;
874    if shift == 0 {
875        lo
876    } else {
877        (hi << shift) | (lo >> (32 - shift))
878    }
879}
880
881/// Funnel shift right. Transpiles to: `__funnelshift_r(lo, hi, shift)`
882#[inline]
883pub fn funnel_shift_right(lo: u32, hi: u32, shift: u32) -> u32 {
884    let shift = shift & 31;
885    if shift == 0 {
886        lo
887    } else {
888        (lo >> shift) | (hi << (32 - shift))
889    }
890}
891
892// ============================================================================
893// Memory Operations
894// ============================================================================
895
896/// Read-only cache load. Transpiles to: `__ldg(ptr)`
897#[inline]
898pub fn ldg<T: Copy>(ptr: &T) -> T {
899    *ptr
900}
901
902/// Load from global memory (alias for ldg).
903#[inline]
904pub fn load_global<T: Copy>(ptr: &T) -> T {
905    *ptr
906}
907
908/// Prefetch to L1 cache. Transpiles to: `__prefetch_l1(ptr)`
909#[inline]
910pub fn prefetch_l1<T>(_ptr: &T) {
911    // CPU fallback: no-op
912}
913
914/// Prefetch to L2 cache. Transpiles to: `__prefetch_l2(ptr)`
915#[inline]
916pub fn prefetch_l2<T>(_ptr: &T) {
917    // CPU fallback: no-op
918}
919
920// ============================================================================
921// Special Functions
922// ============================================================================
923
924/// Fast reciprocal. Transpiles to: `__frcp_rn(x)`
925#[inline]
926pub fn rcp(x: f32) -> f32 {
927    1.0 / x
928}
929
930/// Fast division. Transpiles to: `__fdividef(x, y)`
931#[inline]
932pub fn fast_div(x: f32, y: f32) -> f32 {
933    x / y
934}
935
936/// Saturate to [0, 1]. Transpiles to: `__saturatef(x)`
937#[inline]
938pub fn saturate(x: f32) -> f32 {
939    x.clamp(0.0, 1.0)
940}
941
942/// Clamp to [0, 1] (alias for saturate).
943#[inline]
944pub fn clamp_01(x: f32) -> f32 {
945    saturate(x)
946}
947
948// ============================================================================
949// Clock and Timing
950// ============================================================================
951
952/// Read clock counter. Transpiles to: `clock()`
953#[inline]
954pub fn clock() -> u32 {
955    // CPU fallback: use std time
956    std::time::SystemTime::now()
957        .duration_since(std::time::UNIX_EPOCH)
958        .map(|d| d.as_nanos() as u32)
959        .unwrap_or(0)
960}
961
962/// Read 64-bit clock counter. Transpiles to: `clock64()`
963#[inline]
964pub fn clock64() -> u64 {
965    std::time::SystemTime::now()
966        .duration_since(std::time::UNIX_EPOCH)
967        .map(|d| d.as_nanos() as u64)
968        .unwrap_or(0)
969}
970
971/// Nanosleep. Transpiles to: `__nanosleep(ns)`
972#[inline]
973pub fn nanosleep(ns: u32) {
974    std::thread::sleep(std::time::Duration::from_nanos(ns as u64));
975}
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980
981    #[test]
982    fn test_thread_indices_default() {
983        assert_eq!(thread_idx_x(), 0);
984        assert_eq!(thread_idx_y(), 0);
985        assert_eq!(thread_idx_z(), 0);
986    }
987
988    #[test]
989    fn test_block_indices_default() {
990        assert_eq!(block_idx_x(), 0);
991        assert_eq!(block_idx_y(), 0);
992        assert_eq!(block_idx_z(), 0);
993    }
994
995    #[test]
996    fn test_dimensions_default() {
997        assert_eq!(block_dim_x(), 1);
998        assert_eq!(block_dim_y(), 1);
999        assert_eq!(grid_dim_x(), 1);
1000        assert_eq!(warp_size(), 32);
1001    }
1002
1003    #[test]
1004    fn test_math_functions() {
1005        assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
1006        assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
1007        assert!((sin(0.0)).abs() < 1e-6);
1008        assert!((cos(0.0) - 1.0).abs() < 1e-6);
1009        assert!((exp(0.0) - 1.0).abs() < 1e-6);
1010        assert!((log(1.0)).abs() < 1e-6);
1011    }
1012
1013    #[test]
1014    fn test_trigonometric_functions() {
1015        let pi = std::f32::consts::PI;
1016        assert!((sin(pi / 2.0) - 1.0).abs() < 1e-6);
1017        assert!((cos(pi) + 1.0).abs() < 1e-6);
1018        assert!((tan(0.0)).abs() < 1e-6);
1019        assert!((asin(1.0) - pi / 2.0).abs() < 1e-6);
1020        assert!((atan2(1.0, 1.0) - pi / 4.0).abs() < 1e-6);
1021    }
1022
1023    #[test]
1024    fn test_hyperbolic_functions() {
1025        assert!((sinh(0.0)).abs() < 1e-6);
1026        assert!((cosh(0.0) - 1.0).abs() < 1e-6);
1027        assert!((tanh(0.0)).abs() < 1e-6);
1028    }
1029
1030    #[test]
1031    fn test_exponential_functions() {
1032        assert!((exp2(3.0) - 8.0).abs() < 1e-6);
1033        assert!((log2(8.0) - 3.0).abs() < 1e-6);
1034        assert!((log10(100.0) - 2.0).abs() < 1e-6);
1035        assert!((pow(2.0, 3.0) - 8.0).abs() < 1e-6);
1036    }
1037
1038    #[test]
1039    fn test_classification_functions() {
1040        assert!(is_nan(f32::NAN));
1041        assert!(!is_nan(1.0));
1042        assert!(is_infinite(f32::INFINITY));
1043        assert!(!is_infinite(1.0));
1044        assert!(is_finite(1.0));
1045        assert!(!is_finite(f32::INFINITY));
1046    }
1047
1048    #[test]
1049    fn test_bit_manipulation() {
1050        assert_eq!(popc(0b1010_1010), 4);
1051        assert_eq!(clz(1u32), 31);
1052        assert_eq!(clz(0x8000_0000u32), 0);
1053        assert_eq!(ctz(0b1000), 3);
1054        assert_eq!(ffs(0b1000), 4);
1055        assert_eq!(brev(1u32), 0x8000_0000);
1056    }
1057
1058    #[test]
1059    fn test_warp_operations() {
1060        assert_eq!(warp_active_mask(), 1);
1061        assert_eq!(warp_ballot(0xFFFF_FFFF, true), 1);
1062        assert!(warp_all(0xFFFF_FFFF, true));
1063        assert!(warp_any(0xFFFF_FFFF, true));
1064        assert_eq!(warp_reduce_add(0xFFFF_FFFF, 5), 5);
1065    }
1066
1067    #[test]
1068    fn test_special_functions() {
1069        assert!((rcp(2.0) - 0.5).abs() < 1e-6);
1070        assert!((fast_div(10.0, 2.0) - 5.0).abs() < 1e-6);
1071        assert_eq!(saturate(-1.0), 0.0);
1072        assert_eq!(saturate(0.5), 0.5);
1073        assert_eq!(saturate(2.0), 1.0);
1074    }
1075
1076    #[test]
1077    fn test_atomic_operations() {
1078        let mut val = 10;
1079        assert_eq!(atomic_add(&mut val, 5), 10);
1080        assert_eq!(val, 15);
1081
1082        let mut val = 10;
1083        assert_eq!(atomic_sub(&mut val, 3), 10);
1084        assert_eq!(val, 7);
1085
1086        let mut val = 10;
1087        assert_eq!(atomic_cas(&mut val, 10, 20), 10);
1088        assert_eq!(val, 20);
1089    }
1090
1091    #[test]
1092    fn test_funnel_shift() {
1093        assert_eq!(funnel_shift_left(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF);
1094        assert_eq!(
1095            funnel_shift_right(0xFFFF_0000, 0x0000_FFFF, 16),
1096            0xFFFF_FFFF
1097        );
1098    }
1099
1100    #[test]
1101    fn test_byte_perm() {
1102        let x = 0x04030201u32;
1103        let y = 0x08070605u32;
1104        // Select bytes 0, 1, 2, 3 from x
1105        assert_eq!(byte_perm(x, y, 0x3210), 0x04030201);
1106        // Select bytes 4, 5, 6, 7 from y
1107        assert_eq!(byte_perm(x, y, 0x7654), 0x08070605);
1108    }
1109}