hpt_types/
dtype.rs

1#[cfg(target_feature = "avx512f")]
2use crate::vectors::_512bit::*;
3use crate::{
4    into_vec::IntoVec,
5    type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
6    vectors::traits::VecTrait,
7};
8use core::f32;
9use half::{bf16, f16};
10use num_complex::{Complex32, Complex64};
11use std::fmt::Debug;
12
13/// trait for cuda type
14pub trait CudaType {
15    /// the cuda type
16    const CUDA_TYPE: &'static str;
17}
18
19impl CudaType for bool {
20    const CUDA_TYPE: &'static str = "bool";
21}
22
23impl CudaType for i8 {
24    const CUDA_TYPE: &'static str = "char";
25}
26
27impl CudaType for u8 {
28    const CUDA_TYPE: &'static str = "unsigned char";
29}
30
31impl CudaType for i16 {
32    const CUDA_TYPE: &'static str = "short";
33}
34
35impl CudaType for u16 {
36    const CUDA_TYPE: &'static str = "unsigned short";
37}
38
39impl CudaType for i32 {
40    const CUDA_TYPE: &'static str = "int";
41}
42
43impl CudaType for u32 {
44    const CUDA_TYPE: &'static str = "unsigned int";
45}
46
47#[cfg(target_os = "windows")]
48impl CudaType for i64 {
49    const CUDA_TYPE: &'static str = "long long";
50}
51
52#[cfg(not(target_os = "windows"))]
53impl CudaType for i64 {
54    const CUDA_TYPE: &'static str = "long";
55}
56
57#[cfg(target_os = "windows")]
58impl CudaType for u64 {
59    const CUDA_TYPE: &'static str = "unsigned long long";
60}
61
62#[cfg(not(target_os = "windows"))]
63impl CudaType for u64 {
64    const CUDA_TYPE: &'static str = "unsigned long";
65}
66
67impl CudaType for f32 {
68    const CUDA_TYPE: &'static str = "float";
69}
70
71impl CudaType for f64 {
72    const CUDA_TYPE: &'static str = "double";
73}
74
75impl CudaType for Complex32 {
76    const CUDA_TYPE: &'static str = "cuFloatComplex";
77}
78
79impl CudaType for Complex64 {
80    const CUDA_TYPE: &'static str = "cuDoubleComplex";
81}
82
83#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
84impl CudaType for isize {
85    const CUDA_TYPE: &'static str = "long long";
86}
87
88#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
89impl CudaType for isize {
90    const CUDA_TYPE: &'static str = "long";
91}
92
93#[cfg(target_pointer_width = "32")]
94impl CudaType for isize {
95    const CUDA_TYPE: &'static str = "int";
96}
97
98#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
99impl CudaType for usize {
100    const CUDA_TYPE: &'static str = "unsigned long long";
101}
102
103#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
104impl CudaType for usize {
105    const CUDA_TYPE: &'static str = "unsigned long";
106}
107
108#[cfg(target_pointer_width = "32")]
109impl CudaType for usize {
110    const CUDA_TYPE: &'static str = "unsigned int";
111}
112
113impl CudaType for f16 {
114    const CUDA_TYPE: &'static str = "__half";
115}
116
117impl CudaType for bf16 {
118    const CUDA_TYPE: &'static str = "__nv_bfloat16";
119}
120
121/// common trait for all data types
122///
123/// This trait is used to define the common properties of all data types
124pub trait TypeCommon
125where
126    Self: Sized + Copy,
127{
128    /// the maximum value of the data type
129    const MAX: Self;
130    /// the minimum value of the data type
131    const MIN: Self;
132    /// the zero value of the data type
133    const ZERO: Self;
134    /// the one value of the data type
135    const ONE: Self;
136    /// the infinity value of the data type, for integer types, it is the maximum value
137    const INF: Self;
138    /// the negative infinity value of the data type, for integer types, it is the minimum value
139    const NEG_INF: Self;
140    /// the two value of the data type
141    const TWO: Self;
142    /// the six value of the data type
143    const SIX: Self;
144    /// the ten value of the data type
145    const TEN: Self;
146    /// the string representation of the data type
147    const STR: &'static str;
148    /// the bit size of the data type, alias of `std::mem::size_of()`
149    const BIT_SIZE: usize;
150    /// the simd vector type of the data type
151    type Vec: VecTrait<Self>
152        + Send
153        + Copy
154        + IntoVec<Self::Vec>
155        + std::ops::Index<usize, Output = Self>
156        + std::ops::IndexMut<usize>
157        + Sync
158        + Debug
159        + NormalOutUnary
160        + NormalOut<Self::Vec, Output = Self::Vec>
161        + FloatOutUnary
162        + FloatOutBinary
163        + FloatOutBinary<
164            <Self::Vec as FloatOutUnary>::Output,
165            Output = <Self::Vec as FloatOutUnary>::Output,
166        >;
167}
168
169macro_rules! impl_type_common {
170    (
171        $type:ty,
172        $max:expr,
173        $min:expr,
174        $zero:expr,
175        $one:expr,
176        $inf:expr,
177        $neg_inf:expr,
178        $two:expr,
179        $six:expr,
180        $ten:expr,
181        $str:expr,
182        $vec:ty,
183        $mask:ty
184    ) => {
185        impl std::ops::Index<usize> for $vec {
186            type Output = $type;
187            fn index(&self, index: usize) -> &Self::Output {
188                if index >= <$vec>::SIZE {
189                    panic!(
190                        "index out of bounds: the len is {} but the index is {}",
191                        <$vec>::SIZE,
192                        index
193                    );
194                }
195                unsafe { &*self.as_ptr().add(index) }
196            }
197        }
198        impl std::ops::IndexMut<usize> for $vec {
199            fn index_mut(&mut self, index: usize) -> &mut Self::Output {
200                if index >= <$vec>::SIZE {
201                    panic!(
202                        "index out of bounds: the len is {} but the index is {}",
203                        <$vec>::SIZE,
204                        index
205                    );
206                }
207                unsafe { &mut *self.as_mut_ptr().add(index) }
208            }
209        }
210        impl TypeCommon for $type {
211            const MAX: Self = $max;
212            const MIN: Self = $min;
213            const ZERO: Self = $zero;
214            const ONE: Self = $one;
215            const INF: Self = $inf;
216            const NEG_INF: Self = $neg_inf;
217            const TWO: Self = $two;
218            const SIX: Self = $six;
219            const TEN: Self = $ten;
220            const STR: &'static str = $str;
221            const BIT_SIZE: usize = size_of::<$type>();
222            type Vec = $vec;
223        }
224    };
225}
226
227#[cfg(target_feature = "avx2")]
228mod type_impl {
229    use super::TypeCommon;
230    use crate::simd::_256bit::*;
231    use crate::vectors::traits::VecTrait;
232    use half::*;
233    use num_complex::{Complex32, Complex64};
234    impl_type_common!(
235        bool,
236        true,
237        false,
238        false,
239        true,
240        true,
241        false,
242        false,
243        true,
244        true,
245        "bool",
246        boolx32::boolx32,
247        u8
248    );
249    impl_type_common!(
250        i8,
251        i8::MAX,
252        i8::MIN,
253        0,
254        1,
255        i8::MAX,
256        i8::MIN,
257        2,
258        6,
259        10,
260        "i8",
261        i8x32::i8x32,
262        u8
263    );
264    impl_type_common!(
265        u8,
266        u8::MAX,
267        u8::MIN,
268        0,
269        1,
270        u8::MAX,
271        u8::MIN,
272        2,
273        6,
274        10,
275        "u8",
276        u8x32::u8x32,
277        u8
278    );
279    impl_type_common!(
280        i16,
281        i16::MAX,
282        i16::MIN,
283        0,
284        1,
285        i16::MAX,
286        i16::MIN,
287        2,
288        6,
289        10,
290        "i16",
291        i16x16::i16x16,
292        u16
293    );
294    impl_type_common!(
295        u16,
296        u16::MAX,
297        u16::MIN,
298        0,
299        1,
300        u16::MAX,
301        u16::MIN,
302        2,
303        6,
304        10,
305        "u16",
306        u16x16::u16x16,
307        u16
308    );
309    impl_type_common!(
310        i32,
311        i32::MAX,
312        i32::MIN,
313        0,
314        1,
315        i32::MAX,
316        i32::MIN,
317        2,
318        6,
319        10,
320        "i32",
321        i32x8::i32x8,
322        u32
323    );
324    impl_type_common!(
325        u32,
326        u32::MAX,
327        u32::MIN,
328        0,
329        1,
330        u32::MAX,
331        u32::MIN,
332        2,
333        6,
334        10,
335        "u32",
336        u32x8::u32x8,
337        u32
338    );
339    impl_type_common!(
340        i64,
341        i64::MAX,
342        i64::MIN,
343        0,
344        1,
345        i64::MAX,
346        i64::MIN,
347        2,
348        6,
349        10,
350        "i64",
351        i64x4::i64x4,
352        u64
353    );
354    impl_type_common!(
355        u64,
356        u64::MAX,
357        u64::MIN,
358        0,
359        1,
360        u64::MAX,
361        u64::MIN,
362        2,
363        6,
364        10,
365        "u64",
366        u64x4::u64x4,
367        u64
368    );
369    impl_type_common!(
370        f32,
371        f32::MAX,
372        f32::MIN,
373        0.0,
374        1.0,
375        f32::INFINITY,
376        f32::NEG_INFINITY,
377        2.0,
378        6.0,
379        10.0,
380        "f32",
381        f32x8::f32x8,
382        u32
383    );
384    impl_type_common!(
385        f64,
386        f64::MAX,
387        f64::MIN,
388        0.0,
389        1.0,
390        f64::INFINITY,
391        f64::NEG_INFINITY,
392        2.0,
393        6.0,
394        10.0,
395        "f64",
396        f64x4::f64x4,
397        u64
398    );
399    #[cfg(target_pointer_width = "64")]
400    impl_type_common!(
401        isize,
402        isize::MAX,
403        isize::MIN,
404        0,
405        1,
406        isize::MAX,
407        isize::MIN,
408        2,
409        6,
410        10,
411        "isize",
412        isizex4::isizex4,
413        usize
414    );
415    #[cfg(target_pointer_width = "32")]
416    impl_type_common!(
417        isize,
418        isize::MAX,
419        isize::MIN,
420        0,
421        1,
422        isize::MAX,
423        isize::MIN,
424        2,
425        6,
426        10,
427        "isize",
428        isizex8::isizex8,
429        usize
430    );
431    #[cfg(target_pointer_width = "64")]
432    impl_type_common!(
433        usize,
434        usize::MAX,
435        usize::MIN,
436        0,
437        1,
438        usize::MAX,
439        usize::MIN,
440        2,
441        6,
442        10,
443        "usize",
444        usizex4::usizex4,
445        usize
446    );
447    #[cfg(target_pointer_width = "32")]
448    impl_type_common!(
449        usize,
450        usize::MAX,
451        usize::MIN,
452        0,
453        1,
454        usize::MAX,
455        usize::MIN,
456        2,
457        6,
458        10,
459        "usize",
460        usizex8::usizex8,
461        usize
462    );
463    impl_type_common!(
464        f16,
465        f16::MAX,
466        f16::MIN,
467        f16::ZERO,
468        f16::ONE,
469        f16::INFINITY,
470        f16::NEG_INFINITY,
471        f16::from_f32_const(2.0),
472        f16::from_f32_const(6.0),
473        f16::from_f32_const(10.0),
474        "f16",
475        f16x16::f16x16,
476        u16
477    );
478    impl_type_common!(
479        bf16,
480        bf16::MAX,
481        bf16::MIN,
482        bf16::ZERO,
483        bf16::ONE,
484        bf16::INFINITY,
485        bf16::NEG_INFINITY,
486        bf16::from_f32_const(2.0),
487        bf16::from_f32_const(6.0),
488        bf16::from_f32_const(10.0),
489        "bf16",
490        bf16x16::bf16x16,
491        u16
492    );
493    impl_type_common!(
494        Complex32,
495        Complex32::new(f32::MAX, f32::MAX),
496        Complex32::new(f32::MIN, f32::MIN),
497        Complex32::new(0.0, 0.0),
498        Complex32::new(1.0, 0.0),
499        Complex32::new(f32::INFINITY, f32::INFINITY),
500        Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
501        Complex32::new(2.0, 0.0),
502        Complex32::new(6.0, 0.0),
503        Complex32::new(10.0, 0.0),
504        "c32",
505        cplx32x4::cplx32x4,
506        (u32, u32)
507    );
508    impl_type_common!(
509        Complex64,
510        Complex64::new(f64::MAX, f64::MAX),
511        Complex64::new(f64::MIN, f64::MIN),
512        Complex64::new(0.0, 0.0),
513        Complex64::new(1.0, 0.0),
514        Complex64::new(f64::INFINITY, f64::INFINITY),
515        Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
516        Complex64::new(2.0, 0.0),
517        Complex64::new(6.0, 0.0),
518        Complex64::new(10.0, 0.0),
519        "c64",
520        cplx64x2::cplx64x2,
521        (u64, u64)
522    );
523}
524
525#[cfg(all(
526    any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
527    not(target_feature = "avx2")
528))]
529mod type_impl {
530    use super::TypeCommon;
531    use crate::simd::_128bit::*;
532    use crate::vectors::traits::VecTrait;
533    use half::*;
534    use num_complex::{Complex32, Complex64};
535    impl_type_common!(
536        bool,
537        true,
538        false,
539        false,
540        true,
541        true,
542        false,
543        false,
544        true,
545        true,
546        "bool",
547        boolx16::boolx16,
548        u8
549    );
550    impl_type_common!(
551        i8,
552        i8::MAX,
553        i8::MIN,
554        0,
555        1,
556        i8::MAX,
557        i8::MIN,
558        2,
559        6,
560        10,
561        "i8",
562        i8x16::i8x16,
563        u8
564    );
565    impl_type_common!(
566        u8,
567        u8::MAX,
568        u8::MIN,
569        0,
570        1,
571        u8::MAX,
572        u8::MIN,
573        2,
574        6,
575        10,
576        "u8",
577        u8x16::u8x16,
578        u8
579    );
580    impl_type_common!(
581        i16,
582        i16::MAX,
583        i16::MIN,
584        0,
585        1,
586        i16::MAX,
587        i16::MIN,
588        2,
589        6,
590        10,
591        "i16",
592        i16x8::i16x8,
593        u16
594    );
595    impl_type_common!(
596        u16,
597        u16::MAX,
598        u16::MIN,
599        0,
600        1,
601        u16::MAX,
602        u16::MIN,
603        2,
604        6,
605        10,
606        "u16",
607        u16x8::u16x8,
608        u16
609    );
610    impl_type_common!(
611        i32,
612        i32::MAX,
613        i32::MIN,
614        0,
615        1,
616        i32::MAX,
617        i32::MIN,
618        2,
619        6,
620        10,
621        "i32",
622        i32x4::i32x4,
623        u32
624    );
625    impl_type_common!(
626        u32,
627        u32::MAX,
628        u32::MIN,
629        0,
630        1,
631        u32::MAX,
632        u32::MIN,
633        2,
634        6,
635        10,
636        "u32",
637        u32x4::u32x4,
638        u32
639    );
640    impl_type_common!(
641        i64,
642        i64::MAX,
643        i64::MIN,
644        0,
645        1,
646        i64::MAX,
647        i64::MIN,
648        2,
649        6,
650        10,
651        "i64",
652        i64x2::i64x2,
653        u64
654    );
655    impl_type_common!(
656        u64,
657        u64::MAX,
658        u64::MIN,
659        0,
660        1,
661        u64::MAX,
662        u64::MIN,
663        2,
664        6,
665        10,
666        "u64",
667        u64x2::u64x2,
668        u64
669    );
670    impl_type_common!(
671        f32,
672        f32::MAX,
673        f32::MIN,
674        0.0,
675        1.0,
676        f32::INFINITY,
677        f32::NEG_INFINITY,
678        2.0,
679        6.0,
680        10.0,
681        "f32",
682        f32x4::f32x4,
683        u32
684    );
685    impl_type_common!(
686        f64,
687        f64::MAX,
688        f64::MIN,
689        0.0,
690        1.0,
691        f64::INFINITY,
692        f64::NEG_INFINITY,
693        2.0,
694        6.0,
695        10.0,
696        "f64",
697        f64x2::f64x2,
698        u64
699    );
700    #[cfg(target_pointer_width = "64")]
701    impl_type_common!(
702        isize,
703        isize::MAX,
704        isize::MIN,
705        0,
706        1,
707        isize::MAX,
708        isize::MIN,
709        2,
710        6,
711        10,
712        "isize",
713        isizex2::isizex2,
714        u64
715    );
716    #[cfg(target_pointer_width = "32")]
717    impl_type_common!(
718        isize,
719        isize::MAX,
720        isize::MIN,
721        0,
722        1,
723        isize::MAX,
724        isize::MIN,
725        2,
726        6,
727        10,
728        "isize",
729        "int",
730        isizex4::isizex4,
731        u32
732    );
733    #[cfg(target_pointer_width = "64")]
734    impl_type_common!(
735        usize,
736        usize::MAX,
737        usize::MIN,
738        0,
739        1,
740        usize::MAX,
741        usize::MIN,
742        2,
743        6,
744        10,
745        "usize",
746        usizex2::usizex2,
747        usize
748    );
749    #[cfg(target_pointer_width = "32")]
750    impl_type_common!(
751        usize,
752        usize::MAX,
753        usize::MIN,
754        0,
755        1,
756        usize::MAX,
757        usize::MIN,
758        2,
759        6,
760        10,
761        "usize",
762        "unsigned int",
763        usizex4::usizex4,
764        usize
765    );
766    impl_type_common!(
767        f16,
768        f16::MAX,
769        f16::MIN,
770        f16::ZERO,
771        f16::ONE,
772        f16::INFINITY,
773        f16::NEG_INFINITY,
774        f16::from_f32_const(2.0),
775        f16::from_f32_const(6.0),
776        f16::from_f32_const(10.0),
777        "f16",
778        f16x8::f16x8,
779        u16
780    );
781    impl_type_common!(
782        bf16,
783        bf16::MAX,
784        bf16::MIN,
785        bf16::ZERO,
786        bf16::ONE,
787        bf16::INFINITY,
788        bf16::NEG_INFINITY,
789        bf16::from_f32_const(2.0),
790        bf16::from_f32_const(6.0),
791        bf16::from_f32_const(10.0),
792        "bf16",
793        bf16x8::bf16x8,
794        u16
795    );
796    impl_type_common!(
797        Complex32,
798        Complex32::new(f32::MAX, f32::MAX),
799        Complex32::new(f32::MIN, f32::MIN),
800        Complex32::new(0.0, 0.0),
801        Complex32::new(1.0, 0.0),
802        Complex32::new(f32::INFINITY, f32::INFINITY),
803        Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
804        Complex32::new(2.0, 0.0),
805        Complex32::new(6.0, 0.0),
806        Complex32::new(10.0, 0.0),
807        "c32",
808        cplx32x2::cplx32x2,
809        (u32, u32)
810    );
811    impl_type_common!(
812        Complex64,
813        Complex64::new(f64::MAX, f64::MAX),
814        Complex64::new(f64::MIN, f64::MIN),
815        Complex64::new(0.0, 0.0),
816        Complex64::new(1.0, 0.0),
817        Complex64::new(f64::INFINITY, f64::INFINITY),
818        Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
819        Complex64::new(2.0, 0.0),
820        Complex64::new(6.0, 0.0),
821        Complex64::new(10.0, 0.0),
822        "c64",
823        cplx64x1::cplx64x1,
824        (u64, u64)
825    );
826}
827
828/// constant values for floating point data types
829pub trait FloatConst {
830    /// 0.5
831    const HALF: Self;
832    /// e
833    const E: Self;
834    /// π
835    const PI: Self;
836    /// 3.0
837    const THREE: Self;
838    /// 2π
839    const TWOPI: Self;
840    /// 4π
841    const FOURPI: Self;
842    /// 0.2
843    const POINT_TWO: Self;
844    /// 1/√2
845    const FRAC_1_SQRT_2: Self;
846}
847
848impl FloatConst for f32 {
849    const HALF: Self = 0.5;
850    const E: Self = f32::consts::E;
851    const PI: Self = f32::consts::PI;
852    const THREE: Self = 3.0;
853    const TWOPI: Self = f32::consts::PI * 2.0;
854    const FOURPI: Self = f32::consts::PI * 4.0;
855    const POINT_TWO: Self = 0.2;
856    const FRAC_1_SQRT_2: Self = f32::consts::FRAC_1_SQRT_2;
857}
858
859impl FloatConst for f64 {
860    const HALF: Self = 0.5;
861    const E: Self = std::f64::consts::E;
862    const PI: Self = std::f64::consts::PI;
863    const THREE: Self = 3.0;
864    const TWOPI: Self = std::f64::consts::PI * 2.0;
865    const FOURPI: Self = std::f64::consts::PI * 4.0;
866    const POINT_TWO: Self = 0.2;
867    const FRAC_1_SQRT_2: Self = std::f64::consts::FRAC_1_SQRT_2;
868}
869
870impl FloatConst for f16 {
871    const HALF: Self = f16::from_f32_const(0.5);
872    const E: Self = f16::from_f32_const(f32::consts::E);
873    const PI: Self = f16::from_f32_const(f32::consts::PI);
874    const THREE: Self = f16::from_f32_const(3.0);
875    const TWOPI: Self = f16::from_f32_const(f32::consts::PI * 2.0);
876    const FOURPI: Self = f16::from_f32_const(f32::consts::PI * 4.0);
877    const POINT_TWO: Self = f16::from_f32_const(0.2);
878    const FRAC_1_SQRT_2: Self = f16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
879}
880
881impl FloatConst for bf16 {
882    const HALF: Self = bf16::from_f32_const(0.5);
883    const E: Self = bf16::from_f32_const(f32::consts::E);
884    const PI: Self = bf16::from_f32_const(f32::consts::PI);
885    const THREE: Self = bf16::from_f32_const(3.0);
886    const TWOPI: Self = bf16::from_f32_const(f32::consts::PI * 2.0);
887    const FOURPI: Self = bf16::from_f32_const(f32::consts::PI * 4.0);
888    const POINT_TWO: Self = bf16::from_f32_const(0.2);
889    const FRAC_1_SQRT_2: Self = bf16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
890}