1#[cfg(target_feature = "avx512f")]
2use crate::vectors::_512bit::*;
3use crate::{
4 into_vec::IntoVec,
5 type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
6 vectors::traits::VecTrait,
7};
8use core::f32;
9use half::{bf16, f16};
10use num_complex::{Complex32, Complex64};
11use std::fmt::Debug;
12
13pub trait CudaType {
15 const CUDA_TYPE: &'static str;
17}
18
19impl CudaType for bool {
20 const CUDA_TYPE: &'static str = "bool";
21}
22
23impl CudaType for i8 {
24 const CUDA_TYPE: &'static str = "char";
25}
26
27impl CudaType for u8 {
28 const CUDA_TYPE: &'static str = "unsigned char";
29}
30
31impl CudaType for i16 {
32 const CUDA_TYPE: &'static str = "short";
33}
34
35impl CudaType for u16 {
36 const CUDA_TYPE: &'static str = "unsigned short";
37}
38
39impl CudaType for i32 {
40 const CUDA_TYPE: &'static str = "int";
41}
42
43impl CudaType for u32 {
44 const CUDA_TYPE: &'static str = "unsigned int";
45}
46
47#[cfg(target_os = "windows")]
48impl CudaType for i64 {
49 const CUDA_TYPE: &'static str = "long long";
50}
51
52#[cfg(not(target_os = "windows"))]
53impl CudaType for i64 {
54 const CUDA_TYPE: &'static str = "long";
55}
56
57#[cfg(target_os = "windows")]
58impl CudaType for u64 {
59 const CUDA_TYPE: &'static str = "unsigned long long";
60}
61
62#[cfg(not(target_os = "windows"))]
63impl CudaType for u64 {
64 const CUDA_TYPE: &'static str = "unsigned long";
65}
66
67impl CudaType for f32 {
68 const CUDA_TYPE: &'static str = "float";
69}
70
71impl CudaType for f64 {
72 const CUDA_TYPE: &'static str = "double";
73}
74
75impl CudaType for Complex32 {
76 const CUDA_TYPE: &'static str = "cuFloatComplex";
77}
78
79impl CudaType for Complex64 {
80 const CUDA_TYPE: &'static str = "cuDoubleComplex";
81}
82
83#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
84impl CudaType for isize {
85 const CUDA_TYPE: &'static str = "long long";
86}
87
88#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
89impl CudaType for isize {
90 const CUDA_TYPE: &'static str = "long";
91}
92
93#[cfg(target_pointer_width = "32")]
94impl CudaType for isize {
95 const CUDA_TYPE: &'static str = "int";
96}
97
98#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
99impl CudaType for usize {
100 const CUDA_TYPE: &'static str = "unsigned long long";
101}
102
103#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
104impl CudaType for usize {
105 const CUDA_TYPE: &'static str = "unsigned long";
106}
107
108#[cfg(target_pointer_width = "32")]
109impl CudaType for usize {
110 const CUDA_TYPE: &'static str = "unsigned int";
111}
112
113impl CudaType for f16 {
114 const CUDA_TYPE: &'static str = "__half";
115}
116
117impl CudaType for bf16 {
118 const CUDA_TYPE: &'static str = "__nv_bfloat16";
119}
120
121pub trait TypeCommon
125where
126 Self: Sized + Copy,
127{
128 const MAX: Self;
130 const MIN: Self;
132 const ZERO: Self;
134 const ONE: Self;
136 const INF: Self;
138 const NEG_INF: Self;
140 const TWO: Self;
142 const SIX: Self;
144 const TEN: Self;
146 const STR: &'static str;
148 const BIT_SIZE: usize;
150 type Vec: VecTrait<Self>
152 + Send
153 + Copy
154 + IntoVec<Self::Vec>
155 + std::ops::Index<usize, Output = Self>
156 + std::ops::IndexMut<usize>
157 + Sync
158 + Debug
159 + NormalOutUnary
160 + NormalOut<Self::Vec, Output = Self::Vec>
161 + FloatOutUnary
162 + FloatOutBinary
163 + FloatOutBinary<
164 <Self::Vec as FloatOutUnary>::Output,
165 Output = <Self::Vec as FloatOutUnary>::Output,
166 >;
167}
168
169macro_rules! impl_type_common {
170 (
171 $type:ty,
172 $max:expr,
173 $min:expr,
174 $zero:expr,
175 $one:expr,
176 $inf:expr,
177 $neg_inf:expr,
178 $two:expr,
179 $six:expr,
180 $ten:expr,
181 $str:expr,
182 $vec:ty,
183 $mask:ty
184 ) => {
185 impl std::ops::Index<usize> for $vec {
186 type Output = $type;
187 fn index(&self, index: usize) -> &Self::Output {
188 if index >= <$vec>::SIZE {
189 panic!(
190 "index out of bounds: the len is {} but the index is {}",
191 <$vec>::SIZE,
192 index
193 );
194 }
195 unsafe { &*self.as_ptr().add(index) }
196 }
197 }
198 impl std::ops::IndexMut<usize> for $vec {
199 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
200 if index >= <$vec>::SIZE {
201 panic!(
202 "index out of bounds: the len is {} but the index is {}",
203 <$vec>::SIZE,
204 index
205 );
206 }
207 unsafe { &mut *self.as_mut_ptr().add(index) }
208 }
209 }
210 impl TypeCommon for $type {
211 const MAX: Self = $max;
212 const MIN: Self = $min;
213 const ZERO: Self = $zero;
214 const ONE: Self = $one;
215 const INF: Self = $inf;
216 const NEG_INF: Self = $neg_inf;
217 const TWO: Self = $two;
218 const SIX: Self = $six;
219 const TEN: Self = $ten;
220 const STR: &'static str = $str;
221 const BIT_SIZE: usize = size_of::<$type>();
222 type Vec = $vec;
223 }
224 };
225}
226
227#[cfg(target_feature = "avx2")]
228mod type_impl {
229 use super::TypeCommon;
230 use crate::simd::_256bit::*;
231 use crate::vectors::traits::VecTrait;
232 use half::*;
233 use num_complex::{Complex32, Complex64};
234 impl_type_common!(
235 bool,
236 true,
237 false,
238 false,
239 true,
240 true,
241 false,
242 false,
243 true,
244 true,
245 "bool",
246 boolx32::boolx32,
247 u8
248 );
249 impl_type_common!(
250 i8,
251 i8::MAX,
252 i8::MIN,
253 0,
254 1,
255 i8::MAX,
256 i8::MIN,
257 2,
258 6,
259 10,
260 "i8",
261 i8x32::i8x32,
262 u8
263 );
264 impl_type_common!(
265 u8,
266 u8::MAX,
267 u8::MIN,
268 0,
269 1,
270 u8::MAX,
271 u8::MIN,
272 2,
273 6,
274 10,
275 "u8",
276 u8x32::u8x32,
277 u8
278 );
279 impl_type_common!(
280 i16,
281 i16::MAX,
282 i16::MIN,
283 0,
284 1,
285 i16::MAX,
286 i16::MIN,
287 2,
288 6,
289 10,
290 "i16",
291 i16x16::i16x16,
292 u16
293 );
294 impl_type_common!(
295 u16,
296 u16::MAX,
297 u16::MIN,
298 0,
299 1,
300 u16::MAX,
301 u16::MIN,
302 2,
303 6,
304 10,
305 "u16",
306 u16x16::u16x16,
307 u16
308 );
309 impl_type_common!(
310 i32,
311 i32::MAX,
312 i32::MIN,
313 0,
314 1,
315 i32::MAX,
316 i32::MIN,
317 2,
318 6,
319 10,
320 "i32",
321 i32x8::i32x8,
322 u32
323 );
324 impl_type_common!(
325 u32,
326 u32::MAX,
327 u32::MIN,
328 0,
329 1,
330 u32::MAX,
331 u32::MIN,
332 2,
333 6,
334 10,
335 "u32",
336 u32x8::u32x8,
337 u32
338 );
339 impl_type_common!(
340 i64,
341 i64::MAX,
342 i64::MIN,
343 0,
344 1,
345 i64::MAX,
346 i64::MIN,
347 2,
348 6,
349 10,
350 "i64",
351 i64x4::i64x4,
352 u64
353 );
354 impl_type_common!(
355 u64,
356 u64::MAX,
357 u64::MIN,
358 0,
359 1,
360 u64::MAX,
361 u64::MIN,
362 2,
363 6,
364 10,
365 "u64",
366 u64x4::u64x4,
367 u64
368 );
369 impl_type_common!(
370 f32,
371 f32::MAX,
372 f32::MIN,
373 0.0,
374 1.0,
375 f32::INFINITY,
376 f32::NEG_INFINITY,
377 2.0,
378 6.0,
379 10.0,
380 "f32",
381 f32x8::f32x8,
382 u32
383 );
384 impl_type_common!(
385 f64,
386 f64::MAX,
387 f64::MIN,
388 0.0,
389 1.0,
390 f64::INFINITY,
391 f64::NEG_INFINITY,
392 2.0,
393 6.0,
394 10.0,
395 "f64",
396 f64x4::f64x4,
397 u64
398 );
399 #[cfg(target_pointer_width = "64")]
400 impl_type_common!(
401 isize,
402 isize::MAX,
403 isize::MIN,
404 0,
405 1,
406 isize::MAX,
407 isize::MIN,
408 2,
409 6,
410 10,
411 "isize",
412 isizex4::isizex4,
413 usize
414 );
415 #[cfg(target_pointer_width = "32")]
416 impl_type_common!(
417 isize,
418 isize::MAX,
419 isize::MIN,
420 0,
421 1,
422 isize::MAX,
423 isize::MIN,
424 2,
425 6,
426 10,
427 "isize",
428 isizex8::isizex8,
429 usize
430 );
431 #[cfg(target_pointer_width = "64")]
432 impl_type_common!(
433 usize,
434 usize::MAX,
435 usize::MIN,
436 0,
437 1,
438 usize::MAX,
439 usize::MIN,
440 2,
441 6,
442 10,
443 "usize",
444 usizex4::usizex4,
445 usize
446 );
447 #[cfg(target_pointer_width = "32")]
448 impl_type_common!(
449 usize,
450 usize::MAX,
451 usize::MIN,
452 0,
453 1,
454 usize::MAX,
455 usize::MIN,
456 2,
457 6,
458 10,
459 "usize",
460 usizex8::usizex8,
461 usize
462 );
463 impl_type_common!(
464 f16,
465 f16::MAX,
466 f16::MIN,
467 f16::ZERO,
468 f16::ONE,
469 f16::INFINITY,
470 f16::NEG_INFINITY,
471 f16::from_f32_const(2.0),
472 f16::from_f32_const(6.0),
473 f16::from_f32_const(10.0),
474 "f16",
475 f16x16::f16x16,
476 u16
477 );
478 impl_type_common!(
479 bf16,
480 bf16::MAX,
481 bf16::MIN,
482 bf16::ZERO,
483 bf16::ONE,
484 bf16::INFINITY,
485 bf16::NEG_INFINITY,
486 bf16::from_f32_const(2.0),
487 bf16::from_f32_const(6.0),
488 bf16::from_f32_const(10.0),
489 "bf16",
490 bf16x16::bf16x16,
491 u16
492 );
493 impl_type_common!(
494 Complex32,
495 Complex32::new(f32::MAX, f32::MAX),
496 Complex32::new(f32::MIN, f32::MIN),
497 Complex32::new(0.0, 0.0),
498 Complex32::new(1.0, 0.0),
499 Complex32::new(f32::INFINITY, f32::INFINITY),
500 Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
501 Complex32::new(2.0, 0.0),
502 Complex32::new(6.0, 0.0),
503 Complex32::new(10.0, 0.0),
504 "c32",
505 cplx32x4::cplx32x4,
506 (u32, u32)
507 );
508 impl_type_common!(
509 Complex64,
510 Complex64::new(f64::MAX, f64::MAX),
511 Complex64::new(f64::MIN, f64::MIN),
512 Complex64::new(0.0, 0.0),
513 Complex64::new(1.0, 0.0),
514 Complex64::new(f64::INFINITY, f64::INFINITY),
515 Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
516 Complex64::new(2.0, 0.0),
517 Complex64::new(6.0, 0.0),
518 Complex64::new(10.0, 0.0),
519 "c64",
520 cplx64x2::cplx64x2,
521 (u64, u64)
522 );
523}
524
525#[cfg(all(
526 any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
527 not(target_feature = "avx2")
528))]
529mod type_impl {
530 use super::TypeCommon;
531 use crate::simd::_128bit::*;
532 use crate::vectors::traits::VecTrait;
533 use half::*;
534 use num_complex::{Complex32, Complex64};
535 impl_type_common!(
536 bool,
537 true,
538 false,
539 false,
540 true,
541 true,
542 false,
543 false,
544 true,
545 true,
546 "bool",
547 boolx16::boolx16,
548 u8
549 );
550 impl_type_common!(
551 i8,
552 i8::MAX,
553 i8::MIN,
554 0,
555 1,
556 i8::MAX,
557 i8::MIN,
558 2,
559 6,
560 10,
561 "i8",
562 i8x16::i8x16,
563 u8
564 );
565 impl_type_common!(
566 u8,
567 u8::MAX,
568 u8::MIN,
569 0,
570 1,
571 u8::MAX,
572 u8::MIN,
573 2,
574 6,
575 10,
576 "u8",
577 u8x16::u8x16,
578 u8
579 );
580 impl_type_common!(
581 i16,
582 i16::MAX,
583 i16::MIN,
584 0,
585 1,
586 i16::MAX,
587 i16::MIN,
588 2,
589 6,
590 10,
591 "i16",
592 i16x8::i16x8,
593 u16
594 );
595 impl_type_common!(
596 u16,
597 u16::MAX,
598 u16::MIN,
599 0,
600 1,
601 u16::MAX,
602 u16::MIN,
603 2,
604 6,
605 10,
606 "u16",
607 u16x8::u16x8,
608 u16
609 );
610 impl_type_common!(
611 i32,
612 i32::MAX,
613 i32::MIN,
614 0,
615 1,
616 i32::MAX,
617 i32::MIN,
618 2,
619 6,
620 10,
621 "i32",
622 i32x4::i32x4,
623 u32
624 );
625 impl_type_common!(
626 u32,
627 u32::MAX,
628 u32::MIN,
629 0,
630 1,
631 u32::MAX,
632 u32::MIN,
633 2,
634 6,
635 10,
636 "u32",
637 u32x4::u32x4,
638 u32
639 );
640 impl_type_common!(
641 i64,
642 i64::MAX,
643 i64::MIN,
644 0,
645 1,
646 i64::MAX,
647 i64::MIN,
648 2,
649 6,
650 10,
651 "i64",
652 i64x2::i64x2,
653 u64
654 );
655 impl_type_common!(
656 u64,
657 u64::MAX,
658 u64::MIN,
659 0,
660 1,
661 u64::MAX,
662 u64::MIN,
663 2,
664 6,
665 10,
666 "u64",
667 u64x2::u64x2,
668 u64
669 );
670 impl_type_common!(
671 f32,
672 f32::MAX,
673 f32::MIN,
674 0.0,
675 1.0,
676 f32::INFINITY,
677 f32::NEG_INFINITY,
678 2.0,
679 6.0,
680 10.0,
681 "f32",
682 f32x4::f32x4,
683 u32
684 );
685 impl_type_common!(
686 f64,
687 f64::MAX,
688 f64::MIN,
689 0.0,
690 1.0,
691 f64::INFINITY,
692 f64::NEG_INFINITY,
693 2.0,
694 6.0,
695 10.0,
696 "f64",
697 f64x2::f64x2,
698 u64
699 );
700 #[cfg(target_pointer_width = "64")]
701 impl_type_common!(
702 isize,
703 isize::MAX,
704 isize::MIN,
705 0,
706 1,
707 isize::MAX,
708 isize::MIN,
709 2,
710 6,
711 10,
712 "isize",
713 isizex2::isizex2,
714 u64
715 );
716 #[cfg(target_pointer_width = "32")]
717 impl_type_common!(
718 isize,
719 isize::MAX,
720 isize::MIN,
721 0,
722 1,
723 isize::MAX,
724 isize::MIN,
725 2,
726 6,
727 10,
728 "isize",
729 "int",
730 isizex4::isizex4,
731 u32
732 );
733 #[cfg(target_pointer_width = "64")]
734 impl_type_common!(
735 usize,
736 usize::MAX,
737 usize::MIN,
738 0,
739 1,
740 usize::MAX,
741 usize::MIN,
742 2,
743 6,
744 10,
745 "usize",
746 usizex2::usizex2,
747 usize
748 );
749 #[cfg(target_pointer_width = "32")]
750 impl_type_common!(
751 usize,
752 usize::MAX,
753 usize::MIN,
754 0,
755 1,
756 usize::MAX,
757 usize::MIN,
758 2,
759 6,
760 10,
761 "usize",
762 "unsigned int",
763 usizex4::usizex4,
764 usize
765 );
766 impl_type_common!(
767 f16,
768 f16::MAX,
769 f16::MIN,
770 f16::ZERO,
771 f16::ONE,
772 f16::INFINITY,
773 f16::NEG_INFINITY,
774 f16::from_f32_const(2.0),
775 f16::from_f32_const(6.0),
776 f16::from_f32_const(10.0),
777 "f16",
778 f16x8::f16x8,
779 u16
780 );
781 impl_type_common!(
782 bf16,
783 bf16::MAX,
784 bf16::MIN,
785 bf16::ZERO,
786 bf16::ONE,
787 bf16::INFINITY,
788 bf16::NEG_INFINITY,
789 bf16::from_f32_const(2.0),
790 bf16::from_f32_const(6.0),
791 bf16::from_f32_const(10.0),
792 "bf16",
793 bf16x8::bf16x8,
794 u16
795 );
796 impl_type_common!(
797 Complex32,
798 Complex32::new(f32::MAX, f32::MAX),
799 Complex32::new(f32::MIN, f32::MIN),
800 Complex32::new(0.0, 0.0),
801 Complex32::new(1.0, 0.0),
802 Complex32::new(f32::INFINITY, f32::INFINITY),
803 Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
804 Complex32::new(2.0, 0.0),
805 Complex32::new(6.0, 0.0),
806 Complex32::new(10.0, 0.0),
807 "c32",
808 cplx32x2::cplx32x2,
809 (u32, u32)
810 );
811 impl_type_common!(
812 Complex64,
813 Complex64::new(f64::MAX, f64::MAX),
814 Complex64::new(f64::MIN, f64::MIN),
815 Complex64::new(0.0, 0.0),
816 Complex64::new(1.0, 0.0),
817 Complex64::new(f64::INFINITY, f64::INFINITY),
818 Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
819 Complex64::new(2.0, 0.0),
820 Complex64::new(6.0, 0.0),
821 Complex64::new(10.0, 0.0),
822 "c64",
823 cplx64x1::cplx64x1,
824 (u64, u64)
825 );
826}
827
828pub trait FloatConst {
830 const HALF: Self;
832 const E: Self;
834 const PI: Self;
836 const THREE: Self;
838 const TWOPI: Self;
840 const FOURPI: Self;
842 const POINT_TWO: Self;
844 const FRAC_1_SQRT_2: Self;
846}
847
848impl FloatConst for f32 {
849 const HALF: Self = 0.5;
850 const E: Self = f32::consts::E;
851 const PI: Self = f32::consts::PI;
852 const THREE: Self = 3.0;
853 const TWOPI: Self = f32::consts::PI * 2.0;
854 const FOURPI: Self = f32::consts::PI * 4.0;
855 const POINT_TWO: Self = 0.2;
856 const FRAC_1_SQRT_2: Self = f32::consts::FRAC_1_SQRT_2;
857}
858
859impl FloatConst for f64 {
860 const HALF: Self = 0.5;
861 const E: Self = std::f64::consts::E;
862 const PI: Self = std::f64::consts::PI;
863 const THREE: Self = 3.0;
864 const TWOPI: Self = std::f64::consts::PI * 2.0;
865 const FOURPI: Self = std::f64::consts::PI * 4.0;
866 const POINT_TWO: Self = 0.2;
867 const FRAC_1_SQRT_2: Self = std::f64::consts::FRAC_1_SQRT_2;
868}
869
870impl FloatConst for f16 {
871 const HALF: Self = f16::from_f32_const(0.5);
872 const E: Self = f16::from_f32_const(f32::consts::E);
873 const PI: Self = f16::from_f32_const(f32::consts::PI);
874 const THREE: Self = f16::from_f32_const(3.0);
875 const TWOPI: Self = f16::from_f32_const(f32::consts::PI * 2.0);
876 const FOURPI: Self = f16::from_f32_const(f32::consts::PI * 4.0);
877 const POINT_TWO: Self = f16::from_f32_const(0.2);
878 const FRAC_1_SQRT_2: Self = f16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
879}
880
881impl FloatConst for bf16 {
882 const HALF: Self = bf16::from_f32_const(0.5);
883 const E: Self = bf16::from_f32_const(f32::consts::E);
884 const PI: Self = bf16::from_f32_const(f32::consts::PI);
885 const THREE: Self = bf16::from_f32_const(3.0);
886 const TWOPI: Self = bf16::from_f32_const(f32::consts::PI * 2.0);
887 const FOURPI: Self = bf16::from_f32_const(f32::consts::PI * 4.0);
888 const POINT_TWO: Self = bf16::from_f32_const(0.2);
889 const FRAC_1_SQRT_2: Self = bf16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
890}