1use std::sync::atomic::{fence, Ordering};
55
56pub use ringkernel_codegen::dsl_common::{
59 block_dim_x,
60 block_dim_y,
61 block_dim_z,
62 block_idx_x,
63 block_idx_y,
64 block_idx_z,
65 ceil,
66 cos,
67 exp,
68 floor,
69 fma,
70 grid_dim_x,
71 grid_dim_y,
72 grid_dim_z,
73 log,
74 powf,
75 round,
76 rsqrt,
77 sin,
78 sqrt,
80 sync_threads,
82 tan,
83 thread_fence,
84 thread_fence_block,
85 thread_idx_x,
87 thread_idx_y,
88 thread_idx_z,
89};
90
91#[inline]
98pub fn warp_size() -> i32 {
99 32
100}
101
102#[inline]
109pub fn sync_threads_count(predicate: bool) -> i32 {
110 if predicate {
111 1
112 } else {
113 0
114 }
115}
116
117#[inline]
120pub fn sync_threads_and(predicate: bool) -> i32 {
121 if predicate {
122 1
123 } else {
124 0
125 }
126}
127
128#[inline]
131pub fn sync_threads_or(predicate: bool) -> i32 {
132 if predicate {
133 1
134 } else {
135 0
136 }
137}
138
139#[inline]
142pub fn thread_fence_system() {
143 fence(Ordering::SeqCst);
144}
145
146#[inline]
153pub fn atomic_add(addr: &mut i32, val: i32) -> i32 {
154 let old = *addr;
155 *addr += val;
156 old
157}
158
159#[inline]
161pub fn atomic_add_f32(addr: &mut f32, val: f32) -> f32 {
162 let old = *addr;
163 *addr += val;
164 old
165}
166
167#[inline]
169pub fn atomic_sub(addr: &mut i32, val: i32) -> i32 {
170 let old = *addr;
171 *addr -= val;
172 old
173}
174
175#[inline]
177pub fn atomic_min(addr: &mut i32, val: i32) -> i32 {
178 let old = *addr;
179 *addr = old.min(val);
180 old
181}
182
183#[inline]
185pub fn atomic_max(addr: &mut i32, val: i32) -> i32 {
186 let old = *addr;
187 *addr = old.max(val);
188 old
189}
190
191#[inline]
193pub fn atomic_exchange(addr: &mut i32, val: i32) -> i32 {
194 let old = *addr;
195 *addr = val;
196 old
197}
198
199#[inline]
201pub fn atomic_cas(addr: &mut i32, compare: i32, val: i32) -> i32 {
202 let old = *addr;
203 if old == compare {
204 *addr = val;
205 }
206 old
207}
208
209#[inline]
211pub fn atomic_and(addr: &mut i32, val: i32) -> i32 {
212 let old = *addr;
213 *addr &= val;
214 old
215}
216
217#[inline]
219pub fn atomic_or(addr: &mut i32, val: i32) -> i32 {
220 let old = *addr;
221 *addr |= val;
222 old
223}
224
225#[inline]
227pub fn atomic_xor(addr: &mut i32, val: i32) -> i32 {
228 let old = *addr;
229 *addr ^= val;
230 old
231}
232
233#[inline]
235pub fn atomic_inc(addr: &mut u32, val: u32) -> u32 {
236 let old = *addr;
237 *addr = if old >= val { 0 } else { old + 1 };
238 old
239}
240
241#[inline]
243pub fn atomic_dec(addr: &mut u32, val: u32) -> u32 {
244 let old = *addr;
245 *addr = if old == 0 || old > val { val } else { old - 1 };
246 old
247}
248
249#[inline]
255pub fn fabs(x: f32) -> f32 {
256 x.abs()
257}
258
259#[inline]
261pub fn trunc(x: f32) -> f32 {
262 x.trunc()
263}
264
265#[inline]
267pub fn fmin(a: f32, b: f32) -> f32 {
268 a.min(b)
269}
270
271#[inline]
273pub fn fmax(a: f32, b: f32) -> f32 {
274 a.max(b)
275}
276
277#[inline]
279pub fn fmod(x: f32, y: f32) -> f32 {
280 x % y
281}
282
283#[inline]
285pub fn remainder(x: f32, y: f32) -> f32 {
286 x - (x / y).round() * y
287}
288
289#[inline]
291pub fn copysign(x: f32, y: f32) -> f32 {
292 x.copysign(y)
293}
294
295#[inline]
297pub fn cbrt(x: f32) -> f32 {
298 x.cbrt()
299}
300
301#[inline]
303pub fn hypot(x: f32, y: f32) -> f32 {
304 x.hypot(y)
305}
306
307#[inline]
313pub fn asin(x: f32) -> f32 {
314 x.asin()
315}
316
317#[inline]
319pub fn acos(x: f32) -> f32 {
320 x.acos()
321}
322
323#[inline]
325pub fn atan(x: f32) -> f32 {
326 x.atan()
327}
328
329#[inline]
331pub fn atan2(y: f32, x: f32) -> f32 {
332 y.atan2(x)
333}
334
335#[inline]
337pub fn sincos(x: f32) -> (f32, f32) {
338 (x.sin(), x.cos())
339}
340
341#[inline]
343pub fn sinpi(x: f32) -> f32 {
344 (x * std::f32::consts::PI).sin()
345}
346
347#[inline]
349pub fn cospi(x: f32) -> f32 {
350 (x * std::f32::consts::PI).cos()
351}
352
353#[inline]
359pub fn sinh(x: f32) -> f32 {
360 x.sinh()
361}
362
363#[inline]
365pub fn cosh(x: f32) -> f32 {
366 x.cosh()
367}
368
369#[inline]
371pub fn tanh(x: f32) -> f32 {
372 x.tanh()
373}
374
375#[inline]
377pub fn asinh(x: f32) -> f32 {
378 x.asinh()
379}
380
381#[inline]
383pub fn acosh(x: f32) -> f32 {
384 x.acosh()
385}
386
387#[inline]
389pub fn atanh(x: f32) -> f32 {
390 x.atanh()
391}
392
393#[inline]
399pub fn exp2(x: f32) -> f32 {
400 x.exp2()
401}
402
403#[inline]
405pub fn exp10(x: f32) -> f32 {
406 (x * std::f32::consts::LN_10).exp()
407}
408
409#[inline]
411pub fn expm1(x: f32) -> f32 {
412 x.exp_m1()
413}
414
415#[inline]
417pub fn log2(x: f32) -> f32 {
418 x.log2()
419}
420
421#[inline]
423pub fn log10(x: f32) -> f32 {
424 x.log10()
425}
426
427#[inline]
429pub fn log1p(x: f32) -> f32 {
430 x.ln_1p()
431}
432
433#[inline]
435pub fn pow(x: f32, y: f32) -> f32 {
436 x.powf(y)
437}
438
439#[inline]
441pub fn ldexp(x: f32, exp: i32) -> f32 {
442 x * 2.0_f32.powi(exp)
443}
444
445#[inline]
447pub fn scalbn(x: f32, n: i32) -> f32 {
448 x * 2.0_f32.powi(n)
449}
450
451#[inline]
453pub fn ilogb(x: f32) -> i32 {
454 if x == 0.0 {
455 i32::MIN
456 } else if x.is_infinite() {
457 i32::MAX
458 } else {
459 x.abs().log2().floor() as i32
460 }
461}
462
463#[inline]
465pub fn erf(x: f32) -> f32 {
466 let a1 = 0.254_829_6_f32;
468 let a2 = -0.284_496_74_f32;
469 let a3 = 1.421_413_7_f32;
470 let a4 = -1.453_152_f32;
471 let a5 = 1.061_405_4_f32;
472 let p = 0.327_591_1_f32;
473
474 let sign = if x < 0.0 { -1.0 } else { 1.0 };
475 let x = x.abs();
476 let t = 1.0 / (1.0 + p * x);
477 let y = 1.0 - (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * (-x * x).exp();
478 sign * y
479}
480
481#[inline]
483pub fn erfc(x: f32) -> f32 {
484 1.0 - erf(x)
485}
486
487#[inline]
493pub fn is_nan(x: f32) -> bool {
494 x.is_nan()
495}
496
497#[inline]
499pub fn is_infinite(x: f32) -> bool {
500 x.is_infinite()
501}
502
503#[inline]
505pub fn is_finite(x: f32) -> bool {
506 x.is_finite()
507}
508
509#[inline]
511pub fn is_normal(x: f32) -> bool {
512 x.is_normal()
513}
514
515#[inline]
517pub fn signbit(x: f32) -> bool {
518 x.is_sign_negative()
519}
520
521#[inline]
523pub fn nextafter(x: f32, y: f32) -> f32 {
524 if x == y {
525 y
526 } else if y > x {
527 f32::from_bits(x.to_bits() + 1)
528 } else {
529 f32::from_bits(x.to_bits() - 1)
530 }
531}
532
533#[inline]
535pub fn fdim(x: f32, y: f32) -> f32 {
536 if x > y {
537 x - y
538 } else {
539 0.0
540 }
541}
542
543#[inline]
549pub fn warp_active_mask() -> u32 {
550 1 }
552
553#[inline]
555pub fn warp_ballot(_mask: u32, predicate: bool) -> u32 {
556 if predicate {
557 1
558 } else {
559 0
560 }
561}
562
563#[inline]
565pub fn warp_all(_mask: u32, predicate: bool) -> bool {
566 predicate
567}
568
569#[inline]
571pub fn warp_any(_mask: u32, predicate: bool) -> bool {
572 predicate
573}
574
575#[inline]
577pub fn warp_shfl<T: Copy>(_mask: u32, val: T, _lane: i32) -> T {
578 val }
580
581#[inline]
583pub fn warp_shfl_up<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
584 val
585}
586
587#[inline]
589pub fn warp_shfl_down<T: Copy>(_mask: u32, val: T, _delta: u32) -> T {
590 val
591}
592
593#[inline]
595pub fn warp_shfl_xor<T: Copy>(_mask: u32, val: T, _lane_mask: i32) -> T {
596 val
597}
598
599#[inline]
601pub fn warp_reduce_add(_mask: u32, val: i32) -> i32 {
602 val }
604
605#[inline]
607pub fn warp_reduce_min(_mask: u32, val: i32) -> i32 {
608 val
609}
610
611#[inline]
613pub fn warp_reduce_max(_mask: u32, val: i32) -> i32 {
614 val
615}
616
617#[inline]
619pub fn warp_reduce_and(_mask: u32, val: u32) -> u32 {
620 val
621}
622
623#[inline]
625pub fn warp_reduce_or(_mask: u32, val: u32) -> u32 {
626 val
627}
628
629#[inline]
631pub fn warp_reduce_xor(_mask: u32, val: u32) -> u32 {
632 val
633}
634
635#[inline]
637pub fn warp_match_any(_mask: u32, _val: u32) -> u32 {
638 1 }
640
641#[inline]
643pub fn warp_match_all(_mask: u32, _val: u32) -> (u32, bool) {
644 (1, true) }
646
647#[inline]
653pub fn popc(x: u32) -> i32 {
654 x.count_ones() as i32
655}
656
657#[inline]
659pub fn popcount(x: i32) -> i32 {
660 (x as u32).count_ones() as i32
661}
662
663#[inline]
665pub fn clz(x: u32) -> i32 {
666 x.leading_zeros() as i32
667}
668
669#[inline]
671pub fn leading_zeros(x: i32) -> i32 {
672 (x as u32).leading_zeros() as i32
673}
674
675#[inline]
677pub fn ctz(x: u32) -> i32 {
678 if x == 0 {
679 32
680 } else {
681 x.trailing_zeros() as i32
682 }
683}
684
685#[inline]
687pub fn trailing_zeros(x: i32) -> i32 {
688 if x == 0 {
689 32
690 } else {
691 (x as u32).trailing_zeros() as i32
692 }
693}
694
695#[inline]
697pub fn ffs(x: u32) -> i32 {
698 if x == 0 {
699 0
700 } else {
701 (x.trailing_zeros() + 1) as i32
702 }
703}
704
705#[inline]
707pub fn brev(x: u32) -> u32 {
708 x.reverse_bits()
709}
710
711#[inline]
713pub fn reverse_bits(x: i32) -> i32 {
714 (x as u32).reverse_bits() as i32
715}
716
717#[inline]
719pub fn byte_perm(x: u32, y: u32, s: u32) -> u32 {
720 let bytes = [
721 (x & 0xFF) as u8,
722 ((x >> 8) & 0xFF) as u8,
723 ((x >> 16) & 0xFF) as u8,
724 ((x >> 24) & 0xFF) as u8,
725 (y & 0xFF) as u8,
726 ((y >> 8) & 0xFF) as u8,
727 ((y >> 16) & 0xFF) as u8,
728 ((y >> 24) & 0xFF) as u8,
729 ];
730 let b0 = bytes[(s & 0x7) as usize] as u32;
731 let b1 = bytes[((s >> 4) & 0x7) as usize] as u32;
732 let b2 = bytes[((s >> 8) & 0x7) as usize] as u32;
733 let b3 = bytes[((s >> 12) & 0x7) as usize] as u32;
734 b0 | (b1 << 8) | (b2 << 16) | (b3 << 24)
735}
736
737#[inline]
739pub fn funnel_shift_left(lo: u32, hi: u32, shift: u32) -> u32 {
740 let shift = shift & 31;
741 if shift == 0 {
742 lo
743 } else {
744 (hi << shift) | (lo >> (32 - shift))
745 }
746}
747
748#[inline]
750pub fn funnel_shift_right(lo: u32, hi: u32, shift: u32) -> u32 {
751 let shift = shift & 31;
752 if shift == 0 {
753 lo
754 } else {
755 (lo >> shift) | (hi << (32 - shift))
756 }
757}
758
759#[inline]
765pub fn ldg<T: Copy>(ptr: &T) -> T {
766 *ptr
767}
768
769#[inline]
771pub fn load_global<T: Copy>(ptr: &T) -> T {
772 *ptr
773}
774
775#[inline]
777pub fn prefetch_l1<T>(_ptr: &T) {
778 }
780
781#[inline]
783pub fn prefetch_l2<T>(_ptr: &T) {
784 }
786
787#[inline]
793pub fn rcp(x: f32) -> f32 {
794 1.0 / x
795}
796
797#[inline]
799pub fn fast_div(x: f32, y: f32) -> f32 {
800 x / y
801}
802
803#[inline]
805pub fn saturate(x: f32) -> f32 {
806 x.clamp(0.0, 1.0)
807}
808
809#[inline]
811pub fn clamp_01(x: f32) -> f32 {
812 saturate(x)
813}
814
815#[inline]
821pub fn clock() -> u32 {
822 std::time::SystemTime::now()
824 .duration_since(std::time::UNIX_EPOCH)
825 .map(|d| d.as_nanos() as u32)
826 .unwrap_or(0)
827}
828
829#[inline]
831pub fn clock64() -> u64 {
832 std::time::SystemTime::now()
833 .duration_since(std::time::UNIX_EPOCH)
834 .map(|d| d.as_nanos() as u64)
835 .unwrap_or(0)
836}
837
838#[inline]
840pub fn nanosleep(ns: u32) {
841 std::thread::sleep(std::time::Duration::from_nanos(ns as u64));
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847
848 #[test]
849 fn test_thread_indices_default() {
850 assert_eq!(thread_idx_x(), 0);
851 assert_eq!(thread_idx_y(), 0);
852 assert_eq!(thread_idx_z(), 0);
853 }
854
855 #[test]
856 fn test_block_indices_default() {
857 assert_eq!(block_idx_x(), 0);
858 assert_eq!(block_idx_y(), 0);
859 assert_eq!(block_idx_z(), 0);
860 }
861
862 #[test]
863 fn test_dimensions_default() {
864 assert_eq!(block_dim_x(), 1);
865 assert_eq!(block_dim_y(), 1);
866 assert_eq!(grid_dim_x(), 1);
867 assert_eq!(warp_size(), 32);
868 }
869
870 #[test]
871 fn test_math_functions() {
872 assert!((sqrt(4.0) - 2.0).abs() < 1e-6);
873 assert!((rsqrt(4.0) - 0.5).abs() < 1e-6);
874 assert!((sin(0.0)).abs() < 1e-6);
875 assert!((cos(0.0) - 1.0).abs() < 1e-6);
876 assert!((exp(0.0) - 1.0).abs() < 1e-6);
877 assert!((log(1.0)).abs() < 1e-6);
878 }
879
880 #[test]
881 fn test_trigonometric_functions() {
882 let pi = std::f32::consts::PI;
883 assert!((sin(pi / 2.0) - 1.0).abs() < 1e-6);
884 assert!((cos(pi) + 1.0).abs() < 1e-6);
885 assert!((tan(0.0)).abs() < 1e-6);
886 assert!((asin(1.0) - pi / 2.0).abs() < 1e-6);
887 assert!((atan2(1.0, 1.0) - pi / 4.0).abs() < 1e-6);
888 }
889
890 #[test]
891 fn test_hyperbolic_functions() {
892 assert!((sinh(0.0)).abs() < 1e-6);
893 assert!((cosh(0.0) - 1.0).abs() < 1e-6);
894 assert!((tanh(0.0)).abs() < 1e-6);
895 }
896
897 #[test]
898 fn test_exponential_functions() {
899 assert!((exp2(3.0) - 8.0).abs() < 1e-6);
900 assert!((log2(8.0) - 3.0).abs() < 1e-6);
901 assert!((log10(100.0) - 2.0).abs() < 1e-6);
902 assert!((pow(2.0, 3.0) - 8.0).abs() < 1e-6);
903 }
904
905 #[test]
906 fn test_classification_functions() {
907 assert!(is_nan(f32::NAN));
908 assert!(!is_nan(1.0));
909 assert!(is_infinite(f32::INFINITY));
910 assert!(!is_infinite(1.0));
911 assert!(is_finite(1.0));
912 assert!(!is_finite(f32::INFINITY));
913 }
914
915 #[test]
916 fn test_bit_manipulation() {
917 assert_eq!(popc(0b1010_1010), 4);
918 assert_eq!(clz(1u32), 31);
919 assert_eq!(clz(0x8000_0000u32), 0);
920 assert_eq!(ctz(0b1000), 3);
921 assert_eq!(ffs(0b1000), 4);
922 assert_eq!(brev(1u32), 0x8000_0000);
923 }
924
925 #[test]
926 fn test_warp_operations() {
927 assert_eq!(warp_active_mask(), 1);
928 assert_eq!(warp_ballot(0xFFFF_FFFF, true), 1);
929 assert!(warp_all(0xFFFF_FFFF, true));
930 assert!(warp_any(0xFFFF_FFFF, true));
931 assert_eq!(warp_reduce_add(0xFFFF_FFFF, 5), 5);
932 }
933
934 #[test]
935 fn test_special_functions() {
936 assert!((rcp(2.0) - 0.5).abs() < 1e-6);
937 assert!((fast_div(10.0, 2.0) - 5.0).abs() < 1e-6);
938 assert_eq!(saturate(-1.0), 0.0);
939 assert_eq!(saturate(0.5), 0.5);
940 assert_eq!(saturate(2.0), 1.0);
941 }
942
943 #[test]
944 fn test_atomic_operations() {
945 let mut val = 10;
946 assert_eq!(atomic_add(&mut val, 5), 10);
947 assert_eq!(val, 15);
948
949 let mut val = 10;
950 assert_eq!(atomic_sub(&mut val, 3), 10);
951 assert_eq!(val, 7);
952
953 let mut val = 10;
954 assert_eq!(atomic_cas(&mut val, 10, 20), 10);
955 assert_eq!(val, 20);
956 }
957
958 #[test]
959 fn test_funnel_shift() {
960 assert_eq!(funnel_shift_left(0xFFFF_0000, 0x0000_FFFF, 16), 0xFFFF_FFFF);
961 assert_eq!(
962 funnel_shift_right(0xFFFF_0000, 0x0000_FFFF, 16),
963 0xFFFF_FFFF
964 );
965 }
966
967 #[test]
968 fn test_byte_perm() {
969 let x = 0x04030201u32;
970 let y = 0x08070605u32;
971 assert_eq!(byte_perm(x, y, 0x3210), 0x04030201);
973 assert_eq!(byte_perm(x, y, 0x7654), 0x08070605);
975 }
976}