1use crate::{
2 into_vec::IntoVec,
3 type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
4 vectors::traits::VecTrait,
5};
6use core::f32;
7use half::{bf16, f16};
8use num_complex::{Complex32, Complex64};
9use std::fmt::Debug;
10
11pub trait CudaType {
13 const CUDA_TYPE: &'static str;
15}
16
17impl CudaType for bool {
18 const CUDA_TYPE: &'static str = "bool";
19}
20
21impl CudaType for i8 {
22 const CUDA_TYPE: &'static str = "char";
23}
24
25impl CudaType for u8 {
26 const CUDA_TYPE: &'static str = "unsigned char";
27}
28
29impl CudaType for i16 {
30 const CUDA_TYPE: &'static str = "short";
31}
32
33impl CudaType for u16 {
34 const CUDA_TYPE: &'static str = "unsigned short";
35}
36
37impl CudaType for i32 {
38 const CUDA_TYPE: &'static str = "int";
39}
40
41impl CudaType for u32 {
42 const CUDA_TYPE: &'static str = "unsigned int";
43}
44
45#[cfg(target_os = "windows")]
46impl CudaType for i64 {
47 const CUDA_TYPE: &'static str = "long long";
48}
49
50#[cfg(not(target_os = "windows"))]
51impl CudaType for i64 {
52 const CUDA_TYPE: &'static str = "long";
53}
54
55#[cfg(target_os = "windows")]
56impl CudaType for u64 {
57 const CUDA_TYPE: &'static str = "unsigned long long";
58}
59
60#[cfg(not(target_os = "windows"))]
61impl CudaType for u64 {
62 const CUDA_TYPE: &'static str = "unsigned long";
63}
64
65impl CudaType for f32 {
66 const CUDA_TYPE: &'static str = "float";
67}
68
69impl CudaType for f64 {
70 const CUDA_TYPE: &'static str = "double";
71}
72
73impl CudaType for Complex32 {
74 const CUDA_TYPE: &'static str = "cuFloatComplex";
75}
76
77impl CudaType for Complex64 {
78 const CUDA_TYPE: &'static str = "cuDoubleComplex";
79}
80
81#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
82impl CudaType for isize {
83 const CUDA_TYPE: &'static str = "long long";
84}
85
86#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
87impl CudaType for isize {
88 const CUDA_TYPE: &'static str = "long";
89}
90
91#[cfg(target_pointer_width = "32")]
92impl CudaType for isize {
93 const CUDA_TYPE: &'static str = "int";
94}
95
96#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
97impl CudaType for usize {
98 const CUDA_TYPE: &'static str = "unsigned long long";
99}
100
101#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
102impl CudaType for usize {
103 const CUDA_TYPE: &'static str = "unsigned long";
104}
105
106#[cfg(target_pointer_width = "32")]
107impl CudaType for usize {
108 const CUDA_TYPE: &'static str = "unsigned int";
109}
110
111impl CudaType for f16 {
112 const CUDA_TYPE: &'static str = "__half";
113}
114
115impl CudaType for bf16 {
116 const CUDA_TYPE: &'static str = "__nv_bfloat16";
117}
118
119pub trait TypeCommon
123where
124 Self: Sized + Copy,
125{
126 const MAX: Self;
128 const MIN: Self;
130 const ZERO: Self;
132 const ONE: Self;
134 const INF: Self;
136 const NEG_INF: Self;
138 const TWO: Self;
140 const SIX: Self;
142 const TEN: Self;
144 const STR: &'static str;
146 const BYTE_SIZE: usize;
148 type Vec: VecTrait<Self>
150 + Send
151 + Copy
152 + IntoVec<Self::Vec>
153 + std::ops::Index<usize, Output = Self>
154 + std::ops::IndexMut<usize>
155 + Sync
156 + Debug
157 + NormalOutUnary
158 + NormalOut<Self::Vec, Output = Self::Vec>
159 + FloatOutUnary
160 + FloatOutBinary
161 + FloatOutBinary<
162 <Self::Vec as FloatOutUnary>::Output,
163 Output = <Self::Vec as FloatOutUnary>::Output,
164 >;
165}
166
167macro_rules! impl_type_common {
168 (
169 $type:ty,
170 $max:expr,
171 $min:expr,
172 $zero:expr,
173 $one:expr,
174 $inf:expr,
175 $neg_inf:expr,
176 $two:expr,
177 $six:expr,
178 $ten:expr,
179 $str:expr,
180 $vec:ty,
181 $mask:ty
182 ) => {
183 impl std::ops::Index<usize> for $vec {
184 type Output = $type;
185 fn index(&self, index: usize) -> &Self::Output {
186 if index >= <$vec>::SIZE {
187 panic!(
188 "index out of bounds: the len is {} but the index is {}",
189 <$vec>::SIZE,
190 index
191 );
192 }
193 unsafe { &*self.as_ptr().add(index) }
194 }
195 }
196 impl std::ops::IndexMut<usize> for $vec {
197 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
198 if index >= <$vec>::SIZE {
199 panic!(
200 "index out of bounds: the len is {} but the index is {}",
201 <$vec>::SIZE,
202 index
203 );
204 }
205 unsafe { &mut *self.as_mut_ptr().add(index) }
206 }
207 }
208 impl TypeCommon for $type {
209 const MAX: Self = $max;
210 const MIN: Self = $min;
211 const ZERO: Self = $zero;
212 const ONE: Self = $one;
213 const INF: Self = $inf;
214 const NEG_INF: Self = $neg_inf;
215 const TWO: Self = $two;
216 const SIX: Self = $six;
217 const TEN: Self = $ten;
218 const STR: &'static str = $str;
219 const BYTE_SIZE: usize = size_of::<$type>();
220 type Vec = $vec;
221 }
222 };
223}
224
225#[cfg(target_feature = "avx2")]
226mod type_impl {
227 use super::TypeCommon;
228 use crate::simd::_256bit::*;
229 use crate::vectors::traits::VecTrait;
230 use half::*;
231 use num_complex::{Complex32, Complex64};
232 impl_type_common!(
233 bool, true, false, false, true, true, false, false, true, true, "bool", boolx32, u8
234 );
235 impl_type_common!(
236 i8,
237 i8::MAX,
238 i8::MIN,
239 0,
240 1,
241 i8::MAX,
242 i8::MIN,
243 2,
244 6,
245 10,
246 "i8",
247 i8x32,
248 u8
249 );
250 impl_type_common!(
251 u8,
252 u8::MAX,
253 u8::MIN,
254 0,
255 1,
256 u8::MAX,
257 u8::MIN,
258 2,
259 6,
260 10,
261 "u8",
262 u8x32,
263 u8
264 );
265 impl_type_common!(
266 i16,
267 i16::MAX,
268 i16::MIN,
269 0,
270 1,
271 i16::MAX,
272 i16::MIN,
273 2,
274 6,
275 10,
276 "i16",
277 i16x16,
278 u16
279 );
280 impl_type_common!(
281 u16,
282 u16::MAX,
283 u16::MIN,
284 0,
285 1,
286 u16::MAX,
287 u16::MIN,
288 2,
289 6,
290 10,
291 "u16",
292 u16x16,
293 u16
294 );
295 impl_type_common!(
296 i32,
297 i32::MAX,
298 i32::MIN,
299 0,
300 1,
301 i32::MAX,
302 i32::MIN,
303 2,
304 6,
305 10,
306 "i32",
307 i32x8,
308 u32
309 );
310 impl_type_common!(
311 u32,
312 u32::MAX,
313 u32::MIN,
314 0,
315 1,
316 u32::MAX,
317 u32::MIN,
318 2,
319 6,
320 10,
321 "u32",
322 u32x8,
323 u32
324 );
325 impl_type_common!(
326 i64,
327 i64::MAX,
328 i64::MIN,
329 0,
330 1,
331 i64::MAX,
332 i64::MIN,
333 2,
334 6,
335 10,
336 "i64",
337 i64x4,
338 u64
339 );
340 impl_type_common!(
341 u64,
342 u64::MAX,
343 u64::MIN,
344 0,
345 1,
346 u64::MAX,
347 u64::MIN,
348 2,
349 6,
350 10,
351 "u64",
352 u64x4,
353 u64
354 );
355 impl_type_common!(
356 f32,
357 f32::MAX,
358 f32::MIN,
359 0.0,
360 1.0,
361 f32::INFINITY,
362 f32::NEG_INFINITY,
363 2.0,
364 6.0,
365 10.0,
366 "f32",
367 f32x8,
368 u32
369 );
370 impl_type_common!(
371 f64,
372 f64::MAX,
373 f64::MIN,
374 0.0,
375 1.0,
376 f64::INFINITY,
377 f64::NEG_INFINITY,
378 2.0,
379 6.0,
380 10.0,
381 "f64",
382 f64x4,
383 u64
384 );
385 #[cfg(target_pointer_width = "64")]
386 impl_type_common!(
387 isize,
388 isize::MAX,
389 isize::MIN,
390 0,
391 1,
392 isize::MAX,
393 isize::MIN,
394 2,
395 6,
396 10,
397 "isize",
398 isizex4,
399 usize
400 );
401 #[cfg(target_pointer_width = "32")]
402 impl_type_common!(
403 isize,
404 isize::MAX,
405 isize::MIN,
406 0,
407 1,
408 isize::MAX,
409 isize::MIN,
410 2,
411 6,
412 10,
413 "isize",
414 isizex8,
415 usize
416 );
417 #[cfg(target_pointer_width = "64")]
418 impl_type_common!(
419 usize,
420 usize::MAX,
421 usize::MIN,
422 0,
423 1,
424 usize::MAX,
425 usize::MIN,
426 2,
427 6,
428 10,
429 "usize",
430 usizex4,
431 usize
432 );
433 #[cfg(target_pointer_width = "32")]
434 impl_type_common!(
435 usize,
436 usize::MAX,
437 usize::MIN,
438 0,
439 1,
440 usize::MAX,
441 usize::MIN,
442 2,
443 6,
444 10,
445 "usize",
446 usizex8,
447 usize
448 );
449 impl_type_common!(
450 f16,
451 f16::MAX,
452 f16::MIN,
453 f16::ZERO,
454 f16::ONE,
455 f16::INFINITY,
456 f16::NEG_INFINITY,
457 f16::from_f32_const(2.0),
458 f16::from_f32_const(6.0),
459 f16::from_f32_const(10.0),
460 "f16",
461 f16x16,
462 u16
463 );
464 impl_type_common!(
465 bf16,
466 bf16::MAX,
467 bf16::MIN,
468 bf16::ZERO,
469 bf16::ONE,
470 bf16::INFINITY,
471 bf16::NEG_INFINITY,
472 bf16::from_f32_const(2.0),
473 bf16::from_f32_const(6.0),
474 bf16::from_f32_const(10.0),
475 "bf16",
476 bf16x16,
477 u16
478 );
479 impl_type_common!(
480 Complex32,
481 Complex32::new(f32::MAX, f32::MAX),
482 Complex32::new(f32::MIN, f32::MIN),
483 Complex32::new(0.0, 0.0),
484 Complex32::new(1.0, 0.0),
485 Complex32::new(f32::INFINITY, f32::INFINITY),
486 Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
487 Complex32::new(2.0, 0.0),
488 Complex32::new(6.0, 0.0),
489 Complex32::new(10.0, 0.0),
490 "c32",
491 cplx32x4,
492 (u32, u32)
493 );
494 impl_type_common!(
495 Complex64,
496 Complex64::new(f64::MAX, f64::MAX),
497 Complex64::new(f64::MIN, f64::MIN),
498 Complex64::new(0.0, 0.0),
499 Complex64::new(1.0, 0.0),
500 Complex64::new(f64::INFINITY, f64::INFINITY),
501 Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
502 Complex64::new(2.0, 0.0),
503 Complex64::new(6.0, 0.0),
504 Complex64::new(10.0, 0.0),
505 "c64",
506 cplx64x2,
507 (u64, u64)
508 );
509}
510
511#[cfg(all(
512 any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
513 not(target_feature = "avx2")
514))]
515mod type_impl {
516 use super::TypeCommon;
517 use crate::simd::_128bit::*;
518 use crate::vectors::traits::VecTrait;
519 use half::*;
520 use num_complex::{Complex32, Complex64};
521 impl_type_common!(
522 bool, true, false, false, true, true, false, false, true, true, "bool", boolx16, u8
523 );
524 impl_type_common!(
525 i8,
526 i8::MAX,
527 i8::MIN,
528 0,
529 1,
530 i8::MAX,
531 i8::MIN,
532 2,
533 6,
534 10,
535 "i8",
536 i8x16,
537 u8
538 );
539 impl_type_common!(
540 u8,
541 u8::MAX,
542 u8::MIN,
543 0,
544 1,
545 u8::MAX,
546 u8::MIN,
547 2,
548 6,
549 10,
550 "u8",
551 u8x16,
552 u8
553 );
554 impl_type_common!(
555 i16,
556 i16::MAX,
557 i16::MIN,
558 0,
559 1,
560 i16::MAX,
561 i16::MIN,
562 2,
563 6,
564 10,
565 "i16",
566 i16x8,
567 u16
568 );
569 impl_type_common!(
570 u16,
571 u16::MAX,
572 u16::MIN,
573 0,
574 1,
575 u16::MAX,
576 u16::MIN,
577 2,
578 6,
579 10,
580 "u16",
581 u16x8,
582 u16
583 );
584 impl_type_common!(
585 i32,
586 i32::MAX,
587 i32::MIN,
588 0,
589 1,
590 i32::MAX,
591 i32::MIN,
592 2,
593 6,
594 10,
595 "i32",
596 i32x4,
597 u32
598 );
599 impl_type_common!(
600 u32,
601 u32::MAX,
602 u32::MIN,
603 0,
604 1,
605 u32::MAX,
606 u32::MIN,
607 2,
608 6,
609 10,
610 "u32",
611 u32x4,
612 u32
613 );
614 impl_type_common!(
615 i64,
616 i64::MAX,
617 i64::MIN,
618 0,
619 1,
620 i64::MAX,
621 i64::MIN,
622 2,
623 6,
624 10,
625 "i64",
626 i64x2,
627 u64
628 );
629 impl_type_common!(
630 u64,
631 u64::MAX,
632 u64::MIN,
633 0,
634 1,
635 u64::MAX,
636 u64::MIN,
637 2,
638 6,
639 10,
640 "u64",
641 u64x2,
642 u64
643 );
644 impl_type_common!(
645 f32,
646 f32::MAX,
647 f32::MIN,
648 0.0,
649 1.0,
650 f32::INFINITY,
651 f32::NEG_INFINITY,
652 2.0,
653 6.0,
654 10.0,
655 "f32",
656 f32x4,
657 u32
658 );
659 impl_type_common!(
660 f64,
661 f64::MAX,
662 f64::MIN,
663 0.0,
664 1.0,
665 f64::INFINITY,
666 f64::NEG_INFINITY,
667 2.0,
668 6.0,
669 10.0,
670 "f64",
671 f64x2,
672 u64
673 );
674 #[cfg(target_pointer_width = "64")]
675 impl_type_common!(
676 isize,
677 isize::MAX,
678 isize::MIN,
679 0,
680 1,
681 isize::MAX,
682 isize::MIN,
683 2,
684 6,
685 10,
686 "isize",
687 isizex2,
688 u64
689 );
690 #[cfg(target_pointer_width = "32")]
691 impl_type_common!(
692 isize,
693 isize::MAX,
694 isize::MIN,
695 0,
696 1,
697 isize::MAX,
698 isize::MIN,
699 2,
700 6,
701 10,
702 "isize",
703 "int",
704 isizex4,
705 u32
706 );
707 #[cfg(target_pointer_width = "64")]
708 impl_type_common!(
709 usize,
710 usize::MAX,
711 usize::MIN,
712 0,
713 1,
714 usize::MAX,
715 usize::MIN,
716 2,
717 6,
718 10,
719 "usize",
720 usizex2,
721 usize
722 );
723 #[cfg(target_pointer_width = "32")]
724 impl_type_common!(
725 usize,
726 usize::MAX,
727 usize::MIN,
728 0,
729 1,
730 usize::MAX,
731 usize::MIN,
732 2,
733 6,
734 10,
735 "usize",
736 "unsigned int",
737 usizex4,
738 usize
739 );
740 impl_type_common!(
741 f16,
742 f16::MAX,
743 f16::MIN,
744 f16::ZERO,
745 f16::ONE,
746 f16::INFINITY,
747 f16::NEG_INFINITY,
748 f16::from_f32_const(2.0),
749 f16::from_f32_const(6.0),
750 f16::from_f32_const(10.0),
751 "f16",
752 f16x8,
753 u16
754 );
755 impl_type_common!(
756 bf16,
757 bf16::MAX,
758 bf16::MIN,
759 bf16::ZERO,
760 bf16::ONE,
761 bf16::INFINITY,
762 bf16::NEG_INFINITY,
763 bf16::from_f32_const(2.0),
764 bf16::from_f32_const(6.0),
765 bf16::from_f32_const(10.0),
766 "bf16",
767 bf16x8,
768 u16
769 );
770 impl_type_common!(
771 Complex32,
772 Complex32::new(f32::MAX, f32::MAX),
773 Complex32::new(f32::MIN, f32::MIN),
774 Complex32::new(0.0, 0.0),
775 Complex32::new(1.0, 0.0),
776 Complex32::new(f32::INFINITY, f32::INFINITY),
777 Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
778 Complex32::new(2.0, 0.0),
779 Complex32::new(6.0, 0.0),
780 Complex32::new(10.0, 0.0),
781 "c32",
782 cplx32x2,
783 (u32, u32)
784 );
785 impl_type_common!(
786 Complex64,
787 Complex64::new(f64::MAX, f64::MAX),
788 Complex64::new(f64::MIN, f64::MIN),
789 Complex64::new(0.0, 0.0),
790 Complex64::new(1.0, 0.0),
791 Complex64::new(f64::INFINITY, f64::INFINITY),
792 Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
793 Complex64::new(2.0, 0.0),
794 Complex64::new(6.0, 0.0),
795 Complex64::new(10.0, 0.0),
796 "c64",
797 cplx64x1,
798 (u64, u64)
799 );
800}
801
802pub trait FloatConst {
804 const HALF: Self;
806 const E: Self;
808 const PI: Self;
810 const THREE: Self;
812 const TWOPI: Self;
814 const FOURPI: Self;
816 const POINT_TWO: Self;
818 const FRAC_1_SQRT_2: Self;
820}
821
822impl FloatConst for f32 {
823 const HALF: Self = 0.5;
824 const E: Self = f32::consts::E;
825 const PI: Self = f32::consts::PI;
826 const THREE: Self = 3.0;
827 const TWOPI: Self = f32::consts::PI * 2.0;
828 const FOURPI: Self = f32::consts::PI * 4.0;
829 const POINT_TWO: Self = 0.2;
830 const FRAC_1_SQRT_2: Self = f32::consts::FRAC_1_SQRT_2;
831}
832
833impl FloatConst for f64 {
834 const HALF: Self = 0.5;
835 const E: Self = std::f64::consts::E;
836 const PI: Self = std::f64::consts::PI;
837 const THREE: Self = 3.0;
838 const TWOPI: Self = std::f64::consts::PI * 2.0;
839 const FOURPI: Self = std::f64::consts::PI * 4.0;
840 const POINT_TWO: Self = 0.2;
841 const FRAC_1_SQRT_2: Self = std::f64::consts::FRAC_1_SQRT_2;
842}
843
844impl FloatConst for f16 {
845 const HALF: Self = f16::from_f32_const(0.5);
846 const E: Self = f16::from_f32_const(f32::consts::E);
847 const PI: Self = f16::from_f32_const(f32::consts::PI);
848 const THREE: Self = f16::from_f32_const(3.0);
849 const TWOPI: Self = f16::from_f32_const(f32::consts::PI * 2.0);
850 const FOURPI: Self = f16::from_f32_const(f32::consts::PI * 4.0);
851 const POINT_TWO: Self = f16::from_f32_const(0.2);
852 const FRAC_1_SQRT_2: Self = f16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
853}
854
855impl FloatConst for bf16 {
856 const HALF: Self = bf16::from_f32_const(0.5);
857 const E: Self = bf16::from_f32_const(f32::consts::E);
858 const PI: Self = bf16::from_f32_const(f32::consts::PI);
859 const THREE: Self = bf16::from_f32_const(3.0);
860 const TWOPI: Self = bf16::from_f32_const(f32::consts::PI * 2.0);
861 const FOURPI: Self = bf16::from_f32_const(f32::consts::PI * 4.0);
862 const POINT_TWO: Self = bf16::from_f32_const(0.2);
863 const FRAC_1_SQRT_2: Self = bf16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
864}