Skip to main content

ringkernel_cuda_codegen/
dsl.rs

1//! Rust DSL functions for writing CUDA kernels.
2//!
3//! This module provides Rust functions that map to CUDA intrinsics during transpilation.
4//! These functions have CPU fallback implementations for testing but are transpiled
5//! to the corresponding CUDA operations when used in kernel code.
6//!
7//! Common functions (thread indices, basic math, synchronization) are shared with
8//! the WGSL backend via `ringkernel_codegen::dsl_common`.
9//!
10//! # Thread/Block Index Access
11//!
12//! ```ignore
13//! use ringkernel_cuda_codegen::dsl::*;
14//!
15//! fn my_kernel(...) {
16//!     let tx = thread_idx_x();  // -> threadIdx.x
17//!     let bx = block_idx_x();   // -> blockIdx.x
18//!     let idx = bx * block_dim_x() + tx;  // Global thread index
19//! }
20//! ```
21//!
22//! # Thread Synchronization
23//!
24//! ```ignore
25//! sync_threads();  // -> __syncthreads()
26//! ```
27//!
28//! # Math Functions
29//!
30//! All standard math functions are available with CPU fallbacks:
31//! - Trigonometric: sin, cos, tan, asin, acos, atan, atan2
32//! - Hyperbolic: sinh, cosh, tanh, asinh, acosh, atanh
33//! - Exponential: exp, exp2, exp10, expm1, log, log2, log10, log1p
34//! - Power: pow, sqrt, rsqrt, cbrt
35//! - Rounding: floor, ceil, round, trunc
36//! - Comparison: fmin, fmax, fdim, copysign
37//!
38//! # Warp Operations
39//!
40//! ```ignore
41//! let mask = warp_active_mask();       // Get active lane mask
42//! let result = warp_reduce_add(mask, value);  // Warp-level sum
43//! let shuffled = warp_shfl(mask, value, lane); // Shuffle
44//! ```
45//!
46//! # Bit Manipulation
47//!
48//! ```ignore
49//! let bits = popc(x);      // Count set bits
50//! let zeros = clz(x);      // Count leading zeros
51//! let rev = brev(x);       // Reverse bits
52//! ```
53
54use std::sync::atomic::{fence, Ordering};
55
56// Re-export common DSL functions shared across backends.
57// These provide: thread/block indices, sync primitives, and basic math.
58pub use ringkernel_codegen::dsl_common::{
59    block_dim_x,
60    block_dim_y,
61    block_dim_z,
62    block_idx_x,
63    block_idx_y,
64    block_idx_z,
65    ceil,
66    cos,
67    exp,
68    floor,
69    fma,
70    grid_dim_x,
71    grid_dim_y,
72    grid_dim_z,
73    log,
74    powf,
75    round,
76    rsqrt,
77    sin,
78    // Math
79    sqrt,
80    // Synchronization
81    sync_threads,
82    tan,
83    thread_fence,
84    thread_fence_block,
85    // Thread/block indices
86    thread_idx_x,
87    thread_idx_y,
88    thread_idx_z,
89};
90
91// ============================================================================
92// CUDA-Specific Thread/Block Functions
93// ============================================================================
94
95/// Get the warp size (always 32 on NVIDIA GPUs).
96/// Transpiles to: `warpSize`
97#[inline]
98pub fn warp_size() -> i32 {
99    32
100}
101
102// ============================================================================
103// CUDA-Specific Synchronization Functions
104// ============================================================================
105
106/// Synchronize threads and count predicate.
107/// Transpiles to: `__syncthreads_count(predicate)`
108#[inline]
109pub fn sync_threads_count(predicate: bool) -> i32 {
110    if predicate {
111        1
112    } else {
113        0
114    }
115}
116
117/// Synchronize threads with AND of predicate.
118/// Transpiles to: `__syncthreads_and(predicate)`
119#[inline]
120pub fn sync_threads_and(predicate: bool) -> i32 {
121    if predicate {
122        1
123    } else {
124        0
125    }
126}
127
128/// Synchronize threads with OR of predicate.
129/// Transpiles to: `__syncthreads_or(predicate)`
130#[inline]
131pub fn sync_threads_or(predicate: bool) -> i32 {
132    if predicate {
133        1
134    } else {
135        0
136    }
137}
138
139/// System-wide memory fence.
140/// Transpiles to: `__threadfence_system()`
141#[inline]
142pub fn thread_fence_system() {
143    fence(Ordering::SeqCst);
144}
145
146// ============================================================================
147// Atomic Operations (CPU fallbacks - not thread-safe!)
148// ============================================================================
149
150/// Atomic add. Transpiles to: `atomicAdd(addr, val)`
151/// WARNING: CPU fallback is NOT thread-safe!
152#[inline]
153pub fn atomic_add(addr: &mut i32, val: i32) -> i32 {
154    let old = *addr;
155    *addr += val;
156    old
157}
158
159/// Atomic add for f32. Transpiles to: `atomicAdd(addr, val)`
160#[inline]
161pub fn atomic_add_f32(addr: &mut f32, val: f32) -> f32 {
162    let old = *addr;
163    *addr += val;
164    old
165}
166
167/// Atomic subtract. Transpiles to: `atomicSub(addr, val)`
168#[inline]
169pub fn atomic_sub(addr: &mut i32, val: i32) -> i32 {
170    let old = *addr;
171    *addr -= val;
172    old
173}
174
175/// Atomic minimum. Transpiles to: `atomicMin(addr, val)`
176#[inline]
177pub fn atomic_min(addr: &mut i32, val: i32) -> i32 {
178    let old = *addr;
179    *addr = old.min(val);
180    old
181}
182
183/// Atomic maximum. Transpiles to: `atomicMax(addr, val)`
184#[inline]
185pub fn atomic_max(addr: &mut i32, val: i32) -> i32 {
186    let old = *addr;
187    *addr = old.max(val);
188    old
189}
190
191/// Atomic exchange. Transpiles to: `atomicExch(addr, val)`
192#[inline]
193pub fn atomic_exchange(addr: &mut i32, val: i32) -> i32 {
194    let old = *addr;
195    *addr = val;
196    old
197}
198
199/// Atomic compare and swap. Transpiles to: `atomicCAS(addr, compare, val)`
200#[inline]
201pub fn atomic_cas(addr: &mut i32, compare: i32, val: i32) -> i32 {
202    let old = *addr;
203    if old == compare {
204        *addr = val;
205    }
206    old
207}
208
209/// Atomic AND. Transpiles to: `atomicAnd(addr, val)`
210#[inline]
211pub fn atomic_and(addr: &mut i32, val: i32) -> i32 {
212    let old = *addr;
213    *addr &= val;
214    old
215}
216
217/// Atomic OR. Transpiles to: `atomicOr(addr, val)`
218#[inline]
219pub fn atomic_or(addr: &mut i32, val: i32) -> i32 {
220    let old = *addr;
221    *addr |= val;
222    old
223}
224
225/// Atomic XOR. Transpiles to: `atomicXor(addr, val)`
226#[inline]
227pub fn atomic_xor(addr: &mut i32, val: i32) -> i32 {
228    let old = *addr;
229    *addr ^= val;
230    old
231}
232
233/// Atomic increment with wrap. Transpiles to: `atomicInc(addr, val)`
234#[inline]
235pub fn atomic_inc(addr: &mut u32, val: u32) -> u32 {
236    let old = *addr;
237    *addr = if old >= val { 0 } else { old + 1 };
238    old
239}
240
241/// Atomic decrement with wrap. Transpiles to: `atomicDec(addr, val)`
242#[inline]
243pub fn atomic_dec(addr: &mut u32, val: u32) -> u32 {
244    let old = *addr;
245    *addr = if old == 0 || old > val { val } else { old - 1 };
246    old
247}
248
249// ============================================================================
250// CUDA-Specific Math Functions
251// ============================================================================
252
253/// Absolute value for f32. Transpiles to: `fabsf(x)`
254#[inline]
255pub fn fabs(x: f32) -> f32 {
256    x.abs()
257}
258
259/// Truncate toward zero. Transpiles to: `truncf(x)`
260#[inline]
261pub fn trunc(x: f32) -> f32 {
262    x.trunc()
263}
264
265/// Minimum. Transpiles to: `fminf(a, b)`
266#[inline]
267pub fn fmin(a: f32, b: f32) -> f32 {
268    a.min(b)
269}
270
271/// Maximum. Transpiles to: `fmaxf(a, b)`
272#[inline]
273pub fn fmax(a: f32, b: f32) -> f32 {
274    a.max(b)
275}
276
277/// Floating-point modulo. Transpiles to: `fmodf(x, y)`
278#[inline]
279pub fn fmod(x: f32, y: f32) -> f32 {
280    x % y
281}
282
283/// Remainder. Transpiles to: `remainderf(x, y)`
284#[inline]
285pub fn remainder(x: f32, y: f32) -> f32 {
286    x - (x / y).round() * y
287}
288
289/// Copy sign. Transpiles to: `copysignf(x, y)`
290#[inline]
291pub fn copysign(x: f32, y: f32) -> f32 {
292    x.copysign(y)
293}
294
295/// Cube root. Transpiles to: `cbrtf(x)`
296#[inline]
297pub fn cbrt(x: f32) -> f32 {
298    x.cbrt()
299}
300
301/// Hypotenuse. Transpiles to: `hypotf(x, y)`
302#[inline]
303pub fn hypot(x: f32, y: f32) -> f32 {
304    x.hypot(y)
305}
306
307// ============================================================================
308// CUDA-Specific Trigonometric Functions
309// ============================================================================
310
311/// Arcsine. Transpiles to: `asinf(x)`
312#[inline]
313pub fn asin(x: f32) -> f32 {
314    x.asin()
315}
316
317/// Arccosine. Transpiles to: `acosf(x)`
318#[inline]
319pub fn acos(x: f32) -> f32 {
320    x.acos()
321}
322
323/// Arctangent. Transpiles to: `atanf(x)`
324#[inline]
325pub fn atan(x: f32) -> f32 {
326    x.atan()
327}
328
329/// Two-argument arctangent. Transpiles to: `atan2f(y, x)`
330#[inline]
331pub fn atan2(y: f32, x: f32) -> f32 {
332    y.atan2(x)
333}
334
335/// Sine and cosine together. Transpiles to: `sincosf(x, &s, &c)`
336#[inline]
337pub fn sincos(x: f32) -> (f32, f32) {
338    (x.sin(), x.cos())
339}
340
341/// Sine of pi*x. Transpiles to: `sinpif(x)`
342#[inline]
343pub fn sinpi(x: f32) -> f32 {
344    (x * std::f32::consts::PI).sin()
345}
346
347/// Cosine of pi*x. Transpiles to: `cospif(x)`
348#[inline]
349pub fn cospi(x: f32) -> f32 {
350    (x * std::f32::consts::PI).cos()
351}
352
353// ============================================================================
354// CUDA-Specific Hyperbolic Functions
355// ============================================================================
356
357/// Hyperbolic sine. Transpiles to: `sinhf(x)`
358#[inline]
359pub fn sinh(x: f32) -> f32 {
360    x.sinh()
361}
362
363/// Hyperbolic cosine. Transpiles to: `coshf(x)`
364#[inline]
365pub fn cosh(x: f32) -> f32 {
366    x.cosh()
367}
368
369/// Hyperbolic tangent. Transpiles to: `tanhf(x)`
370#[inline]
371pub fn tanh(x: f32) -> f32 {
372    x.tanh()
373}
374
375/// Inverse hyperbolic sine. Transpiles to: `asinhf(x)`
376#[inline]
377pub fn asinh(x: f32) -> f32 {
378    x.asinh()
379}
380
381/// Inverse hyperbolic cosine. Transpiles to: `acoshf(x)`
382#[inline]
383pub fn acosh(x: f32) -> f32 {
384    x.acosh()
385}
386
387/// Inverse hyperbolic tangent. Transpiles to: `atanhf(x)`
388#[inline]
389pub fn atanh(x: f32) -> f32 {
390    x.atanh()
391}
392
393// ============================================================================
394// CUDA-Specific Exponential and Logarithmic Functions
395// ============================================================================
396
397/// Exponential (base 2). Transpiles to: `exp2f(x)`
398#[inline]
399pub fn exp2(x: f32) -> f32 {
400    x.exp2()
401}
402
403/// Exponential (base 10). Transpiles to: `exp10f(x)`
404#[inline]
405pub fn exp10(x: f32) -> f32 {
406    (x * std::f32::consts::LN_10).exp()
407}
408
409/// exp(x) - 1 (accurate for small x). Transpiles to: `expm1f(x)`
410#[inline]
411pub fn expm1(x: f32) -> f32 {
412    x.exp_m1()
413}
414
415/// Logarithm (base 2). Transpiles to: `log2f(x)`
416#[inline]
417pub fn log2(x: f32) -> f32 {
418    x.log2()
419}
420
421/// Logarithm (base 10). Transpiles to: `log10f(x)`
422#[inline]
423pub fn log10(x: f32) -> f32 {
424    x.log10()
425}
426
427/// log(1 + x) (accurate for small x). Transpiles to: `log1pf(x)`
428#[inline]
429pub fn log1p(x: f32) -> f32 {
430    x.ln_1p()
431}
432
433/// Power. Transpiles to: `powf(x, y)`
434#[inline]
435pub fn pow(x: f32, y: f32) -> f32 {
436    x.powf(y)
437}
438
439/// Load exponent. Transpiles to: `ldexpf(x, exp)`
440#[inline]
441pub fn ldexp(x: f32, exp: i32) -> f32 {
442    x * 2.0_f32.powi(exp)
443}
444
445/// Scale by power of 2. Transpiles to: `scalbnf(x, n)`
446#[inline]
447pub fn scalbn(x: f32, n: i32) -> f32 {
448    x * 2.0_f32.powi(n)
449}
450
451/// Extract exponent. Transpiles to: `ilogbf(x)`
452#[inline]
453pub fn ilogb(x: f32) -> i32 {
454    if x == 0.0 {
455        i32::MIN
456    } else if x.is_infinite() {
457        i32::MAX
458    } else {
459        x.abs().log2().floor() as i32
460    }
461}
462
463/// Error function. Transpiles to: `erff(x)`
464#[inline]
465pub fn erf(x: f32) -> f32 {
466    // Approximation using Horner form
467    let a1 = 0.254_829_6_f32;
468    let a2 = -0.284_496_74_f32;
469    let a3 = 1.421_413_7_f32;
470    let a4 = -1.453_152_f32;
471    let a5 = 1.061_405_4_f32;
472    let p = 0.327_591_1_f32;
473
474    let sign = if x < 0.0 { -1.0 } else { 1.0 };
475    let x = x.abs();
476    let t = 1.0 / (1.0 + p * x);
477    let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
478    sign * y
479}
480
481/// Complementary error function. Transpiles to: `erfcf(x)`
482#[inline]
483pub fn erfc(x: f32) -> f32 {
484    1.0 - erf(x)
485}
486
487// ============================================================================
488// Classification and Comparison Functions
489// ============================================================================
490
491/// Check if NaN. Transpiles to: `isnan(x)`
492#[inline]
493pub fn is_nan(x: f32) -> bool {
494    x.is_nan()
495}
496
497/// Check if infinite. Transpiles to: `isinf(x)`
498#[inline]
499pub fn is_infinite(x: f32) -> bool {
500    x.is_infinite()
501}
502
503/// Check if finite. Transpiles to: `isfinite(x)`
504#[inline]
505pub fn is_finite(x: f32) -> bool {
506    x.is_finite()
507}
508
509/// Check if normal. Transpiles to: `isnormal(x)`
510#[inline]
511pub fn is_normal(x: f32) -> bool {
512    x.is_normal()
513}
514
515/// Check sign bit. Transpiles to: `signbit(x)`
516#[inline]
517pub fn signbit(x: f32) -> bool {
518    x.is_sign_negative()
519}
520
521/// Next representable value. Transpiles to: `nextafterf(x, y)`
522#[inline]
523pub fn nextafter(x: f32, y: f32) -> f32 {
524    if x == y {
525        y
526    } else if y > x {
527        f32::from_bits(x.to_bits() + 1)
528    } else {
529        f32::from_bits(x.to_bits() - 1)
530    }
531}
532
533/// Floating-point difference. Transpiles to: `fdimf(x, y)`
534#[inline]
535pub fn fdim(x: f32, y: f32) -> f32 {
536    if x > y {
537        x - y
538    } else {
539        0.0
540    }
541}
542
543// ============================================================================
544// Warp-Level Operations
545// ============================================================================
546
547/// Get active thread mask. Transpiles to: `__activemask()`
548#[inline]
549pub fn warp_active_mask() -> u32 {
550    1 // CPU fallback: only one thread active
551}
552
553/// Warp ballot. Transpiles to: `__ballot_sync(mask, predicate)`
554#[inline]
555pub fn warp_ballot(_mask: u32, predicate: bool) -> u32 {
556    if predicate {
557        1
558    } else {
559        0
560    }
561}
562
563/// Warp all predicate. Transpiles to: `__all_sync(mask, predicate)`
564#[inline]
565pub fn warp_all(_mask: u32, predicate: bool) -> bool {
566    predicate
567}
568
569/// Warp any predicate. Transpiles to: `__any_sync(mask, predicate)`
570#[inline]
571pub fn warp_any(_mask: u32, predicate: bool) -> bool {
572    predicate
573}
574
575/// Warp shuffle. Transpiles to: `__shfl_sync(mask, val, lane)`
576#[inline]
577pub fn warp_shfl<T: Copy>(_mask: u32, val: T, _lane: i32) -> T {
578    val // CPU fallback: return same value
579}
580
581/// Warp shuffle up. Transpiles to: `__shfl_up_sync(mask, val, delta)`
582#[inline]
583pub fn warp_shfl_up<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
584    val
585}
586
587/// Warp shuffle down. Transpiles to: `__shfl_down_sync(mask, val, delta)`
588#[inline]
589pub fn warp_shfl_down<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
590    val
591}
592
593/// Warp shuffle XOR. Transpiles to: `__shfl_xor_sync(mask, val, lane_mask)`
594#[inline]
595pub fn warp_shfl_xor<T: Copy>(_mask: u32, val: T, _lane_mask: i32) -> T {
596    val
597}
598
599/// Warp reduce add. Transpiles to: `__reduce_add_sync(mask, val)`
600#[inline]
601pub fn warp_reduce_add(_mask: u32, val: i32) -> i32 {
602    val // CPU: single thread, no reduction needed
603}
604
605/// Warp reduce min. Transpiles to: `__reduce_min_sync(mask, val)`
606#[inline]
607pub fn warp_reduce_min(_mask: u32, val: i32) -> i32 {
608    val
609}
610
611/// Warp reduce max. Transpiles to: `__reduce_max_sync(mask, val)`
612#[inline]
613pub fn warp_reduce_max(_mask: u32, val: i32) -> i32 {
614    val
615}
616
617/// Warp reduce AND. Transpiles to: `__reduce_and_sync(mask, val)`
618#[inline]
619pub fn warp_reduce_and(_mask: u32, val: u32) -> u32 {
620    val
621}
622
623/// Warp reduce OR. Transpiles to: `__reduce_or_sync(mask, val)`
624#[inline]
625pub fn warp_reduce_or(_mask: u32, val: u32) -> u32 {
626    val
627}
628
629/// Warp reduce XOR. Transpiles to: `__reduce_xor_sync(mask, val)`
630#[inline]
631pub fn warp_reduce_xor(_mask: u32, val: u32) -> u32 {
632    val
633}
634
635/// Warp match any. Transpiles to: `__match_any_sync(mask, val)`
636#[inline]
637pub fn warp_match_any(_mask: u32, _val: u32) -> u32 {
638    1 // CPU: single thread always matches itself
639}
640
641/// Warp match all. Transpiles to: `__match_all_sync(mask, val, pred)`
642#[inline]
643pub fn warp_match_all(_mask: u32, _val: u32) -> (u32, bool) {
644    (1, true) // CPU: single thread, trivially all match
645}
646
647// ============================================================================
648// Bit Manipulation Functions
649// ============================================================================
650
651/// Population count (count set bits). Transpiles to: `__popc(x)`
652#[inline]
653pub fn popc(x: u32) -> i32 {
654    x.count_ones() as i32
655}
656
657/// Population count (i32 version).
658#[inline]
659pub fn popcount(x: i32) -> i32 {
660    (x as u32).count_ones() as i32
661}
662
663/// Count leading zeros. Transpiles to: `__clz(x)`
664#[inline]
665pub fn clz(x: u32) -> i32 {
666    x.leading_zeros() as i32
667}
668
669/// Count leading zeros (i32 version).
670#[inline]
671pub fn leading_zeros(x: i32) -> i32 {
672    (x as u32).leading_zeros() as i32
673}
674
675/// Count trailing zeros. Transpiles to: `__ffs(x) - 1`
676#[inline]
677pub fn ctz(x: u32) -> i32 {
678    if x == 0 {
679        32
680    } else {
681        x.trailing_zeros() as i32
682    }
683}
684
685/// Count trailing zeros (i32 version).
686#[inline]
687pub fn trailing_zeros(x: i32) -> i32 {
688    if x == 0 {
689        32
690    } else {
691        (x as u32).trailing_zeros() as i32
692    }
693}
694
695/// Find first set bit (1-indexed, 0 if none). Transpiles to: `__ffs(x)`
696#[inline]
697pub fn ffs(x: u32) -> i32 {
698    if x == 0 {
699        0
700    } else {
701        (x.trailing_zeros() + 1) as i32
702    }
703}
704
705/// Bit reverse. Transpiles to: `__brev(x)`
706#[inline]
707pub fn brev(x: u32) -> u32 {
708    x.reverse_bits()
709}
710
711/// Bit reverse (i32 version).
712#[inline]
713pub fn reverse_bits(x: i32) -> i32 {
714    (x as u32).reverse_bits() as i32
715}
716
717/// Byte permutation. Transpiles to: `__byte_perm(x, y, s)`
718#[inline]
719pub fn byte_perm(x: u32, y: u32, s: u32) -> u32 {
720    let bytes = [
721        (x & 0xFF) as u8,
722        ((x >> 8) & 0xFF) as u8,
723        ((x >> 16) & 0xFF) as u8,
724        ((x >> 24) & 0xFF) as u8,
725        (y & 0xFF) as u8,
726        ((y >> 8) & 0xFF) as u8,
727        ((y >> 16) & 0xFF) as u8,
728        ((y >> 24) & 0xFF) as u8,
729    ];
730    let b0 = bytes[(s & 0x7) as usize] as u32;
731    let b1 = bytes[((s >> 4) & 0x7) as usize] as u32;
732    let b2 = bytes[((s >> 8) & 0x7) as usize] as u32;
733    let b3 = bytes[((s >> 12) & 0x7) as usize] as u32;
734    b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
735}
736
737/// Funnel shift left. Transpiles to: `__funnelshift_l(lo, hi, shift)`
738#[inline]
739pub fn funnel_shift_left(lo: u32, hi: u32, shift: u32) -> u32 {
740    let shift = shift & 31;
741    if shift == 0 {
742        lo
743    } else {
744        (hi << shift) | (lo >> (32 - shift))
745    }
746}
747
748/// Funnel shift right. Transpiles to: `__funnelshift_r(lo, hi, shift)`
749#[inline]
750pub fn funnel_shift_right(lo: u32, hi: u32, shift: u32) -> u32 {
751    let shift = shift & 31;
752    if shift == 0 {
753        lo
754    } else {
755        (lo >> shift) | (hi << (32 - shift))
756    }
757}
758
759// ============================================================================
760// Memory Operations
761// ============================================================================
762
763/// Read-only cache load. Transpiles to: `__ldg(ptr)`
764#[inline]
765pub fn ldg<T: Copy>(ptr: &T) -> T {
766    *ptr
767}
768
769/// Load from global memory (alias for ldg).
770#[inline]
771pub fn load_global<T: Copy>(ptr: &T) -> T {
772    *ptr
773}
774
775/// Prefetch to L1 cache. Transpiles to: `__prefetch_l1(ptr)`
776#[inline]
777pub fn prefetch_l1<T>(_ptr: &T) {
778    // CPU fallback: no-op
779}
780
781/// Prefetch to L2 cache. Transpiles to: `__prefetch_l2(ptr)`
782#[inline]
783pub fn prefetch_l2<T>(_ptr: &T) {
784    // CPU fallback: no-op
785}
786
787// ============================================================================
788// Special Functions
789// ============================================================================
790
791/// Fast reciprocal. Transpiles to: `__frcp_rn(x)`
792#[inline]
793pub fn rcp(x: f32) -> f32 {
794    1.0 / x
795}
796
797/// Fast division. Transpiles to: `__fdividef(x, y)`
798#[inline]
799pub fn fast_div(x: f32, y: f32) -> f32 {
800    x / y
801}
802
803/// Saturate to [0, 1]. Transpiles to: `__saturatef(x)`
804#[inline]
805pub fn saturate(x: f32) -> f32 {
806    x.clamp(0.0, 1.0)
807}
808
809/// Clamp to [0, 1] (alias for saturate).
810#[inline]
811pub fn clamp_01(x: f32) -> f32 {
812    saturate(x)
813}
814
815// ============================================================================
816// Clock and Timing
817// ============================================================================
818
819/// Read clock counter. Transpiles to: `clock()`
820#[inline]
821pub fn clock() -> u32 {
822    // CPU fallback: use std time
823    std::time::SystemTime::now()
824        .duration_since(std::time::UNIX_EPOCH)
825        .map(|d| d.as_nanos() as u32)
826        .unwrap_or(0)
827}
828
829/// Read 64-bit clock counter. Transpiles to: `clock64()`
830#[inline]
831pub fn clock64() -> u64 {
832    std::time::SystemTime::now()
833        .duration_since(std::time::UNIX_EPOCH)
834        .map(|d| d.as_nanos() as u64)
835        .unwrap_or(0)
836}
837
838/// Nanosleep. Transpiles to: `__nanosleep(ns)`
839#[inline]
840pub fn nanosleep(ns: u32) {
841    std::thread::sleep(std::time::Duration::from_nanos(ns as u64));
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847
848    #[test]
849    fn test_thread_indices_default() {
850        assert_eq!(thread_idx_x(), 0);
851        assert_eq!(thread_idx_y(), 0);
852        assert_eq!(thread_idx_z(), 0);
853    }
854
855    #[test]
856    fn test_block_indices_default() {
857        assert_eq!(block_idx_x(), 0);
858        assert_eq!(block_idx_y(), 0);
859        assert_eq!(block_idx_z(), 0);
860    }
861
862    #[test]
863    fn test_dimensions_default() {
864        assert_eq!(block_dim_x(), 1);
865        assert_eq!(block_dim_y(), 1);
866        assert_eq!(grid_dim_x(), 1);
867        assert_eq!(warp_size(), 32);
868    }
869
870    #[test]
871    fn test_math_functions() {
872        assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
873        assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
874        assert!((sin(0.0)).abs() < 1e-6);
875        assert!((cos(0.0) - 1.0).abs() < 1e-6);
876        assert!((exp(0.0) - 1.0).abs() < 1e-6);
877        assert!((log(1.0)).abs() < 1e-6);
878    }
879
880    #[test]
881    fn test_trigonometric_functions() {
882        let pi = std::f32::consts::PI;
883        assert!((sin(pi / 2.0) - 1.0).abs() < 1e-6);
884        assert!((cos(pi) + 1.0).abs() < 1e-6);
885        assert!((tan(0.0)).abs() < 1e-6);
886        assert!((asin(1.0) - pi / 2.0).abs() < 1e-6);
887        assert!((atan2(1.0, 1.0) - pi / 4.0).abs() < 1e-6);
888    }
889
890    #[test]
891    fn test_hyperbolic_functions() {
892        assert!((sinh(0.0)).abs() < 1e-6);
893        assert!((cosh(0.0) - 1.0).abs() < 1e-6);
894        assert!((tanh(0.0)).abs() < 1e-6);
895    }
896
897    #[test]
898    fn test_exponential_functions() {
899        assert!((exp2(3.0) - 8.0).abs() < 1e-6);
900        assert!((log2(8.0) - 3.0).abs() < 1e-6);
901        assert!((log10(100.0) - 2.0).abs() < 1e-6);
902        assert!((pow(2.0, 3.0) - 8.0).abs() < 1e-6);
903    }
904
905    #[test]
906    fn test_classification_functions() {
907        assert!(is_nan(f32::NAN));
908        assert!(!is_nan(1.0));
909        assert!(is_infinite(f32::INFINITY));
910        assert!(!is_infinite(1.0));
911        assert!(is_finite(1.0));
912        assert!(!is_finite(f32::INFINITY));
913    }
914
915    #[test]
916    fn test_bit_manipulation() {
917        assert_eq!(popc(0b1010_1010), 4);
918        assert_eq!(clz(1u32), 31);
919        assert_eq!(clz(0x8000_0000u32), 0);
920        assert_eq!(ctz(0b1000), 3);
921        assert_eq!(ffs(0b1000), 4);
922        assert_eq!(brev(1u32), 0x8000_0000);
923    }
924
925    #[test]
926    fn test_warp_operations() {
927        assert_eq!(warp_active_mask(), 1);
928        assert_eq!(warp_ballot(0xFFFF_FFFF, true), 1);
929        assert!(warp_all(0xFFFF_FFFF, true));
930        assert!(warp_any(0xFFFF_FFFF, true));
931        assert_eq!(warp_reduce_add(0xFFFF_FFFF, 5), 5);
932    }
933
934    #[test]
935    fn test_special_functions() {
936        assert!((rcp(2.0) - 0.5).abs() < 1e-6);
937        assert!((fast_div(10.0, 2.0) - 5.0).abs() < 1e-6);
938        assert_eq!(saturate(-1.0), 0.0);
939        assert_eq!(saturate(0.5), 0.5);
940        assert_eq!(saturate(2.0), 1.0);
941    }
942
943    #[test]
944    fn test_atomic_operations() {
945        let mut val = 10;
946        assert_eq!(atomic_add(&mut val, 5), 10);
947        assert_eq!(val, 15);
948
949        let mut val = 10;
950        assert_eq!(atomic_sub(&mut val, 3), 10);
951        assert_eq!(val, 7);
952
953        let mut val = 10;
954        assert_eq!(atomic_cas(&mut val, 10, 20), 10);
955        assert_eq!(val, 20);
956    }
957
958    #[test]
959    fn test_funnel_shift() {
960        assert_eq!(funnel_shift_left(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF);
961        assert_eq!(
962            funnel_shift_right(0xFFFF_0000, 0x0000_FFFF, 16),
963            0xFFFF_FFFF
964        );
965    }
966
967    #[test]
968    fn test_byte_perm() {
969        let x = 0x04030201u32;
970        let y = 0x08070605u32;
971        // Select bytes 0, 1, 2, 3 from x
972        assert_eq!(byte_perm(x, y, 0x3210), 0x04030201);
973        // Select bytes 4, 5, 6, 7 from y
974        assert_eq!(byte_perm(x, y, 0x7654), 0x08070605);
975    }
976}