1use std::sync::atomic::{fence, Ordering};
52
53#[inline]
60pub fn thread_idx_x() -> i32 {
61 0 }
63
64#[inline]
67pub fn thread_idx_y() -> i32 {
68 0
69}
70
71#[inline]
74pub fn thread_idx_z() -> i32 {
75 0
76}
77
78#[inline]
81pub fn block_idx_x() -> i32 {
82 0
83}
84
85#[inline]
88pub fn block_idx_y() -> i32 {
89 0
90}
91
92#[inline]
95pub fn block_idx_z() -> i32 {
96 0
97}
98
99#[inline]
102pub fn block_dim_x() -> i32 {
103 1
104}
105
106#[inline]
109pub fn block_dim_y() -> i32 {
110 1
111}
112
113#[inline]
116pub fn block_dim_z() -> i32 {
117 1
118}
119
120#[inline]
123pub fn grid_dim_x() -> i32 {
124 1
125}
126
127#[inline]
130pub fn grid_dim_y() -> i32 {
131 1
132}
133
134#[inline]
137pub fn grid_dim_z() -> i32 {
138 1
139}
140
141#[inline]
144pub fn warp_size() -> i32 {
145 32
146}
147
148#[inline]
155pub fn sync_threads() {
156 }
158
159#[inline]
162pub fn sync_threads_count(predicate: bool) -> i32 {
163 if predicate {
164 1
165 } else {
166 0
167 }
168}
169
170#[inline]
173pub fn sync_threads_and(predicate: bool) -> i32 {
174 if predicate {
175 1
176 } else {
177 0
178 }
179}
180
181#[inline]
184pub fn sync_threads_or(predicate: bool) -> i32 {
185 if predicate {
186 1
187 } else {
188 0
189 }
190}
191
192#[inline]
195pub fn thread_fence() {
196 fence(Ordering::SeqCst);
197}
198
199#[inline]
202pub fn thread_fence_block() {
203 fence(Ordering::Release);
204}
205
206#[inline]
209pub fn thread_fence_system() {
210 fence(Ordering::SeqCst);
211}
212
213#[inline]
220pub fn atomic_add(addr: &mut i32, val: i32) -> i32 {
221 let old = *addr;
222 *addr += val;
223 old
224}
225
226#[inline]
228pub fn atomic_add_f32(addr: &mut f32, val: f32) -> f32 {
229 let old = *addr;
230 *addr += val;
231 old
232}
233
234#[inline]
236pub fn atomic_sub(addr: &mut i32, val: i32) -> i32 {
237 let old = *addr;
238 *addr -= val;
239 old
240}
241
242#[inline]
244pub fn atomic_min(addr: &mut i32, val: i32) -> i32 {
245 let old = *addr;
246 *addr = old.min(val);
247 old
248}
249
250#[inline]
252pub fn atomic_max(addr: &mut i32, val: i32) -> i32 {
253 let old = *addr;
254 *addr = old.max(val);
255 old
256}
257
258#[inline]
260pub fn atomic_exchange(addr: &mut i32, val: i32) -> i32 {
261 let old = *addr;
262 *addr = val;
263 old
264}
265
266#[inline]
268pub fn atomic_cas(addr: &mut i32, compare: i32, val: i32) -> i32 {
269 let old = *addr;
270 if old == compare {
271 *addr = val;
272 }
273 old
274}
275
276#[inline]
278pub fn atomic_and(addr: &mut i32, val: i32) -> i32 {
279 let old = *addr;
280 *addr &= val;
281 old
282}
283
284#[inline]
286pub fn atomic_or(addr: &mut i32, val: i32) -> i32 {
287 let old = *addr;
288 *addr |= val;
289 old
290}
291
292#[inline]
294pub fn atomic_xor(addr: &mut i32, val: i32) -> i32 {
295 let old = *addr;
296 *addr ^= val;
297 old
298}
299
300#[inline]
302pub fn atomic_inc(addr: &mut u32, val: u32) -> u32 {
303 let old = *addr;
304 *addr = if old >= val { 0 } else { old + 1 };
305 old
306}
307
308#[inline]
310pub fn atomic_dec(addr: &mut u32, val: u32) -> u32 {
311 let old = *addr;
312 *addr = if old == 0 || old > val { val } else { old - 1 };
313 old
314}
315
316#[inline]
322pub fn sqrt(x: f32) -> f32 {
323 x.sqrt()
324}
325
326#[inline]
328pub fn rsqrt(x: f32) -> f32 {
329 1.0 / x.sqrt()
330}
331
332#[inline]
334pub fn fabs(x: f32) -> f32 {
335 x.abs()
336}
337
338#[inline]
340pub fn floor(x: f32) -> f32 {
341 x.floor()
342}
343
344#[inline]
346pub fn ceil(x: f32) -> f32 {
347 x.ceil()
348}
349
350#[inline]
352pub fn round(x: f32) -> f32 {
353 x.round()
354}
355
356#[inline]
358pub fn trunc(x: f32) -> f32 {
359 x.trunc()
360}
361
362#[inline]
364pub fn fma(a: f32, b: f32, c: f32) -> f32 {
365 a.mul_add(b, c)
366}
367
368#[inline]
370pub fn fmin(a: f32, b: f32) -> f32 {
371 a.min(b)
372}
373
374#[inline]
376pub fn fmax(a: f32, b: f32) -> f32 {
377 a.max(b)
378}
379
380#[inline]
382pub fn fmod(x: f32, y: f32) -> f32 {
383 x % y
384}
385
386#[inline]
388pub fn remainder(x: f32, y: f32) -> f32 {
389 x - (x / y).round() * y
390}
391
392#[inline]
394pub fn copysign(x: f32, y: f32) -> f32 {
395 x.copysign(y)
396}
397
398#[inline]
400pub fn cbrt(x: f32) -> f32 {
401 x.cbrt()
402}
403
404#[inline]
406pub fn hypot(x: f32, y: f32) -> f32 {
407 x.hypot(y)
408}
409
410#[inline]
416pub fn sin(x: f32) -> f32 {
417 x.sin()
418}
419
420#[inline]
422pub fn cos(x: f32) -> f32 {
423 x.cos()
424}
425
426#[inline]
428pub fn tan(x: f32) -> f32 {
429 x.tan()
430}
431
432#[inline]
434pub fn asin(x: f32) -> f32 {
435 x.asin()
436}
437
438#[inline]
440pub fn acos(x: f32) -> f32 {
441 x.acos()
442}
443
444#[inline]
446pub fn atan(x: f32) -> f32 {
447 x.atan()
448}
449
450#[inline]
452pub fn atan2(y: f32, x: f32) -> f32 {
453 y.atan2(x)
454}
455
456#[inline]
458pub fn sincos(x: f32) -> (f32, f32) {
459 (x.sin(), x.cos())
460}
461
462#[inline]
464pub fn sinpi(x: f32) -> f32 {
465 (x * std::f32::consts::PI).sin()
466}
467
468#[inline]
470pub fn cospi(x: f32) -> f32 {
471 (x * std::f32::consts::PI).cos()
472}
473
474#[inline]
480pub fn sinh(x: f32) -> f32 {
481 x.sinh()
482}
483
484#[inline]
486pub fn cosh(x: f32) -> f32 {
487 x.cosh()
488}
489
490#[inline]
492pub fn tanh(x: f32) -> f32 {
493 x.tanh()
494}
495
496#[inline]
498pub fn asinh(x: f32) -> f32 {
499 x.asinh()
500}
501
502#[inline]
504pub fn acosh(x: f32) -> f32 {
505 x.acosh()
506}
507
508#[inline]
510pub fn atanh(x: f32) -> f32 {
511 x.atanh()
512}
513
514#[inline]
520pub fn exp(x: f32) -> f32 {
521 x.exp()
522}
523
524#[inline]
526pub fn exp2(x: f32) -> f32 {
527 x.exp2()
528}
529
530#[inline]
532pub fn exp10(x: f32) -> f32 {
533 (x * std::f32::consts::LN_10).exp()
534}
535
536#[inline]
538pub fn expm1(x: f32) -> f32 {
539 x.exp_m1()
540}
541
542#[inline]
544pub fn log(x: f32) -> f32 {
545 x.ln()
546}
547
548#[inline]
550pub fn log2(x: f32) -> f32 {
551 x.log2()
552}
553
554#[inline]
556pub fn log10(x: f32) -> f32 {
557 x.log10()
558}
559
560#[inline]
562pub fn log1p(x: f32) -> f32 {
563 x.ln_1p()
564}
565
566#[inline]
568pub fn pow(x: f32, y: f32) -> f32 {
569 x.powf(y)
570}
571
572#[inline]
574pub fn ldexp(x: f32, exp: i32) -> f32 {
575 x * 2.0_f32.powi(exp)
576}
577
578#[inline]
580pub fn scalbn(x: f32, n: i32) -> f32 {
581 x * 2.0_f32.powi(n)
582}
583
584#[inline]
586pub fn ilogb(x: f32) -> i32 {
587 if x == 0.0 {
588 i32::MIN
589 } else if x.is_infinite() {
590 i32::MAX
591 } else {
592 x.abs().log2().floor() as i32
593 }
594}
595
596#[inline]
598pub fn erf(x: f32) -> f32 {
599 let a1 = 0.254_829_6_f32;
601 let a2 = -0.284_496_74_f32;
602 let a3 = 1.421_413_7_f32;
603 let a4 = -1.453_152_f32;
604 let a5 = 1.061_405_4_f32;
605 let p = 0.327_591_1_f32;
606
607 let sign = if x < 0.0 { -1.0 } else { 1.0 };
608 let x = x.abs();
609 let t = 1.0 / (1.0 + p * x);
610 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
611 sign * y
612}
613
614#[inline]
616pub fn erfc(x: f32) -> f32 {
617 1.0 - erf(x)
618}
619
620#[inline]
626pub fn is_nan(x: f32) -> bool {
627 x.is_nan()
628}
629
630#[inline]
632pub fn is_infinite(x: f32) -> bool {
633 x.is_infinite()
634}
635
636#[inline]
638pub fn is_finite(x: f32) -> bool {
639 x.is_finite()
640}
641
642#[inline]
644pub fn is_normal(x: f32) -> bool {
645 x.is_normal()
646}
647
648#[inline]
650pub fn signbit(x: f32) -> bool {
651 x.is_sign_negative()
652}
653
654#[inline]
656pub fn nextafter(x: f32, y: f32) -> f32 {
657 if x == y {
658 y
659 } else if y > x {
660 f32::from_bits(x.to_bits() + 1)
661 } else {
662 f32::from_bits(x.to_bits() - 1)
663 }
664}
665
666#[inline]
668pub fn fdim(x: f32, y: f32) -> f32 {
669 if x > y {
670 x - y
671 } else {
672 0.0
673 }
674}
675
676#[inline]
682pub fn warp_active_mask() -> u32 {
683 1 }
685
686#[inline]
688pub fn warp_ballot(_mask: u32, predicate: bool) -> u32 {
689 if predicate {
690 1
691 } else {
692 0
693 }
694}
695
696#[inline]
698pub fn warp_all(_mask: u32, predicate: bool) -> bool {
699 predicate
700}
701
702#[inline]
704pub fn warp_any(_mask: u32, predicate: bool) -> bool {
705 predicate
706}
707
708#[inline]
710pub fn warp_shfl<T: Copy>(_mask: u32, val: T, _lane: i32) -> T {
711 val }
713
714#[inline]
716pub fn warp_shfl_up<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
717 val
718}
719
720#[inline]
722pub fn warp_shfl_down<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
723 val
724}
725
726#[inline]
728pub fn warp_shfl_xor<T: Copy>(_mask: u32, val: T, _lane_mask: i32) -> T {
729 val
730}
731
732#[inline]
734pub fn warp_reduce_add(_mask: u32, val: i32) -> i32 {
735 val }
737
738#[inline]
740pub fn warp_reduce_min(_mask: u32, val: i32) -> i32 {
741 val
742}
743
744#[inline]
746pub fn warp_reduce_max(_mask: u32, val: i32) -> i32 {
747 val
748}
749
750#[inline]
752pub fn warp_reduce_and(_mask: u32, val: u32) -> u32 {
753 val
754}
755
756#[inline]
758pub fn warp_reduce_or(_mask: u32, val: u32) -> u32 {
759 val
760}
761
762#[inline]
764pub fn warp_reduce_xor(_mask: u32, val: u32) -> u32 {
765 val
766}
767
768#[inline]
770pub fn warp_match_any(_mask: u32, _val: u32) -> u32 {
771 1 }
773
774#[inline]
776pub fn warp_match_all(_mask: u32, _val: u32) -> (u32, bool) {
777 (1, true) }
779
780#[inline]
786pub fn popc(x: u32) -> i32 {
787 x.count_ones() as i32
788}
789
790#[inline]
792pub fn popcount(x: i32) -> i32 {
793 (x as u32).count_ones() as i32
794}
795
796#[inline]
798pub fn clz(x: u32) -> i32 {
799 x.leading_zeros() as i32
800}
801
802#[inline]
804pub fn leading_zeros(x: i32) -> i32 {
805 (x as u32).leading_zeros() as i32
806}
807
808#[inline]
810pub fn ctz(x: u32) -> i32 {
811 if x == 0 {
812 32
813 } else {
814 x.trailing_zeros() as i32
815 }
816}
817
818#[inline]
820pub fn trailing_zeros(x: i32) -> i32 {
821 if x == 0 {
822 32
823 } else {
824 (x as u32).trailing_zeros() as i32
825 }
826}
827
828#[inline]
830pub fn ffs(x: u32) -> i32 {
831 if x == 0 {
832 0
833 } else {
834 (x.trailing_zeros() + 1) as i32
835 }
836}
837
838#[inline]
840pub fn brev(x: u32) -> u32 {
841 x.reverse_bits()
842}
843
844#[inline]
846pub fn reverse_bits(x: i32) -> i32 {
847 (x as u32).reverse_bits() as i32
848}
849
850#[inline]
852pub fn byte_perm(x: u32, y: u32, s: u32) -> u32 {
853 let bytes = [
854 (x & 0xFF) as u8,
855 ((x >> 8) & 0xFF) as u8,
856 ((x >> 16) & 0xFF) as u8,
857 ((x >> 24) & 0xFF) as u8,
858 (y & 0xFF) as u8,
859 ((y >> 8) & 0xFF) as u8,
860 ((y >> 16) & 0xFF) as u8,
861 ((y >> 24) & 0xFF) as u8,
862 ];
863 let b0 = bytes[(s & 0x7) as usize] as u32;
864 let b1 = bytes[((s >> 4) & 0x7) as usize] as u32;
865 let b2 = bytes[((s >> 8) & 0x7) as usize] as u32;
866 let b3 = bytes[((s >> 12) & 0x7) as usize] as u32;
867 b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
868}
869
870#[inline]
872pub fn funnel_shift_left(lo: u32, hi: u32, shift: u32) -> u32 {
873 let shift = shift & 31;
874 if shift == 0 {
875 lo
876 } else {
877 (hi << shift) | (lo >> (32 - shift))
878 }
879}
880
881#[inline]
883pub fn funnel_shift_right(lo: u32, hi: u32, shift: u32) -> u32 {
884 let shift = shift & 31;
885 if shift == 0 {
886 lo
887 } else {
888 (lo >> shift) | (hi << (32 - shift))
889 }
890}
891
892#[inline]
898pub fn ldg<T: Copy>(ptr: &T) -> T {
899 *ptr
900}
901
902#[inline]
904pub fn load_global<T: Copy>(ptr: &T) -> T {
905 *ptr
906}
907
908#[inline]
910pub fn prefetch_l1<T>(_ptr: &T) {
911 }
913
914#[inline]
916pub fn prefetch_l2<T>(_ptr: &T) {
917 }
919
920#[inline]
926pub fn rcp(x: f32) -> f32 {
927 1.0 / x
928}
929
930#[inline]
932pub fn fast_div(x: f32, y: f32) -> f32 {
933 x / y
934}
935
936#[inline]
938pub fn saturate(x: f32) -> f32 {
939 x.clamp(0.0, 1.0)
940}
941
942#[inline]
944pub fn clamp_01(x: f32) -> f32 {
945 saturate(x)
946}
947
948#[inline]
954pub fn clock() -> u32 {
955 std::time::SystemTime::now()
957 .duration_since(std::time::UNIX_EPOCH)
958 .map(|d| d.as_nanos() as u32)
959 .unwrap_or(0)
960}
961
962#[inline]
964pub fn clock64() -> u64 {
965 std::time::SystemTime::now()
966 .duration_since(std::time::UNIX_EPOCH)
967 .map(|d| d.as_nanos() as u64)
968 .unwrap_or(0)
969}
970
971#[inline]
973pub fn nanosleep(ns: u32) {
974 std::thread::sleep(std::time::Duration::from_nanos(ns as u64));
975}
976
977#[cfg(test)]
978mod tests {
979 use super::*;
980
981 #[test]
982 fn test_thread_indices_default() {
983 assert_eq!(thread_idx_x(), 0);
984 assert_eq!(thread_idx_y(), 0);
985 assert_eq!(thread_idx_z(), 0);
986 }
987
988 #[test]
989 fn test_block_indices_default() {
990 assert_eq!(block_idx_x(), 0);
991 assert_eq!(block_idx_y(), 0);
992 assert_eq!(block_idx_z(), 0);
993 }
994
995 #[test]
996 fn test_dimensions_default() {
997 assert_eq!(block_dim_x(), 1);
998 assert_eq!(block_dim_y(), 1);
999 assert_eq!(grid_dim_x(), 1);
1000 assert_eq!(warp_size(), 32);
1001 }
1002
1003 #[test]
1004 fn test_math_functions() {
1005 assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
1006 assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
1007 assert!((sin(0.0)).abs() < 1e-6);
1008 assert!((cos(0.0) - 1.0).abs() < 1e-6);
1009 assert!((exp(0.0) - 1.0).abs() < 1e-6);
1010 assert!((log(1.0)).abs() < 1e-6);
1011 }
1012
1013 #[test]
1014 fn test_trigonometric_functions() {
1015 let pi = std::f32::consts::PI;
1016 assert!((sin(pi / 2.0) - 1.0).abs() < 1e-6);
1017 assert!((cos(pi) + 1.0).abs() < 1e-6);
1018 assert!((tan(0.0)).abs() < 1e-6);
1019 assert!((asin(1.0) - pi / 2.0).abs() < 1e-6);
1020 assert!((atan2(1.0, 1.0) - pi / 4.0).abs() < 1e-6);
1021 }
1022
1023 #[test]
1024 fn test_hyperbolic_functions() {
1025 assert!((sinh(0.0)).abs() < 1e-6);
1026 assert!((cosh(0.0) - 1.0).abs() < 1e-6);
1027 assert!((tanh(0.0)).abs() < 1e-6);
1028 }
1029
1030 #[test]
1031 fn test_exponential_functions() {
1032 assert!((exp2(3.0) - 8.0).abs() < 1e-6);
1033 assert!((log2(8.0) - 3.0).abs() < 1e-6);
1034 assert!((log10(100.0) - 2.0).abs() < 1e-6);
1035 assert!((pow(2.0, 3.0) - 8.0).abs() < 1e-6);
1036 }
1037
1038 #[test]
1039 fn test_classification_functions() {
1040 assert!(is_nan(f32::NAN));
1041 assert!(!is_nan(1.0));
1042 assert!(is_infinite(f32::INFINITY));
1043 assert!(!is_infinite(1.0));
1044 assert!(is_finite(1.0));
1045 assert!(!is_finite(f32::INFINITY));
1046 }
1047
1048 #[test]
1049 fn test_bit_manipulation() {
1050 assert_eq!(popc(0b1010_1010), 4);
1051 assert_eq!(clz(1u32), 31);
1052 assert_eq!(clz(0x8000_0000u32), 0);
1053 assert_eq!(ctz(0b1000), 3);
1054 assert_eq!(ffs(0b1000), 4);
1055 assert_eq!(brev(1u32), 0x8000_0000);
1056 }
1057
1058 #[test]
1059 fn test_warp_operations() {
1060 assert_eq!(warp_active_mask(), 1);
1061 assert_eq!(warp_ballot(0xFFFF_FFFF, true), 1);
1062 assert!(warp_all(0xFFFF_FFFF, true));
1063 assert!(warp_any(0xFFFF_FFFF, true));
1064 assert_eq!(warp_reduce_add(0xFFFF_FFFF, 5), 5);
1065 }
1066
1067 #[test]
1068 fn test_special_functions() {
1069 assert!((rcp(2.0) - 0.5).abs() < 1e-6);
1070 assert!((fast_div(10.0, 2.0) - 5.0).abs() < 1e-6);
1071 assert_eq!(saturate(-1.0), 0.0);
1072 assert_eq!(saturate(0.5), 0.5);
1073 assert_eq!(saturate(2.0), 1.0);
1074 }
1075
1076 #[test]
1077 fn test_atomic_operations() {
1078 let mut val = 10;
1079 assert_eq!(atomic_add(&mut val, 5), 10);
1080 assert_eq!(val, 15);
1081
1082 let mut val = 10;
1083 assert_eq!(atomic_sub(&mut val, 3), 10);
1084 assert_eq!(val, 7);
1085
1086 let mut val = 10;
1087 assert_eq!(atomic_cas(&mut val, 10, 20), 10);
1088 assert_eq!(val, 20);
1089 }
1090
1091 #[test]
1092 fn test_funnel_shift() {
1093 assert_eq!(funnel_shift_left(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF);
1094 assert_eq!(
1095 funnel_shift_right(0xFFFF_0000, 0x0000_FFFF, 16),
1096 0xFFFF_FFFF
1097 );
1098 }
1099
1100 #[test]
1101 fn test_byte_perm() {
1102 let x = 0x04030201u32;
1103 let y = 0x08070605u32;
1104 assert_eq!(byte_perm(x, y, 0x3210), 0x04030201);
1106 assert_eq!(byte_perm(x, y, 0x7654), 0x08070605);
1108 }
1109}