hpt_types/
lib.rs

1//! This crate implment type utilities for tensor operations
2
3#![cfg_attr(feature = "stdsimd", feature(portable_simd))]
4#![deny(missing_docs)]
5
6/// A module defines a set of data types and utilities
7pub mod dtype;
8pub extern crate half;
9/// A module implement type conversion
10pub mod convertion;
11/// A module implement type conversion
12pub mod into_scalar;
13/// A module implement simd vector conversion
14pub mod into_vec;
15/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
16pub mod type_promote;
17/// A module defines a set of traits for scalar operations
18pub(crate) mod scalars {
19    pub(crate) mod _bf16;
20    pub(crate) mod _bool;
21    pub(crate) mod _f16;
22    pub(crate) mod _f32;
23    pub(crate) mod _f64;
24    pub(crate) mod impls;
25}
26
27/// A module defines a set of traits for type promotion
28pub mod promotion {
29    #[cfg(feature = "normal_promote")]
30    pub(crate) mod normal_promote {
31        pub(crate) mod _bf16;
32        pub(crate) mod _bool;
33        pub(crate) mod _cplx32;
34        pub(crate) mod _cplx64;
35        pub(crate) mod _f16;
36        pub(crate) mod _f32;
37        pub(crate) mod _f64;
38        pub(crate) mod _i16;
39        pub(crate) mod _i32;
40        pub(crate) mod _i64;
41        pub(crate) mod _i8;
42        pub(crate) mod _isize;
43        pub(crate) mod _u16;
44        pub(crate) mod _u32;
45        pub(crate) mod _u64;
46        pub(crate) mod _u8;
47        pub(crate) mod _usize;
48    }
49    pub(crate) mod utils;
50}
51
52/// A module defines a set of vector types
53pub mod vectors {
54    /// A module defines a set of vector types using stdsimd
55    #[cfg(feature = "stdsimd")]
56    pub mod std_simd {
57        /// A module defines a set of 128-bit vector types
58        #[cfg(any(
59            all(not(target_feature = "avx2"), target_feature = "sse"),
60            target_arch = "arm",
61            target_arch = "aarch64",
62            target_feature = "neon"
63        ))]
64        pub mod _128bit {
65            /// A module defines a set of 128-bit vector types for bf16
66            pub mod bf16x8;
67            /// A module defines a set of 128-bit vector types for bool
68            pub mod boolx16;
69            /// A module defines a set of 128-bit vector types for cplx32
70            pub mod cplx32x2;
71            /// A module defines a set of 128-bit vector types for cplx64
72            pub mod cplx64x1;
73            /// A module defines a set of 128-bit vector types for f16
74            pub mod f16x8;
75            /// A module defines a set of 128-bit vector types for f32
76            pub mod f32x4;
77            /// A module defines a set of 128-bit vector types for f64
78            pub mod f64x2;
79            /// A module defines a set of 128-bit vector types for i16
80            pub mod i16x8;
81            /// A module defines a set of 128-bit vector types for i32
82            pub mod i32x4;
83            /// A module defines a set of 128-bit vector types for i64
84            pub mod i64x2;
85            /// A module defines a set of 128-bit vector types for i8
86            pub mod i8x16;
87            /// A module defines a set of 128-bit vector types for isize
88            pub mod isizex2;
89            /// A module defines a set of 128-bit vector types for u16
90            pub mod u16x8;
91            /// A module defines a set of 128-bit vector types for u32
92            pub mod u32x4;
93            /// A module defines a set of 128-bit vector types for u64
94            pub mod u64x2;
95            /// A module defines a set of 128-bit vector types for u8
96            pub mod u8x16;
97            /// A module defines a set of 128-bit vector types for usize
98            pub mod usizex2;
99        }
100        /// A module defines a set of 256-bit vector types
101        #[cfg(target_feature = "avx2")]
102        pub mod _256bit {
103            /// A module defines a set of 256-bit vector types for bf16
104            pub mod bf16x16;
105            /// A module defines a set of 256-bit vector types for bool
106            pub mod boolx32;
107            /// A module defines a set of 256-bit vector types for cplx32
108            pub mod cplx32x4;
109            /// A module defines a set of 256-bit vector types for cplx64
110            pub mod cplx64x2;
111            /// A module defines a set of 256-bit vector types for f16
112            pub mod f16x16;
113            /// A module defines a set of 256-bit vector types for f32
114            pub mod f32x8;
115            /// A module defines a set of 256-bit vector types for f64
116            pub mod f64x4;
117            /// A module defines a set of 256-bit vector types for i16
118            pub mod i16x16;
119            /// A module defines a set of 256-bit vector types for i32
120            pub mod i32x8;
121            /// A module defines a set of 256-bit vector types for i64
122            pub mod i64x4;
123            /// A module defines a set of 256-bit vector types for i8
124            pub mod i8x32;
125            /// A module defines a set of 256-bit vector types for isize
126            pub mod isizex4;
127            /// A module defines a set of 256-bit vector types for u16
128            pub mod u16x16;
129            /// A module defines a set of 256-bit vector types for u32
130            pub mod u32x8;
131            /// A module defines a set of 256-bit vector types for u64
132            pub mod u64x4;
133            /// A module defines a set of 256-bit vector types for u8
134            pub mod u8x32;
135            /// A module defines a set of 256-bit vector types for usize
136            pub mod usizex4;
137        }
138        /// A module defines a set of 512-bit vector types
139        #[cfg(target_feature = "avx512f")]
140        pub mod _512bit {
141            /// A module defines a set of 512-bit vector types for bf16
142            pub mod bf16x32;
143            /// A module defines a set of 512-bit vector types for bool
144            pub mod boolx64;
145            /// A module defines a set of 512-bit vector types for cplx32
146            pub mod cplx32x8;
147            /// A module defines a set of 512-bit vector types for cplx64
148            pub mod cplx64x4;
149            /// A module defines a set of 512-bit vector types for f16
150            pub mod f16x32;
151            /// A module defines a set of 512-bit vector types for f32
152            pub mod f32x16;
153            /// A module defines a set of 512-bit vector types for f64
154            pub mod f64x8;
155            /// A module defines a set of 512-bit vector types for i16
156            pub mod i16x32;
157            /// A module defines a set of 512-bit vector types for i32
158            pub mod i32x16;
159            /// A module defines a set of 512-bit vector types for i64
160            pub mod i64x8;
161            /// A module defines a set of 512-bit vector types for i8
162            pub mod i8x64;
163            /// A module defines a set of 512-bit vector types for isize
164            pub mod isizex8;
165            /// A module defines a set of 512-bit vector types for u16
166            pub mod u16x32;
167            /// A module defines a set of 512-bit vector types for u32
168            pub mod u32x16;
169            /// A module defines a set of 512-bit vector types for u64
170            pub mod u64x8;
171            /// A module defines a set of 512-bit vector types for u8
172            pub mod u8x64;
173            /// A module defines a set of 512-bit vector types for usize
174            pub mod usizex8;
175        }
176    }
177    /// A module defines a set of vector types using stdsimd
178    #[cfg(feature = "archsimd")]
179    pub mod arch_simd {
180        /// A module defines a set of 128-bit vector types
181        #[cfg(any(
182            all(not(target_feature = "avx2"), target_feature = "sse"),
183            target_arch = "arm",
184            target_arch = "aarch64",
185            target_feature = "neon"
186        ))]
187        pub mod _128bit {
188            /// A module defines a set of 128-bit vector types for bf16
189            pub mod bf16x8;
190            /// A module defines a set of 128-bit vector types for bool
191            pub mod boolx16;
192            /// A module defines a set of 128-bit vector types for cplx32
193            pub mod cplx32x2;
194            /// A module defines a set of 128-bit vector types for cplx64
195            pub mod cplx64x1;
196            /// A module defines a set of 128-bit vector types for f16
197            pub mod f16x8;
198            /// A module defines a set of 128-bit vector types for f32
199            pub mod f32x4;
200            /// A module defines a set of 128-bit vector types for f64
201            pub mod f64x2;
202            /// A module defines a set of 128-bit vector types for i16
203            pub mod i16x8;
204            /// A module defines a set of 128-bit vector types for i32
205            pub mod i32x4;
206            /// A module defines a set of 128-bit vector types for i64
207            pub mod i64x2;
208            /// A module defines a set of 128-bit vector types for i8
209            pub mod i8x16;
210            /// A module defines a set of 128-bit vector types for isize
211            pub mod isizex2;
212            /// A module defines a set of 128-bit vector types for u16
213            pub mod u16x8;
214            /// A module defines a set of 128-bit vector types for u32
215            pub mod u32x4;
216            /// A module defines a set of 128-bit vector types for u64
217            pub mod u64x2;
218            /// A module defines a set of 128-bit vector types for u8
219            pub mod u8x16;
220            /// A module defines a set of 128-bit vector types for usize
221            pub mod usizex2;
222        }
223        /// A module defines a set of 256-bit vector types
224        #[cfg(target_feature = "avx2")]
225        pub mod _256bit {
226            /// A module defines a set of 256-bit vector types for bf16
227            pub mod bf16x16;
228            /// A module defines a set of 256-bit vector types for bool
229            pub mod boolx32;
230            /// A module defines a set of 256-bit vector types for cplx32
231            pub mod cplx32x4;
232            /// A module defines a set of 256-bit vector types for cplx64
233            pub mod cplx64x2;
234            /// A module defines a set of 256-bit vector types for f16
235            pub mod f16x16;
236            /// A module defines a set of 256-bit vector types for f32
237            pub mod f32x8;
238            /// A module defines a set of 256-bit vector types for f64
239            pub mod f64x4;
240            /// A module defines a set of 256-bit vector types for i16
241            pub mod i16x16;
242            /// A module defines a set of 256-bit vector types for i32
243            pub mod i32x8;
244            /// A module defines a set of 256-bit vector types for i64
245            pub mod i64x4;
246            /// A module defines a set of 256-bit vector types for i8
247            pub mod i8x32;
248            /// A module defines a set of 256-bit vector types for isize
249            pub mod isizex4;
250            /// A module defines a set of 256-bit vector types for u16
251            pub mod u16x16;
252            /// A module defines a set of 256-bit vector types for u32
253            pub mod u32x8;
254            /// A module defines a set of 256-bit vector types for u64
255            pub mod u64x4;
256            /// A module defines a set of 256-bit vector types for u8
257            pub mod u8x32;
258            /// A module defines a set of 256-bit vector types for usize
259            pub mod usizex4;
260        }
261        /// A module defines a set of 512-bit vector types
262        #[cfg(target_feature = "avx512f")]
263        pub mod _512bit {
264            /// A module defines a set of 512-bit vector types for bf16
265            pub mod bf16x32;
266            /// A module defines a set of 512-bit vector types for bool
267            pub mod boolx64;
268            /// A module defines a set of 512-bit vector types for cplx32
269            pub mod cplx32x8;
270            /// A module defines a set of 512-bit vector types for cplx64
271            pub mod cplx64x4;
272            /// A module defines a set of 512-bit vector types for f16
273            pub mod f16x32;
274            /// A module defines a set of 512-bit vector types for f32
275            pub mod f32x16;
276            /// A module defines a set of 512-bit vector types for f64
277            pub mod f64x8;
278            /// A module defines a set of 512-bit vector types for i16
279            pub mod i16x32;
280            /// A module defines a set of 512-bit vector types for i32
281            pub mod i32x16;
282            /// A module defines a set of 512-bit vector types for i64
283            pub mod i64x8;
284            /// A module defines a set of 512-bit vector types for i8
285            pub mod i8x64;
286            /// A module defines a set of 512-bit vector types for isize
287            pub mod isizex8;
288            /// A module defines a set of 512-bit vector types for u16
289            pub mod u16x32;
290            /// A module defines a set of 512-bit vector types for u32
291            pub mod u32x16;
292            /// A module defines a set of 512-bit vector types for u64
293            pub mod u64x8;
294            /// A module defines a set of 512-bit vector types for u8
295            pub mod u8x64;
296            /// A module defines a set of 512-bit vector types for usize
297            pub mod usizex8;
298        }
299
300        // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
301        //
302        // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
303        // Modified work Copyright (c) 2024 hpt Contributors
304        //
305        // Boost Software License - Version 1.0 - August 17th, 2003
306        //
307        // Permission is hereby granted, free of charge, to any person or organization
308        // obtaining a copy of the software and accompanying documentation covered by
309        // this license (the "Software") to use, reproduce, display, distribute,
310        // execute, and transmit the Software, and to prepare derivative works of the
311        // Software, and to permit third-parties to whom the Software is furnished to
312        // do so, all subject to the following:
313        //
314        // The copyright notices in the Software and this entire statement, including
315        // the above license grant, this restriction and the following disclaimer,
316        // must be included in all copies of the Software, in whole or in part, and
317        // all derivative works of the Software, unless such copies or derivative
318        // works are solely in the form of machine-executable object code generated by
319        // a source language processor.
320        //
321        // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
322        // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
323        // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
324        // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
325        // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
326        // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
327        // DEALINGS IN THE SOFTWARE.
328        //
329        // This Rust port is additionally licensed under Apache-2.0 OR MIT
330        // See repository root for details
331        /// A module defines a set of vector types for sleef
332        pub mod sleef {
333            /// A module defines a set of vector types for table
334            pub mod table;
335            /// A module defines a set of vector types for helper
336            pub mod arch {
337                /// A module defines a set of vector types for helper
338                #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
339                pub mod helper_aarch64;
340                /// A module defines a set of vector types for helper
341                #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
342                pub mod helper_avx2;
343                /// A module defines a set of vector types for helper
344                #[cfg(all(
345                    target_arch = "x86_64",
346                    target_feature = "sse",
347                    not(target_feature = "avx2")
348                ))]
349                pub mod helper_sse;
350            }
351            /// A module defines a set of vector types for common
352            pub mod common {
353                /// A module defines a set of vector types for common
354                pub mod commonfuncs;
355                /// A module defines a set of vector types for common
356                pub mod dd;
357                /// A module defines a set of vector types for common
358                pub mod df;
359                /// A module defines a macro for polynomial approximation
360                pub mod estrin;
361                /// A module defines a set of vector types for common
362                pub mod misc;
363            }
364            /// A module defines a set of vector types for libm
365            pub mod libm {
366                /// a module defins a set of double precision floating point functions
367                pub mod sleefsimddp;
368                /// a module defins a set of single precision floating point functions
369                pub mod sleefsimdsp;
370            }
371        }
372    }
373    /// A module defines a set of traits for vector
374    pub mod traits;
375    /// A module defines a set of utils for vector
376    pub mod utils;
377
378    #[cfg(target_feature = "avx2")]
379    pub(crate) mod vector_promote {
380        #[cfg(target_pointer_width = "64")]
381        pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
382        #[cfg(target_pointer_width = "32")]
383        pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
384        #[cfg(target_pointer_width = "64")]
385        pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
386        #[cfg(target_pointer_width = "32")]
387        pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
388        pub(crate) use crate::vectors::arch_simd::_256bit::{
389            bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
390            cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
391            f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
392            i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
393            u8x32::u8_promote,
394        };
395    }
396    #[cfg(any(
397        all(not(target_feature = "avx2"), target_feature = "sse"),
398        target_arch = "arm",
399        target_arch = "aarch64",
400        target_feature = "neon"
401    ))]
402    pub(crate) mod vector_promote {
403        #[cfg(target_pointer_width = "64")]
404        pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
405        #[cfg(target_pointer_width = "32")]
406        pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
407        #[cfg(target_pointer_width = "64")]
408        pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
409        #[cfg(target_pointer_width = "32")]
410        pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
411        pub(crate) use crate::vectors::arch_simd::_128bit::{
412            bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
413            cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
414            f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
415            i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
416            u8x16::u8_promote,
417        };
418    }
419    #[cfg(target_feature = "avx512f")]
420    pub(crate) mod vector_promote {
421        #[cfg(target_pointer_width = "32")]
422        pub(crate) use crate::vectors::arch_simd::_512bit::isizex16::isize_promote;
423        #[cfg(target_pointer_width = "64")]
424        pub(crate) use crate::vectors::arch_simd::_512bit::isizex8::isize_promote;
425        #[cfg(target_pointer_width = "32")]
426        pub(crate) use crate::vectors::arch_simd::_512bit::usizex16::usize_promote;
427        #[cfg(target_pointer_width = "64")]
428        pub(crate) use crate::vectors::arch_simd::_512bit::usizex8::usize_promote;
429        pub(crate) use crate::vectors::arch_simd::_512bit::{
430            bf16x32::bf16_promote, boolx64::bool_promote, cplx32x8::Complex32_promote,
431            cplx64x4::Complex64_promote, f16x32::f16_promote, f32x16::f32_promote,
432            f64x8::f64_promote, i16x32::i16_promote, i32x16::i32_promote, i64x8::i64_promote,
433            i8x64::i8_promote, u16x32::u16_promote, u32x16::u32_promote, u64x8::u64_promote,
434            u8x64::u8_promote,
435        };
436    }
437}
438
439#[cfg(feature = "cuda")]
440/// A module defines a set of types for cuda
441pub mod cuda_types {
442    /// A module defines convertion for cuda types
443    pub mod convertion;
444    /// A module defines a scalar type for cuda
445    pub mod scalar;
446
447    pub(crate) mod _bf16;
448    pub(crate) mod _bool;
449    pub(crate) mod _cplx32;
450    pub(crate) mod _cplx64;
451    pub(crate) mod _f16;
452    pub(crate) mod _f32;
453    pub(crate) mod _f64;
454    pub(crate) mod _i16;
455    pub(crate) mod _i32;
456    pub(crate) mod _i64;
457    pub(crate) mod _i8;
458    pub(crate) mod _isize;
459    pub(crate) mod _u16;
460    pub(crate) mod _u32;
461    pub(crate) mod _u64;
462    pub(crate) mod _u8;
463    pub(crate) mod _usize;
464}
465
466pub use vectors::*;
467#[cfg(feature = "archsimd")]
468mod simd {
469    pub use crate::vectors::arch_simd::*;
470}
471#[cfg(feature = "stdsimd")]
472mod simd {
473    pub use crate::vectors::std_simd::*;
474}
475
476#[cfg(all(
477    target_arch = "x86_64",
478    target_feature = "avx2",
479    not(feature = "stdsimd")
480))]
481pub(crate) mod sleef_types {
482    use std::arch::x86_64::*;
483    pub(crate) type VDouble = __m256d;
484    pub(crate) type VMask = __m256i;
485    pub(crate) type Vopmask = __m256i;
486    pub(crate) type VFloat = __m256;
487    pub(crate) type VInt = __m128i;
488    pub(crate) type VInt2 = __m256i;
489    pub(crate) type VInt64 = __m256i;
490    pub(crate) type VUInt64 = __m256i;
491}
492
493#[cfg(all(
494    target_arch = "x86_64",
495    target_feature = "sse",
496    not(target_feature = "avx2"),
497    not(feature = "stdsimd")
498))]
499pub(crate) mod sleef_types {
500    use std::arch::x86_64::*;
501    pub(crate) type VDouble = __m128d;
502    pub(crate) type VMask = __m128i;
503    pub(crate) type Vopmask = __m128i;
504    pub(crate) type VFloat = __m128;
505    pub(crate) type VInt = __m128i;
506    pub(crate) type VInt2 = __m128i;
507}
508
509#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
510pub(crate) mod sleef_types {
511    use std::arch::aarch64::*;
512    pub(crate) type VDouble = float64x2_t;
513    pub(crate) type VMask = uint32x4_t;
514    pub(crate) type Vopmask = uint32x4_t;
515    pub(crate) type VFloat = float32x4_t;
516    pub(crate) type VInt = int32x2_t;
517    pub(crate) type VInt2 = int32x4_t;
518}