hpt_types/
lib.rs

1//! This crate implment type utilities for tensor operations
2#![deny(missing_docs)]
3
4/// A module implement type conversion
5pub mod convertion;
6/// A module defines a set of data types and utilities
7pub mod dtype;
8/// A module implement type conversion
9pub mod into_scalar;
10/// A module implement simd vector conversion
11pub mod into_vec;
12/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
13pub mod type_promote;
14/// A module defines a set of traits for scalar operations
15pub(crate) mod scalars {
16    pub(crate) mod _bf16;
17    pub(crate) mod _bool;
18    pub(crate) mod _f16;
19    pub(crate) mod _f32;
20    pub(crate) mod _f64;
21    pub(crate) mod impls;
22}
23
24/// A module defines a set of traits for type promotion
25pub mod promotion {
26    #[cfg(feature = "normal_promote")]
27    pub(crate) mod normal_promote {
28        pub(crate) mod _bf16;
29        pub(crate) mod _bool;
30        pub(crate) mod _cplx32;
31        pub(crate) mod _cplx64;
32        pub(crate) mod _f16;
33        pub(crate) mod _f32;
34        pub(crate) mod _f64;
35        pub(crate) mod _i16;
36        pub(crate) mod _i32;
37        pub(crate) mod _i64;
38        pub(crate) mod _i8;
39        pub(crate) mod _isize;
40        pub(crate) mod _u16;
41        pub(crate) mod _u32;
42        pub(crate) mod _u64;
43        pub(crate) mod _u8;
44        pub(crate) mod _usize;
45    }
46    pub(crate) mod utils;
47}
48
49/// A module defines a set of vector types
50pub mod vectors {
51    /// A module defines a set of vector types using stdsimd
52    pub mod arch_simd {
53        /// A module defines a set of 128-bit vector types
54        #[cfg(any(
55            all(not(target_feature = "avx2"), target_feature = "sse"),
56            target_arch = "arm",
57            target_arch = "aarch64",
58            target_feature = "neon"
59        ))]
60        pub mod _128bit {
61            pub(crate) mod common {
62                pub(crate) mod bf16x8;
63                pub(crate) mod boolx16;
64                pub(crate) mod cplx32x2;
65                pub(crate) mod cplx64x1;
66                pub(crate) mod f16x8;
67                pub(crate) mod f32x4;
68                pub(crate) mod f64x2;
69                pub(crate) mod i16x8;
70                pub(crate) mod i32x4;
71                pub(crate) mod i64x2;
72                pub(crate) mod i8x16;
73                pub(crate) mod isizex2;
74                pub(crate) mod u16x8;
75                pub(crate) mod u32x4;
76                pub(crate) mod u64x2;
77                pub(crate) mod u8x16;
78                pub(crate) mod usizex2;
79            }
80
81            #[cfg(target_feature = "neon")]
82            #[cfg(target_arch = "aarch64")]
83            pub(crate) mod neon {
84                pub(crate) mod bf16x8;
85                pub(crate) mod boolx16;
86                pub(crate) mod f16x8;
87                pub(crate) mod f32x4;
88                pub(crate) mod f64x2;
89                pub(crate) mod i16x8;
90                pub(crate) mod i32x4;
91                pub(crate) mod i64x2;
92                pub(crate) mod i8x16;
93                pub(crate) mod u16x8;
94                pub(crate) mod u32x4;
95                pub(crate) mod u64x2;
96                pub(crate) mod u8x16;
97            }
98
99            #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
100            pub(crate) mod sse {
101                pub(crate) mod bf16x8;
102                pub(crate) mod boolx16;
103                pub(crate) mod f16x8;
104                pub(crate) mod f32x4;
105                pub(crate) mod f64x2;
106                pub(crate) mod i16x8;
107                pub(crate) mod i32x4;
108                pub(crate) mod i64x2;
109                pub(crate) mod i8x16;
110                pub(crate) mod u16x8;
111                pub(crate) mod u32x4;
112                pub(crate) mod u64x2;
113                pub(crate) mod u8x16;
114            }
115
116            pub use crate::arch_simd::_128bit::common::{
117                bf16x8::bf16x8, boolx16::boolx16, cplx32x2::cplx32x2, cplx64x1::cplx64x1,
118                f16x8::f16x8, f32x4::f32x4, f64x2::f64x2, i16x8::i16x8, i32x4::i32x4, i64x2::i64x2,
119                i8x16::i8x16, isizex2::isizex2, u16x8::u16x8, u32x4::u32x4, u64x2::u64x2,
120                u8x16::u8x16, usizex2::usizex2,
121            };
122        }
123        /// A module defines a set of 256-bit vector types
124        #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
125        pub mod _256bit {
126            pub(crate) mod common {
127                pub(crate) mod bf16x16;
128                pub(crate) mod boolx32;
129                pub(crate) mod cplx32x4;
130                pub(crate) mod cplx64x2;
131                pub(crate) mod f16x16;
132                pub(crate) mod f32x8;
133                pub(crate) mod f64x4;
134                pub(crate) mod i16x16;
135                pub(crate) mod i32x8;
136                pub(crate) mod i64x4;
137                pub(crate) mod i8x32;
138                pub(crate) mod isizex4;
139                pub(crate) mod u16x16;
140                pub(crate) mod u32x8;
141                pub(crate) mod u64x4;
142                pub(crate) mod u8x32;
143                pub(crate) mod usizex4;
144            }
145            #[cfg(target_feature = "avx2")]
146            pub(crate) mod avx2 {
147                pub(crate) mod bf16x16;
148                pub(crate) mod f16x16;
149                pub(crate) mod f32x8;
150                pub(crate) mod f64x4;
151                pub(crate) mod i16x16;
152                pub(crate) mod i32x8;
153                pub(crate) mod i64x4;
154                pub(crate) mod i8x32;
155                pub(crate) mod u16x16;
156                pub(crate) mod u32x8;
157                pub(crate) mod u64x4;
158                pub(crate) mod u8x32;
159            }
160
161            pub use crate::arch_simd::_256bit::common::{
162                bf16x16::bf16x16, boolx32::boolx32, cplx32x4::cplx32x4, cplx64x2::cplx64x2,
163                f16x16::f16x16, f32x8::f32x8, f64x4::f64x4, i16x16::i16x16, i32x8::i32x8,
164                i64x4::i64x4, i8x32::i8x32, isizex4::isizex4, u16x16::u16x16, u32x8::u32x8,
165                u64x4::u64x4, u8x32::u8x32, usizex4::usizex4,
166            };
167        }
168
169        // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
170        //
171        // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
172        // Modified work Copyright (c) 2024 hpt Contributors
173        //
174        // Boost Software License - Version 1.0 - August 17th, 2003
175        //
176        // Permission is hereby granted, free of charge, to any person or organization
177        // obtaining a copy of the software and accompanying documentation covered by
178        // this license (the "Software") to use, reproduce, display, distribute,
179        // execute, and transmit the Software, and to prepare derivative works of the
180        // Software, and to permit third-parties to whom the Software is furnished to
181        // do so, all subject to the following:
182        //
183        // The copyright notices in the Software and this entire statement, including
184        // the above license grant, this restriction and the following disclaimer,
185        // must be included in all copies of the Software, in whole or in part, and
186        // all derivative works of the Software, unless such copies or derivative
187        // works are solely in the form of machine-executable object code generated by
188        // a source language processor.
189        //
190        // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
191        // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
192        // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
193        // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
194        // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
195        // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
196        // DEALINGS IN THE SOFTWARE.
197        //
198        // This Rust port is additionally licensed under Apache-2.0 OR MIT
199        // See repository root for details
200        /// A module defines a set of vector types for sleef
201        #[allow(clippy::approx_constant)]
202        #[allow(clippy::excessive_precision)]
203        #[allow(clippy::unreadable_literal)]
204        pub mod sleef {
205            /// A module defines a set of vector types for table
206            pub mod table;
207            /// A module defines a set of vector types for helper
208            pub mod arch {
209                /// A module defines a set of vector types for helper
210                #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
211                pub mod helper_aarch64;
212                /// A module defines a set of vector types for helper
213                #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
214                pub mod helper_avx2;
215                /// A module defines a set of vector types for helper
216                #[cfg(all(
217                    target_arch = "x86_64",
218                    target_feature = "sse",
219                    not(target_feature = "avx2")
220                ))]
221                pub mod helper_sse;
222            }
223            /// A module defines a set of vector types for common
224            pub mod common {
225                /// A module defines a set of vector types for common
226                pub mod commonfuncs;
227                /// A module defines a set of vector types for common
228                pub mod dd;
229                /// A module defines a set of vector types for common
230                pub mod df;
231                /// A module defines a macro for polynomial approximation
232                pub mod estrin;
233                /// A module defines a set of vector types for common
234                pub mod misc;
235            }
236            /// A module defines a set of vector types for libm
237            pub mod libm {
238                /// a module defins a set of double precision floating point functions
239                pub mod sleefsimddp;
240                /// a module defins a set of single precision floating point functions
241                pub mod sleefsimdsp;
242            }
243        }
244    }
245    /// A module defines a set of traits for vector
246    pub mod traits;
247    /// A module defines a set of utils for vector
248    pub mod utils;
249
250    #[cfg(target_feature = "avx2")]
251    pub(crate) mod vector_promote {
252        #[cfg(target_pointer_width = "64")]
253        pub(crate) use crate::vectors::arch_simd::_256bit::common::isizex4::isize_promote;
254        #[cfg(target_pointer_width = "32")]
255        pub(crate) use crate::vectors::arch_simd::_256bit::common::isizex8::isize_promote;
256        #[cfg(target_pointer_width = "64")]
257        pub(crate) use crate::vectors::arch_simd::_256bit::common::usizex4::usize_promote;
258        #[cfg(target_pointer_width = "32")]
259        pub(crate) use crate::vectors::arch_simd::_256bit::common::usizex8::usize_promote;
260        pub(crate) use crate::vectors::arch_simd::_256bit::common::{
261            bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
262            cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
263            f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
264            i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
265            u8x32::u8_promote,
266        };
267    }
268    #[cfg(any(
269        all(not(target_feature = "avx2"), target_feature = "sse"),
270        target_arch = "arm",
271        target_arch = "aarch64",
272        target_feature = "neon"
273    ))]
274    pub(crate) mod vector_promote {
275        #[cfg(target_pointer_width = "64")]
276        pub(crate) use crate::vectors::arch_simd::_128bit::common::isizex2::isize_promote;
277        #[cfg(target_pointer_width = "32")]
278        pub(crate) use crate::vectors::arch_simd::_128bit::common::isizex4::isize_promote;
279        #[cfg(target_pointer_width = "64")]
280        pub(crate) use crate::vectors::arch_simd::_128bit::common::usizex2::usize_promote;
281        #[cfg(target_pointer_width = "32")]
282        pub(crate) use crate::vectors::arch_simd::_128bit::common::usizex4::usize_promote;
283        pub(crate) use crate::vectors::arch_simd::_128bit::common::{
284            bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
285            cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
286            f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
287            i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
288            u8x16::u8_promote,
289        };
290    }
291}
292
293#[cfg(feature = "cuda")]
294/// A module defines a set of types for cuda
295pub mod cuda_types {
296    /// A module defines convertion for cuda types
297    pub mod convertion;
298    /// A module defines a scalar type for cuda
299    pub mod scalar;
300
301    pub(crate) mod _bf16;
302    pub(crate) mod _bool;
303    pub(crate) mod _cplx32;
304    pub(crate) mod _cplx64;
305    pub(crate) mod _f16;
306    pub(crate) mod _f32;
307    pub(crate) mod _f64;
308    pub(crate) mod _i16;
309    pub(crate) mod _i32;
310    pub(crate) mod _i64;
311    pub(crate) mod _i8;
312    pub(crate) mod _isize;
313    pub(crate) mod _u16;
314    pub(crate) mod _u32;
315    pub(crate) mod _u64;
316    pub(crate) mod _u8;
317    pub(crate) mod _usize;
318}
319
320pub use vectors::*;
321mod simd {
322    pub use crate::vectors::arch_simd::*;
323}
324
325#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
326pub(crate) mod sleef_types {
327    use std::arch::x86_64::*;
328    pub(crate) type VDouble = __m256d;
329    pub(crate) type VMask = __m256i;
330    pub(crate) type Vopmask = __m256i;
331    pub(crate) type VFloat = __m256;
332    pub(crate) type VInt = __m128i;
333    pub(crate) type VInt2 = __m256i;
334    pub(crate) type VInt64 = __m256i;
335    pub(crate) type VUInt64 = __m256i;
336}
337
338#[cfg(all(
339    target_arch = "x86_64",
340    target_feature = "sse",
341    not(target_feature = "avx2")
342))]
343pub(crate) mod sleef_types {
344    use std::arch::x86_64::*;
345    pub(crate) type VDouble = __m128d;
346    pub(crate) type VMask = __m128i;
347    pub(crate) type Vopmask = __m128i;
348    pub(crate) type VFloat = __m128;
349    pub(crate) type VInt = __m128i;
350    pub(crate) type VInt2 = __m128i;
351}
352
353#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
354pub(crate) mod sleef_types {
355    use std::arch::aarch64::*;
356    pub(crate) type VDouble = float64x2_t;
357    pub(crate) type VMask = uint32x4_t;
358    pub(crate) type Vopmask = uint32x4_t;
359    pub(crate) type VFloat = float32x4_t;
360    pub(crate) type VInt = int32x2_t;
361    pub(crate) type VInt2 = int32x4_t;
362}