hpt_types/
dtype.rs

1use crate::{
2    into_vec::IntoVec,
3    type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
4    vectors::traits::VecTrait,
5};
6use core::f32;
7use half::{bf16, f16};
8use num_complex::{Complex32, Complex64};
9use std::fmt::Debug;
10
11/// trait for cuda type
12pub trait CudaType {
13    /// the cuda type
14    const CUDA_TYPE: &'static str;
15}
16
17impl CudaType for bool {
18    const CUDA_TYPE: &'static str = "bool";
19}
20
21impl CudaType for i8 {
22    const CUDA_TYPE: &'static str = "char";
23}
24
25impl CudaType for u8 {
26    const CUDA_TYPE: &'static str = "unsigned char";
27}
28
29impl CudaType for i16 {
30    const CUDA_TYPE: &'static str = "short";
31}
32
33impl CudaType for u16 {
34    const CUDA_TYPE: &'static str = "unsigned short";
35}
36
37impl CudaType for i32 {
38    const CUDA_TYPE: &'static str = "int";
39}
40
41impl CudaType for u32 {
42    const CUDA_TYPE: &'static str = "unsigned int";
43}
44
45#[cfg(target_os = "windows")]
46impl CudaType for i64 {
47    const CUDA_TYPE: &'static str = "long long";
48}
49
50#[cfg(not(target_os = "windows"))]
51impl CudaType for i64 {
52    const CUDA_TYPE: &'static str = "long";
53}
54
55#[cfg(target_os = "windows")]
56impl CudaType for u64 {
57    const CUDA_TYPE: &'static str = "unsigned long long";
58}
59
60#[cfg(not(target_os = "windows"))]
61impl CudaType for u64 {
62    const CUDA_TYPE: &'static str = "unsigned long";
63}
64
65impl CudaType for f32 {
66    const CUDA_TYPE: &'static str = "float";
67}
68
69impl CudaType for f64 {
70    const CUDA_TYPE: &'static str = "double";
71}
72
73impl CudaType for Complex32 {
74    const CUDA_TYPE: &'static str = "cuFloatComplex";
75}
76
77impl CudaType for Complex64 {
78    const CUDA_TYPE: &'static str = "cuDoubleComplex";
79}
80
81#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
82impl CudaType for isize {
83    const CUDA_TYPE: &'static str = "long long";
84}
85
86#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
87impl CudaType for isize {
88    const CUDA_TYPE: &'static str = "long";
89}
90
91#[cfg(target_pointer_width = "32")]
92impl CudaType for isize {
93    const CUDA_TYPE: &'static str = "int";
94}
95
96#[cfg(all(target_pointer_width = "64", target_os = "windows"))]
97impl CudaType for usize {
98    const CUDA_TYPE: &'static str = "unsigned long long";
99}
100
101#[cfg(all(target_pointer_width = "64", not(target_os = "windows")))]
102impl CudaType for usize {
103    const CUDA_TYPE: &'static str = "unsigned long";
104}
105
106#[cfg(target_pointer_width = "32")]
107impl CudaType for usize {
108    const CUDA_TYPE: &'static str = "unsigned int";
109}
110
111impl CudaType for f16 {
112    const CUDA_TYPE: &'static str = "__half";
113}
114
115impl CudaType for bf16 {
116    const CUDA_TYPE: &'static str = "__nv_bfloat16";
117}
118
119/// common trait for all data types
120///
121/// This trait is used to define the common properties of all data types
122pub trait TypeCommon
123where
124    Self: Sized + Copy,
125{
126    /// the maximum value of the data type
127    const MAX: Self;
128    /// the minimum value of the data type
129    const MIN: Self;
130    /// the zero value of the data type
131    const ZERO: Self;
132    /// the one value of the data type
133    const ONE: Self;
134    /// the infinity value of the data type, for integer types, it is the maximum value
135    const INF: Self;
136    /// the negative infinity value of the data type, for integer types, it is the minimum value
137    const NEG_INF: Self;
138    /// the two value of the data type
139    const TWO: Self;
140    /// the six value of the data type
141    const SIX: Self;
142    /// the ten value of the data type
143    const TEN: Self;
144    /// the string representation of the data type
145    const STR: &'static str;
146    /// the bit size of the data type, alias of `std::mem::size_of()`
147    const BYTE_SIZE: usize;
148    /// the simd vector type of the data type
149    type Vec: VecTrait<Self>
150        + Send
151        + Copy
152        + IntoVec<Self::Vec>
153        + std::ops::Index<usize, Output = Self>
154        + std::ops::IndexMut<usize>
155        + Sync
156        + Debug
157        + NormalOutUnary
158        + NormalOut<Self::Vec, Output = Self::Vec>
159        + FloatOutUnary
160        + FloatOutBinary
161        + FloatOutBinary<
162            <Self::Vec as FloatOutUnary>::Output,
163            Output = <Self::Vec as FloatOutUnary>::Output,
164        >;
165}
166
167macro_rules! impl_type_common {
168    (
169        $type:ty,
170        $max:expr,
171        $min:expr,
172        $zero:expr,
173        $one:expr,
174        $inf:expr,
175        $neg_inf:expr,
176        $two:expr,
177        $six:expr,
178        $ten:expr,
179        $str:expr,
180        $vec:ty,
181        $mask:ty
182    ) => {
183        impl std::ops::Index<usize> for $vec {
184            type Output = $type;
185            fn index(&self, index: usize) -> &Self::Output {
186                if index >= <$vec>::SIZE {
187                    panic!(
188                        "index out of bounds: the len is {} but the index is {}",
189                        <$vec>::SIZE,
190                        index
191                    );
192                }
193                unsafe { &*self.as_ptr().add(index) }
194            }
195        }
196        impl std::ops::IndexMut<usize> for $vec {
197            fn index_mut(&mut self, index: usize) -> &mut Self::Output {
198                if index >= <$vec>::SIZE {
199                    panic!(
200                        "index out of bounds: the len is {} but the index is {}",
201                        <$vec>::SIZE,
202                        index
203                    );
204                }
205                unsafe { &mut *self.as_mut_ptr().add(index) }
206            }
207        }
208        impl TypeCommon for $type {
209            const MAX: Self = $max;
210            const MIN: Self = $min;
211            const ZERO: Self = $zero;
212            const ONE: Self = $one;
213            const INF: Self = $inf;
214            const NEG_INF: Self = $neg_inf;
215            const TWO: Self = $two;
216            const SIX: Self = $six;
217            const TEN: Self = $ten;
218            const STR: &'static str = $str;
219            const BYTE_SIZE: usize = size_of::<$type>();
220            type Vec = $vec;
221        }
222    };
223}
224
225#[cfg(target_feature = "avx2")]
226mod type_impl {
227    use super::TypeCommon;
228    use crate::simd::_256bit::*;
229    use crate::vectors::traits::VecTrait;
230    use half::*;
231    use num_complex::{Complex32, Complex64};
232    impl_type_common!(
233        bool, true, false, false, true, true, false, false, true, true, "bool", boolx32, u8
234    );
235    impl_type_common!(
236        i8,
237        i8::MAX,
238        i8::MIN,
239        0,
240        1,
241        i8::MAX,
242        i8::MIN,
243        2,
244        6,
245        10,
246        "i8",
247        i8x32,
248        u8
249    );
250    impl_type_common!(
251        u8,
252        u8::MAX,
253        u8::MIN,
254        0,
255        1,
256        u8::MAX,
257        u8::MIN,
258        2,
259        6,
260        10,
261        "u8",
262        u8x32,
263        u8
264    );
265    impl_type_common!(
266        i16,
267        i16::MAX,
268        i16::MIN,
269        0,
270        1,
271        i16::MAX,
272        i16::MIN,
273        2,
274        6,
275        10,
276        "i16",
277        i16x16,
278        u16
279    );
280    impl_type_common!(
281        u16,
282        u16::MAX,
283        u16::MIN,
284        0,
285        1,
286        u16::MAX,
287        u16::MIN,
288        2,
289        6,
290        10,
291        "u16",
292        u16x16,
293        u16
294    );
295    impl_type_common!(
296        i32,
297        i32::MAX,
298        i32::MIN,
299        0,
300        1,
301        i32::MAX,
302        i32::MIN,
303        2,
304        6,
305        10,
306        "i32",
307        i32x8,
308        u32
309    );
310    impl_type_common!(
311        u32,
312        u32::MAX,
313        u32::MIN,
314        0,
315        1,
316        u32::MAX,
317        u32::MIN,
318        2,
319        6,
320        10,
321        "u32",
322        u32x8,
323        u32
324    );
325    impl_type_common!(
326        i64,
327        i64::MAX,
328        i64::MIN,
329        0,
330        1,
331        i64::MAX,
332        i64::MIN,
333        2,
334        6,
335        10,
336        "i64",
337        i64x4,
338        u64
339    );
340    impl_type_common!(
341        u64,
342        u64::MAX,
343        u64::MIN,
344        0,
345        1,
346        u64::MAX,
347        u64::MIN,
348        2,
349        6,
350        10,
351        "u64",
352        u64x4,
353        u64
354    );
355    impl_type_common!(
356        f32,
357        f32::MAX,
358        f32::MIN,
359        0.0,
360        1.0,
361        f32::INFINITY,
362        f32::NEG_INFINITY,
363        2.0,
364        6.0,
365        10.0,
366        "f32",
367        f32x8,
368        u32
369    );
370    impl_type_common!(
371        f64,
372        f64::MAX,
373        f64::MIN,
374        0.0,
375        1.0,
376        f64::INFINITY,
377        f64::NEG_INFINITY,
378        2.0,
379        6.0,
380        10.0,
381        "f64",
382        f64x4,
383        u64
384    );
385    #[cfg(target_pointer_width = "64")]
386    impl_type_common!(
387        isize,
388        isize::MAX,
389        isize::MIN,
390        0,
391        1,
392        isize::MAX,
393        isize::MIN,
394        2,
395        6,
396        10,
397        "isize",
398        isizex4,
399        usize
400    );
401    #[cfg(target_pointer_width = "32")]
402    impl_type_common!(
403        isize,
404        isize::MAX,
405        isize::MIN,
406        0,
407        1,
408        isize::MAX,
409        isize::MIN,
410        2,
411        6,
412        10,
413        "isize",
414        isizex8,
415        usize
416    );
417    #[cfg(target_pointer_width = "64")]
418    impl_type_common!(
419        usize,
420        usize::MAX,
421        usize::MIN,
422        0,
423        1,
424        usize::MAX,
425        usize::MIN,
426        2,
427        6,
428        10,
429        "usize",
430        usizex4,
431        usize
432    );
433    #[cfg(target_pointer_width = "32")]
434    impl_type_common!(
435        usize,
436        usize::MAX,
437        usize::MIN,
438        0,
439        1,
440        usize::MAX,
441        usize::MIN,
442        2,
443        6,
444        10,
445        "usize",
446        usizex8,
447        usize
448    );
449    impl_type_common!(
450        f16,
451        f16::MAX,
452        f16::MIN,
453        f16::ZERO,
454        f16::ONE,
455        f16::INFINITY,
456        f16::NEG_INFINITY,
457        f16::from_f32_const(2.0),
458        f16::from_f32_const(6.0),
459        f16::from_f32_const(10.0),
460        "f16",
461        f16x16,
462        u16
463    );
464    impl_type_common!(
465        bf16,
466        bf16::MAX,
467        bf16::MIN,
468        bf16::ZERO,
469        bf16::ONE,
470        bf16::INFINITY,
471        bf16::NEG_INFINITY,
472        bf16::from_f32_const(2.0),
473        bf16::from_f32_const(6.0),
474        bf16::from_f32_const(10.0),
475        "bf16",
476        bf16x16,
477        u16
478    );
479    impl_type_common!(
480        Complex32,
481        Complex32::new(f32::MAX, f32::MAX),
482        Complex32::new(f32::MIN, f32::MIN),
483        Complex32::new(0.0, 0.0),
484        Complex32::new(1.0, 0.0),
485        Complex32::new(f32::INFINITY, f32::INFINITY),
486        Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
487        Complex32::new(2.0, 0.0),
488        Complex32::new(6.0, 0.0),
489        Complex32::new(10.0, 0.0),
490        "c32",
491        cplx32x4,
492        (u32, u32)
493    );
494    impl_type_common!(
495        Complex64,
496        Complex64::new(f64::MAX, f64::MAX),
497        Complex64::new(f64::MIN, f64::MIN),
498        Complex64::new(0.0, 0.0),
499        Complex64::new(1.0, 0.0),
500        Complex64::new(f64::INFINITY, f64::INFINITY),
501        Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
502        Complex64::new(2.0, 0.0),
503        Complex64::new(6.0, 0.0),
504        Complex64::new(10.0, 0.0),
505        "c64",
506        cplx64x2,
507        (u64, u64)
508    );
509}
510
511#[cfg(all(
512    any(target_feature = "sse", target_arch = "arm", target_arch = "aarch64"),
513    not(target_feature = "avx2")
514))]
515mod type_impl {
516    use super::TypeCommon;
517    use crate::simd::_128bit::*;
518    use crate::vectors::traits::VecTrait;
519    use half::*;
520    use num_complex::{Complex32, Complex64};
521    impl_type_common!(
522        bool, true, false, false, true, true, false, false, true, true, "bool", boolx16, u8
523    );
524    impl_type_common!(
525        i8,
526        i8::MAX,
527        i8::MIN,
528        0,
529        1,
530        i8::MAX,
531        i8::MIN,
532        2,
533        6,
534        10,
535        "i8",
536        i8x16,
537        u8
538    );
539    impl_type_common!(
540        u8,
541        u8::MAX,
542        u8::MIN,
543        0,
544        1,
545        u8::MAX,
546        u8::MIN,
547        2,
548        6,
549        10,
550        "u8",
551        u8x16,
552        u8
553    );
554    impl_type_common!(
555        i16,
556        i16::MAX,
557        i16::MIN,
558        0,
559        1,
560        i16::MAX,
561        i16::MIN,
562        2,
563        6,
564        10,
565        "i16",
566        i16x8,
567        u16
568    );
569    impl_type_common!(
570        u16,
571        u16::MAX,
572        u16::MIN,
573        0,
574        1,
575        u16::MAX,
576        u16::MIN,
577        2,
578        6,
579        10,
580        "u16",
581        u16x8,
582        u16
583    );
584    impl_type_common!(
585        i32,
586        i32::MAX,
587        i32::MIN,
588        0,
589        1,
590        i32::MAX,
591        i32::MIN,
592        2,
593        6,
594        10,
595        "i32",
596        i32x4,
597        u32
598    );
599    impl_type_common!(
600        u32,
601        u32::MAX,
602        u32::MIN,
603        0,
604        1,
605        u32::MAX,
606        u32::MIN,
607        2,
608        6,
609        10,
610        "u32",
611        u32x4,
612        u32
613    );
614    impl_type_common!(
615        i64,
616        i64::MAX,
617        i64::MIN,
618        0,
619        1,
620        i64::MAX,
621        i64::MIN,
622        2,
623        6,
624        10,
625        "i64",
626        i64x2,
627        u64
628    );
629    impl_type_common!(
630        u64,
631        u64::MAX,
632        u64::MIN,
633        0,
634        1,
635        u64::MAX,
636        u64::MIN,
637        2,
638        6,
639        10,
640        "u64",
641        u64x2,
642        u64
643    );
644    impl_type_common!(
645        f32,
646        f32::MAX,
647        f32::MIN,
648        0.0,
649        1.0,
650        f32::INFINITY,
651        f32::NEG_INFINITY,
652        2.0,
653        6.0,
654        10.0,
655        "f32",
656        f32x4,
657        u32
658    );
659    impl_type_common!(
660        f64,
661        f64::MAX,
662        f64::MIN,
663        0.0,
664        1.0,
665        f64::INFINITY,
666        f64::NEG_INFINITY,
667        2.0,
668        6.0,
669        10.0,
670        "f64",
671        f64x2,
672        u64
673    );
674    #[cfg(target_pointer_width = "64")]
675    impl_type_common!(
676        isize,
677        isize::MAX,
678        isize::MIN,
679        0,
680        1,
681        isize::MAX,
682        isize::MIN,
683        2,
684        6,
685        10,
686        "isize",
687        isizex2,
688        u64
689    );
690    #[cfg(target_pointer_width = "32")]
691    impl_type_common!(
692        isize,
693        isize::MAX,
694        isize::MIN,
695        0,
696        1,
697        isize::MAX,
698        isize::MIN,
699        2,
700        6,
701        10,
702        "isize",
703        "int",
704        isizex4,
705        u32
706    );
707    #[cfg(target_pointer_width = "64")]
708    impl_type_common!(
709        usize,
710        usize::MAX,
711        usize::MIN,
712        0,
713        1,
714        usize::MAX,
715        usize::MIN,
716        2,
717        6,
718        10,
719        "usize",
720        usizex2,
721        usize
722    );
723    #[cfg(target_pointer_width = "32")]
724    impl_type_common!(
725        usize,
726        usize::MAX,
727        usize::MIN,
728        0,
729        1,
730        usize::MAX,
731        usize::MIN,
732        2,
733        6,
734        10,
735        "usize",
736        "unsigned int",
737        usizex4,
738        usize
739    );
740    impl_type_common!(
741        f16,
742        f16::MAX,
743        f16::MIN,
744        f16::ZERO,
745        f16::ONE,
746        f16::INFINITY,
747        f16::NEG_INFINITY,
748        f16::from_f32_const(2.0),
749        f16::from_f32_const(6.0),
750        f16::from_f32_const(10.0),
751        "f16",
752        f16x8,
753        u16
754    );
755    impl_type_common!(
756        bf16,
757        bf16::MAX,
758        bf16::MIN,
759        bf16::ZERO,
760        bf16::ONE,
761        bf16::INFINITY,
762        bf16::NEG_INFINITY,
763        bf16::from_f32_const(2.0),
764        bf16::from_f32_const(6.0),
765        bf16::from_f32_const(10.0),
766        "bf16",
767        bf16x8,
768        u16
769    );
770    impl_type_common!(
771        Complex32,
772        Complex32::new(f32::MAX, f32::MAX),
773        Complex32::new(f32::MIN, f32::MIN),
774        Complex32::new(0.0, 0.0),
775        Complex32::new(1.0, 0.0),
776        Complex32::new(f32::INFINITY, f32::INFINITY),
777        Complex32::new(f32::NEG_INFINITY, f32::NEG_INFINITY),
778        Complex32::new(2.0, 0.0),
779        Complex32::new(6.0, 0.0),
780        Complex32::new(10.0, 0.0),
781        "c32",
782        cplx32x2,
783        (u32, u32)
784    );
785    impl_type_common!(
786        Complex64,
787        Complex64::new(f64::MAX, f64::MAX),
788        Complex64::new(f64::MIN, f64::MIN),
789        Complex64::new(0.0, 0.0),
790        Complex64::new(1.0, 0.0),
791        Complex64::new(f64::INFINITY, f64::INFINITY),
792        Complex64::new(f64::NEG_INFINITY, f64::NEG_INFINITY),
793        Complex64::new(2.0, 0.0),
794        Complex64::new(6.0, 0.0),
795        Complex64::new(10.0, 0.0),
796        "c64",
797        cplx64x1,
798        (u64, u64)
799    );
800}
801
802/// constant values for floating point data types
803pub trait FloatConst {
804    /// 0.5
805    const HALF: Self;
806    /// e
807    const E: Self;
808    /// π
809    const PI: Self;
810    /// 3.0
811    const THREE: Self;
812    /// 2π
813    const TWOPI: Self;
814    /// 4π
815    const FOURPI: Self;
816    /// 0.2
817    const POINT_TWO: Self;
818    /// 1/√2
819    const FRAC_1_SQRT_2: Self;
820}
821
822impl FloatConst for f32 {
823    const HALF: Self = 0.5;
824    const E: Self = f32::consts::E;
825    const PI: Self = f32::consts::PI;
826    const THREE: Self = 3.0;
827    const TWOPI: Self = f32::consts::PI * 2.0;
828    const FOURPI: Self = f32::consts::PI * 4.0;
829    const POINT_TWO: Self = 0.2;
830    const FRAC_1_SQRT_2: Self = f32::consts::FRAC_1_SQRT_2;
831}
832
833impl FloatConst for f64 {
834    const HALF: Self = 0.5;
835    const E: Self = std::f64::consts::E;
836    const PI: Self = std::f64::consts::PI;
837    const THREE: Self = 3.0;
838    const TWOPI: Self = std::f64::consts::PI * 2.0;
839    const FOURPI: Self = std::f64::consts::PI * 4.0;
840    const POINT_TWO: Self = 0.2;
841    const FRAC_1_SQRT_2: Self = std::f64::consts::FRAC_1_SQRT_2;
842}
843
844impl FloatConst for f16 {
845    const HALF: Self = f16::from_f32_const(0.5);
846    const E: Self = f16::from_f32_const(f32::consts::E);
847    const PI: Self = f16::from_f32_const(f32::consts::PI);
848    const THREE: Self = f16::from_f32_const(3.0);
849    const TWOPI: Self = f16::from_f32_const(f32::consts::PI * 2.0);
850    const FOURPI: Self = f16::from_f32_const(f32::consts::PI * 4.0);
851    const POINT_TWO: Self = f16::from_f32_const(0.2);
852    const FRAC_1_SQRT_2: Self = f16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
853}
854
855impl FloatConst for bf16 {
856    const HALF: Self = bf16::from_f32_const(0.5);
857    const E: Self = bf16::from_f32_const(f32::consts::E);
858    const PI: Self = bf16::from_f32_const(f32::consts::PI);
859    const THREE: Self = bf16::from_f32_const(3.0);
860    const TWOPI: Self = bf16::from_f32_const(f32::consts::PI * 2.0);
861    const FOURPI: Self = bf16::from_f32_const(f32::consts::PI * 4.0);
862    const POINT_TWO: Self = bf16::from_f32_const(0.2);
863    const FRAC_1_SQRT_2: Self = bf16::from_f32_const(f32::consts::FRAC_1_SQRT_2);
864}