hpt_types/
lib.rs

1//! This crate implment type utilities for tensor operations
2#![deny(missing_docs)]
3
4/// A module implement type conversion
5pub mod convertion;
6/// A module defines a set of data types and utilities
7pub mod dtype;
8/// A module implement type conversion
9pub mod into_scalar;
10/// A module implement simd vector conversion
11pub mod into_vec;
12/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
13pub mod type_promote;
14/// A module defines a set of traits for scalar operations
15pub(crate) mod scalars {
16    pub(crate) mod _bf16;
17    pub(crate) mod _bool;
18    pub(crate) mod _f16;
19    pub(crate) mod _f32;
20    pub(crate) mod _f64;
21    pub(crate) mod impls;
22}
23
24/// A module defines a set of traits for type promotion
25pub mod promotion {
26    #[cfg(feature = "normal_promote")]
27    pub(crate) mod normal_promote {
28        pub(crate) mod _bf16;
29        pub(crate) mod _bool;
30        pub(crate) mod _cplx32;
31        pub(crate) mod _cplx64;
32        pub(crate) mod _f16;
33        pub(crate) mod _f32;
34        pub(crate) mod _f64;
35        pub(crate) mod _i16;
36        pub(crate) mod _i32;
37        pub(crate) mod _i64;
38        pub(crate) mod _i8;
39        pub(crate) mod _isize;
40        pub(crate) mod _u16;
41        pub(crate) mod _u32;
42        pub(crate) mod _u64;
43        pub(crate) mod _u8;
44        pub(crate) mod _usize;
45    }
46    pub(crate) mod utils;
47}
48
49/// A module defines a set of vector types
50pub mod vectors {
51    /// A module defines a set of vector types using stdsimd
52    pub mod arch_simd {
53        /// A module defines a set of 128-bit vector types
54        #[cfg(any(
55            all(not(target_feature = "avx2"), target_feature = "sse"),
56            target_arch = "arm",
57            target_arch = "aarch64",
58            target_feature = "neon"
59        ))]
60        pub mod _128bit {
61            /// A module defines a set of 128-bit vector types for bf16
62            pub mod bf16x8;
63            /// A module defines a set of 128-bit vector types for bool
64            pub mod boolx16;
65            /// A module defines a set of 128-bit vector types for cplx32
66            pub mod cplx32x2;
67            /// A module defines a set of 128-bit vector types for cplx64
68            pub mod cplx64x1;
69            /// A module defines a set of 128-bit vector types for f16
70            pub mod f16x8;
71            /// A module defines a set of 128-bit vector types for f32
72            pub mod f32x4;
73            /// A module defines a set of 128-bit vector types for f64
74            pub mod f64x2;
75            /// A module defines a set of 128-bit vector types for i16
76            pub mod i16x8;
77            /// A module defines a set of 128-bit vector types for i32
78            pub mod i32x4;
79            /// A module defines a set of 128-bit vector types for i64
80            pub mod i64x2;
81            /// A module defines a set of 128-bit vector types for i8
82            pub mod i8x16;
83            /// A module defines a set of 128-bit vector types for isize
84            pub mod isizex2;
85            /// A module defines a set of 128-bit vector types for u16
86            pub mod u16x8;
87            /// A module defines a set of 128-bit vector types for u32
88            pub mod u32x4;
89            /// A module defines a set of 128-bit vector types for u64
90            pub mod u64x2;
91            /// A module defines a set of 128-bit vector types for u8
92            pub mod u8x16;
93            /// A module defines a set of 128-bit vector types for usize
94            pub mod usizex2;
95        }
96        /// A module defines a set of 256-bit vector types
97        #[cfg(target_feature = "avx2")]
98        pub mod _256bit {
99            /// A module defines a set of 256-bit vector types for bf16
100            pub mod bf16x16;
101            /// A module defines a set of 256-bit vector types for bool
102            pub mod boolx32;
103            /// A module defines a set of 256-bit vector types for cplx32
104            pub mod cplx32x4;
105            /// A module defines a set of 256-bit vector types for cplx64
106            pub mod cplx64x2;
107            /// A module defines a set of 256-bit vector types for f16
108            pub mod f16x16;
109            /// A module defines a set of 256-bit vector types for f32
110            pub mod f32x8;
111            /// A module defines a set of 256-bit vector types for f64
112            pub mod f64x4;
113            /// A module defines a set of 256-bit vector types for i16
114            pub mod i16x16;
115            /// A module defines a set of 256-bit vector types for i32
116            pub mod i32x8;
117            /// A module defines a set of 256-bit vector types for i64
118            pub mod i64x4;
119            /// A module defines a set of 256-bit vector types for i8
120            pub mod i8x32;
121            /// A module defines a set of 256-bit vector types for isize
122            pub mod isizex4;
123            /// A module defines a set of 256-bit vector types for u16
124            pub mod u16x16;
125            /// A module defines a set of 256-bit vector types for u32
126            pub mod u32x8;
127            /// A module defines a set of 256-bit vector types for u64
128            pub mod u64x4;
129            /// A module defines a set of 256-bit vector types for u8
130            pub mod u8x32;
131            /// A module defines a set of 256-bit vector types for usize
132            pub mod usizex4;
133        }
134        /// A module defines a set of 512-bit vector types
135        #[cfg(target_feature = "avx512f")]
136        pub mod _512bit {
137            /// A module defines a set of 512-bit vector types for bf16
138            pub mod bf16x32;
139            /// A module defines a set of 512-bit vector types for bool
140            pub mod boolx64;
141            /// A module defines a set of 512-bit vector types for cplx32
142            pub mod cplx32x8;
143            /// A module defines a set of 512-bit vector types for cplx64
144            pub mod cplx64x4;
145            /// A module defines a set of 512-bit vector types for f16
146            pub mod f16x32;
147            /// A module defines a set of 512-bit vector types for f32
148            pub mod f32x16;
149            /// A module defines a set of 512-bit vector types for f64
150            pub mod f64x8;
151            /// A module defines a set of 512-bit vector types for i16
152            pub mod i16x32;
153            /// A module defines a set of 512-bit vector types for i32
154            pub mod i32x16;
155            /// A module defines a set of 512-bit vector types for i64
156            pub mod i64x8;
157            /// A module defines a set of 512-bit vector types for i8
158            pub mod i8x64;
159            /// A module defines a set of 512-bit vector types for isize
160            pub mod isizex8;
161            /// A module defines a set of 512-bit vector types for u16
162            pub mod u16x32;
163            /// A module defines a set of 512-bit vector types for u32
164            pub mod u32x16;
165            /// A module defines a set of 512-bit vector types for u64
166            pub mod u64x8;
167            /// A module defines a set of 512-bit vector types for u8
168            pub mod u8x64;
169            /// A module defines a set of 512-bit vector types for usize
170            pub mod usizex8;
171        }
172
173        // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
174        //
175        // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
176        // Modified work Copyright (c) 2024 hpt Contributors
177        //
178        // Boost Software License - Version 1.0 - August 17th, 2003
179        //
180        // Permission is hereby granted, free of charge, to any person or organization
181        // obtaining a copy of the software and accompanying documentation covered by
182        // this license (the "Software") to use, reproduce, display, distribute,
183        // execute, and transmit the Software, and to prepare derivative works of the
184        // Software, and to permit third-parties to whom the Software is furnished to
185        // do so, all subject to the following:
186        //
187        // The copyright notices in the Software and this entire statement, including
188        // the above license grant, this restriction and the following disclaimer,
189        // must be included in all copies of the Software, in whole or in part, and
190        // all derivative works of the Software, unless such copies or derivative
191        // works are solely in the form of machine-executable object code generated by
192        // a source language processor.
193        //
194        // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
195        // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
196        // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
197        // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
198        // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
199        // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
200        // DEALINGS IN THE SOFTWARE.
201        //
202        // This Rust port is additionally licensed under Apache-2.0 OR MIT
203        // See repository root for details
204        /// A module defines a set of vector types for sleef
205        pub mod sleef {
206            /// A module defines a set of vector types for table
207            pub mod table;
208            /// A module defines a set of vector types for helper
209            pub mod arch {
210                /// A module defines a set of vector types for helper
211                #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
212                pub mod helper_aarch64;
213                /// A module defines a set of vector types for helper
214                #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
215                pub mod helper_avx2;
216                /// A module defines a set of vector types for helper
217                #[cfg(all(
218                    target_arch = "x86_64",
219                    target_feature = "sse",
220                    not(target_feature = "avx2")
221                ))]
222                pub mod helper_sse;
223            }
224            /// A module defines a set of vector types for common
225            pub mod common {
226                /// A module defines a set of vector types for common
227                pub mod commonfuncs;
228                /// A module defines a set of vector types for common
229                pub mod dd;
230                /// A module defines a set of vector types for common
231                pub mod df;
232                /// A module defines a macro for polynomial approximation
233                pub mod estrin;
234                /// A module defines a set of vector types for common
235                pub mod misc;
236            }
237            /// A module defines a set of vector types for libm
238            pub mod libm {
239                /// a module defins a set of double precision floating point functions
240                pub mod sleefsimddp;
241                /// a module defins a set of single precision floating point functions
242                pub mod sleefsimdsp;
243            }
244        }
245    }
246    /// A module defines a set of traits for vector
247    pub mod traits;
248    /// A module defines a set of utils for vector
249    pub mod utils;
250
251    #[cfg(target_feature = "avx2")]
252    pub(crate) mod vector_promote {
253        #[cfg(target_pointer_width = "64")]
254        pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
255        #[cfg(target_pointer_width = "32")]
256        pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
257        #[cfg(target_pointer_width = "64")]
258        pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
259        #[cfg(target_pointer_width = "32")]
260        pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
261        pub(crate) use crate::vectors::arch_simd::_256bit::{
262            bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
263            cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
264            f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
265            i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
266            u8x32::u8_promote,
267        };
268    }
269    #[cfg(any(
270        all(not(target_feature = "avx2"), target_feature = "sse"),
271        target_arch = "arm",
272        target_arch = "aarch64",
273        target_feature = "neon"
274    ))]
275    pub(crate) mod vector_promote {
276        #[cfg(target_pointer_width = "64")]
277        pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
278        #[cfg(target_pointer_width = "32")]
279        pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
280        #[cfg(target_pointer_width = "64")]
281        pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
282        #[cfg(target_pointer_width = "32")]
283        pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
284        pub(crate) use crate::vectors::arch_simd::_128bit::{
285            bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
286            cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
287            f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
288            i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
289            u8x16::u8_promote,
290        };
291    }
292    #[cfg(target_feature = "avx512f")]
293    pub(crate) mod vector_promote {
294        #[cfg(target_pointer_width = "32")]
295        pub(crate) use crate::vectors::arch_simd::_512bit::isizex16::isize_promote;
296        #[cfg(target_pointer_width = "64")]
297        pub(crate) use crate::vectors::arch_simd::_512bit::isizex8::isize_promote;
298        #[cfg(target_pointer_width = "32")]
299        pub(crate) use crate::vectors::arch_simd::_512bit::usizex16::usize_promote;
300        #[cfg(target_pointer_width = "64")]
301        pub(crate) use crate::vectors::arch_simd::_512bit::usizex8::usize_promote;
302        pub(crate) use crate::vectors::arch_simd::_512bit::{
303            bf16x32::bf16_promote, boolx64::bool_promote, cplx32x8::Complex32_promote,
304            cplx64x4::Complex64_promote, f16x32::f16_promote, f32x16::f32_promote,
305            f64x8::f64_promote, i16x32::i16_promote, i32x16::i32_promote, i64x8::i64_promote,
306            i8x64::i8_promote, u16x32::u16_promote, u32x16::u32_promote, u64x8::u64_promote,
307            u8x64::u8_promote,
308        };
309    }
310}
311
312#[cfg(feature = "cuda")]
313/// A module defines a set of types for cuda
314pub mod cuda_types {
315    /// A module defines convertion for cuda types
316    pub mod convertion;
317    /// A module defines a scalar type for cuda
318    pub mod scalar;
319
320    pub(crate) mod _bf16;
321    pub(crate) mod _bool;
322    pub(crate) mod _cplx32;
323    pub(crate) mod _cplx64;
324    pub(crate) mod _f16;
325    pub(crate) mod _f32;
326    pub(crate) mod _f64;
327    pub(crate) mod _i16;
328    pub(crate) mod _i32;
329    pub(crate) mod _i64;
330    pub(crate) mod _i8;
331    pub(crate) mod _isize;
332    pub(crate) mod _u16;
333    pub(crate) mod _u32;
334    pub(crate) mod _u64;
335    pub(crate) mod _u8;
336    pub(crate) mod _usize;
337}
338
339pub use vectors::*;
340mod simd {
341    pub use crate::vectors::arch_simd::*;
342}
343
344#[cfg(all(target_arch = "x86_64", target_feature = "avx2",))]
345pub(crate) mod sleef_types {
346    use std::arch::x86_64::*;
347    pub(crate) type VDouble = __m256d;
348    pub(crate) type VMask = __m256i;
349    pub(crate) type Vopmask = __m256i;
350    pub(crate) type VFloat = __m256;
351    pub(crate) type VInt = __m128i;
352    pub(crate) type VInt2 = __m256i;
353    pub(crate) type VInt64 = __m256i;
354    pub(crate) type VUInt64 = __m256i;
355}
356
357#[cfg(all(
358    target_arch = "x86_64",
359    target_feature = "sse",
360    not(target_feature = "avx2")
361))]
362pub(crate) mod sleef_types {
363    use std::arch::x86_64::*;
364    pub(crate) type VDouble = __m128d;
365    pub(crate) type VMask = __m128i;
366    pub(crate) type Vopmask = __m128i;
367    pub(crate) type VFloat = __m128;
368    pub(crate) type VInt = __m128i;
369    pub(crate) type VInt2 = __m128i;
370}
371
372#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
373pub(crate) mod sleef_types {
374    use std::arch::aarch64::*;
375    pub(crate) type VDouble = float64x2_t;
376    pub(crate) type VMask = uint32x4_t;
377    pub(crate) type Vopmask = uint32x4_t;
378    pub(crate) type VFloat = float32x4_t;
379    pub(crate) type VInt = int32x2_t;
380    pub(crate) type VInt2 = int32x4_t;
381}