hpt_types/
lib.rs

1//! This crate implment type utilities for tensor operations
2#![deny(missing_docs)]
3
4/// A module implement type conversion
5pub mod convertion;
6/// A module defines a set of data types and utilities
7pub mod dtype;
8/// A module implement type conversion
9pub mod into_scalar;
10/// A module implement simd vector conversion
11pub mod into_vec;
12/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
13pub mod type_promote;
14/// A module defines a set of traits for scalar operations
15pub(crate) mod scalars {
16    pub(crate) mod _bf16;
17    pub(crate) mod _bool;
18    pub(crate) mod _f16;
19    pub(crate) mod _f32;
20    pub(crate) mod _f64;
21    pub(crate) mod impls;
22}
23
24/// A module defines a set of traits for type promotion
25pub mod promotion {
26    #[cfg(feature = "normal_promote")]
27    pub(crate) mod normal_promote {
28        pub(crate) mod _bf16;
29        pub(crate) mod _bool;
30        pub(crate) mod _cplx32;
31        pub(crate) mod _cplx64;
32        pub(crate) mod _f16;
33        pub(crate) mod _f32;
34        pub(crate) mod _f64;
35        pub(crate) mod _i16;
36        pub(crate) mod _i32;
37        pub(crate) mod _i64;
38        pub(crate) mod _i8;
39        pub(crate) mod _isize;
40        pub(crate) mod _u16;
41        pub(crate) mod _u32;
42        pub(crate) mod _u64;
43        pub(crate) mod _u8;
44        pub(crate) mod _usize;
45    }
46    pub(crate) mod utils;
47}
48
49/// A module defines a set of vector types
50pub mod vectors {
51    /// A module defines a set of vector types using stdsimd
52    pub mod arch_simd {
53        /// A module defines a set of 128-bit vector types
54        #[cfg(any(
55            all(not(target_feature = "avx2"), target_feature = "sse"),
56            target_arch = "arm",
57            target_arch = "aarch64",
58            target_feature = "neon"
59        ))]
60        pub mod _128bit {
61            /// A module defines a set of 128-bit vector types for bf16
62            pub mod bf16x8;
63            /// A module defines a set of 128-bit vector types for bool
64            pub mod boolx16;
65            /// A module defines a set of 128-bit vector types for cplx32
66            pub mod cplx32x2;
67            /// A module defines a set of 128-bit vector types for cplx64
68            pub mod cplx64x1;
69            /// A module defines a set of 128-bit vector types for f16
70            pub mod f16x8;
71            /// A module defines a set of 128-bit vector types for f32
72            pub mod f32x4;
73            /// A module defines a set of 128-bit vector types for f64
74            pub mod f64x2;
75            /// A module defines a set of 128-bit vector types for i16
76            pub mod i16x8;
77            /// A module defines a set of 128-bit vector types for i32
78            pub mod i32x4;
79            /// A module defines a set of 128-bit vector types for i64
80            pub mod i64x2;
81            /// A module defines a set of 128-bit vector types for i8
82            pub mod i8x16;
83            /// A module defines a set of 128-bit vector types for isize
84            pub mod isizex2;
85            /// A module defines a set of 128-bit vector types for u16
86            pub mod u16x8;
87            /// A module defines a set of 128-bit vector types for u32
88            pub mod u32x4;
89            /// A module defines a set of 128-bit vector types for u64
90            pub mod u64x2;
91            /// A module defines a set of 128-bit vector types for u8
92            pub mod u8x16;
93            /// A module defines a set of 128-bit vector types for usize
94            pub mod usizex2;
95        }
96        /// A module defines a set of 256-bit vector types
97        #[cfg(target_feature = "avx2")]
98        pub mod _256bit {
99            /// A module defines a set of 256-bit vector types for bf16
100            pub mod bf16x16;
101            /// A module defines a set of 256-bit vector types for bool
102            pub mod boolx32;
103            /// A module defines a set of 256-bit vector types for cplx32
104            pub mod cplx32x4;
105            /// A module defines a set of 256-bit vector types for cplx64
106            pub mod cplx64x2;
107            /// A module defines a set of 256-bit vector types for f16
108            pub mod f16x16;
109            /// A module defines a set of 256-bit vector types for f32
110            pub mod f32x8;
111            /// A module defines a set of 256-bit vector types for f64
112            pub mod f64x4;
113            /// A module defines a set of 256-bit vector types for i16
114            pub mod i16x16;
115            /// A module defines a set of 256-bit vector types for i32
116            pub mod i32x8;
117            /// A module defines a set of 256-bit vector types for i64
118            pub mod i64x4;
119            /// A module defines a set of 256-bit vector types for i8
120            pub mod i8x32;
121            /// A module defines a set of 256-bit vector types for isize
122            pub mod isizex4;
123            /// A module defines a set of 256-bit vector types for u16
124            pub mod u16x16;
125            /// A module defines a set of 256-bit vector types for u32
126            pub mod u32x8;
127            /// A module defines a set of 256-bit vector types for u64
128            pub mod u64x4;
129            /// A module defines a set of 256-bit vector types for u8
130            pub mod u8x32;
131            /// A module defines a set of 256-bit vector types for usize
132            pub mod usizex4;
133        }
134
135        // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
136        //
137        // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
138        // Modified work Copyright (c) 2024 hpt Contributors
139        //
140        // Boost Software License - Version 1.0 - August 17th, 2003
141        //
142        // Permission is hereby granted, free of charge, to any person or organization
143        // obtaining a copy of the software and accompanying documentation covered by
144        // this license (the "Software") to use, reproduce, display, distribute,
145        // execute, and transmit the Software, and to prepare derivative works of the
146        // Software, and to permit third-parties to whom the Software is furnished to
147        // do so, all subject to the following:
148        //
149        // The copyright notices in the Software and this entire statement, including
150        // the above license grant, this restriction and the following disclaimer,
151        // must be included in all copies of the Software, in whole or in part, and
152        // all derivative works of the Software, unless such copies or derivative
153        // works are solely in the form of machine-executable object code generated by
154        // a source language processor.
155        //
156        // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
157        // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
158        // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
159        // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
160        // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
161        // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
162        // DEALINGS IN THE SOFTWARE.
163        //
164        // This Rust port is additionally licensed under Apache-2.0 OR MIT
165        // See repository root for details
166        /// A module defines a set of vector types for sleef
167        pub mod sleef {
168            /// A module defines a set of vector types for table
169            pub mod table;
170            /// A module defines a set of vector types for helper
171            pub mod arch {
172                /// A module defines a set of vector types for helper
173                #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
174                pub mod helper_aarch64;
175                /// A module defines a set of vector types for helper
176                #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
177                pub mod helper_avx2;
178                /// A module defines a set of vector types for helper
179                #[cfg(all(
180                    target_arch = "x86_64",
181                    target_feature = "sse",
182                    not(target_feature = "avx2")
183                ))]
184                pub mod helper_sse;
185            }
186            /// A module defines a set of vector types for common
187            pub mod common {
188                /// A module defines a set of vector types for common
189                pub mod commonfuncs;
190                /// A module defines a set of vector types for common
191                pub mod dd;
192                /// A module defines a set of vector types for common
193                pub mod df;
194                /// A module defines a macro for polynomial approximation
195                pub mod estrin;
196                /// A module defines a set of vector types for common
197                pub mod misc;
198            }
199            /// A module defines a set of vector types for libm
200            pub mod libm {
201                /// a module defins a set of double precision floating point functions
202                pub mod sleefsimddp;
203                /// a module defins a set of single precision floating point functions
204                pub mod sleefsimdsp;
205            }
206        }
207    }
208    /// A module defines a set of traits for vector
209    pub mod traits;
210    /// A module defines a set of utils for vector
211    pub mod utils;
212
213    #[cfg(target_feature = "avx2")]
214    pub(crate) mod vector_promote {
215        #[cfg(target_pointer_width = "64")]
216        pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
217        #[cfg(target_pointer_width = "32")]
218        pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
219        #[cfg(target_pointer_width = "64")]
220        pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
221        #[cfg(target_pointer_width = "32")]
222        pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
223        pub(crate) use crate::vectors::arch_simd::_256bit::{
224            bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
225            cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
226            f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
227            i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
228            u8x32::u8_promote,
229        };
230    }
231    #[cfg(any(
232        all(not(target_feature = "avx2"), target_feature = "sse"),
233        target_arch = "arm",
234        target_arch = "aarch64",
235        target_feature = "neon"
236    ))]
237    pub(crate) mod vector_promote {
238        #[cfg(target_pointer_width = "64")]
239        pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
240        #[cfg(target_pointer_width = "32")]
241        pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
242        #[cfg(target_pointer_width = "64")]
243        pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
244        #[cfg(target_pointer_width = "32")]
245        pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
246        pub(crate) use crate::vectors::arch_simd::_128bit::{
247            bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
248            cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
249            f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
250            i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
251            u8x16::u8_promote,
252        };
253    }
254}
255
256#[cfg(feature = "cuda")]
257/// A module defines a set of types for cuda
258pub mod cuda_types {
259    /// A module defines convertion for cuda types
260    pub mod convertion;
261    /// A module defines a scalar type for cuda
262    pub mod scalar;
263
264    pub(crate) mod _bf16;
265    pub(crate) mod _bool;
266    pub(crate) mod _cplx32;
267    pub(crate) mod _cplx64;
268    pub(crate) mod _f16;
269    pub(crate) mod _f32;
270    pub(crate) mod _f64;
271    pub(crate) mod _i16;
272    pub(crate) mod _i32;
273    pub(crate) mod _i64;
274    pub(crate) mod _i8;
275    pub(crate) mod _isize;
276    pub(crate) mod _u16;
277    pub(crate) mod _u32;
278    pub(crate) mod _u64;
279    pub(crate) mod _u8;
280    pub(crate) mod _usize;
281}
282
283pub use vectors::*;
284mod simd {
285    pub use crate::vectors::arch_simd::*;
286}
287
288#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
289pub(crate) mod sleef_types {
290    use std::arch::x86_64::*;
291    pub(crate) type VDouble = __m256d;
292    pub(crate) type VMask = __m256i;
293    pub(crate) type Vopmask = __m256i;
294    pub(crate) type VFloat = __m256;
295    pub(crate) type VInt = __m128i;
296    pub(crate) type VInt2 = __m256i;
297    pub(crate) type VInt64 = __m256i;
298    pub(crate) type VUInt64 = __m256i;
299}
300
301#[cfg(all(
302    target_arch = "x86_64",
303    target_feature = "sse",
304    not(target_feature = "avx2")
305))]
306pub(crate) mod sleef_types {
307    use std::arch::x86_64::*;
308    pub(crate) type VDouble = __m128d;
309    pub(crate) type VMask = __m128i;
310    pub(crate) type Vopmask = __m128i;
311    pub(crate) type VFloat = __m128;
312    pub(crate) type VInt = __m128i;
313    pub(crate) type VInt2 = __m128i;
314}
315
316#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
317pub(crate) mod sleef_types {
318    use std::arch::aarch64::*;
319    pub(crate) type VDouble = float64x2_t;
320    pub(crate) type VMask = uint32x4_t;
321    pub(crate) type Vopmask = uint32x4_t;
322    pub(crate) type VFloat = float32x4_t;
323    pub(crate) type VInt = int32x2_t;
324    pub(crate) type VInt2 = int32x4_t;
325}