hpt_types/lib.rs
1//! This crate implment type utilities for tensor operations
2#![deny(missing_docs)]
3
4/// A module implement type conversion
5pub mod convertion;
6/// A module defines a set of data types and utilities
7pub mod dtype;
8/// A module implement type conversion
9pub mod into_scalar;
10/// A module implement simd vector conversion
11pub mod into_vec;
12/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
13pub mod type_promote;
14/// A module defines a set of traits for scalar operations
15pub(crate) mod scalars {
16 pub(crate) mod _bf16;
17 pub(crate) mod _bool;
18 pub(crate) mod _f16;
19 pub(crate) mod _f32;
20 pub(crate) mod _f64;
21 pub(crate) mod impls;
22}
23
24/// A module defines a set of traits for type promotion
25pub mod promotion {
26 #[cfg(feature = "normal_promote")]
27 pub(crate) mod normal_promote {
28 pub(crate) mod _bf16;
29 pub(crate) mod _bool;
30 pub(crate) mod _cplx32;
31 pub(crate) mod _cplx64;
32 pub(crate) mod _f16;
33 pub(crate) mod _f32;
34 pub(crate) mod _f64;
35 pub(crate) mod _i16;
36 pub(crate) mod _i32;
37 pub(crate) mod _i64;
38 pub(crate) mod _i8;
39 pub(crate) mod _isize;
40 pub(crate) mod _u16;
41 pub(crate) mod _u32;
42 pub(crate) mod _u64;
43 pub(crate) mod _u8;
44 pub(crate) mod _usize;
45 }
46 pub(crate) mod utils;
47}
48
49/// A module defines a set of vector types
50pub mod vectors {
51 /// A module defines a set of vector types using stdsimd
52 pub mod arch_simd {
53 /// A module defines a set of 128-bit vector types
54 #[cfg(any(
55 all(not(target_feature = "avx2"), target_feature = "sse"),
56 target_arch = "arm",
57 target_arch = "aarch64",
58 target_feature = "neon"
59 ))]
60 pub mod _128bit {
61 /// A module defines a set of 128-bit vector types for bf16
62 pub mod bf16x8;
63 /// A module defines a set of 128-bit vector types for bool
64 pub mod boolx16;
65 /// A module defines a set of 128-bit vector types for cplx32
66 pub mod cplx32x2;
67 /// A module defines a set of 128-bit vector types for cplx64
68 pub mod cplx64x1;
69 /// A module defines a set of 128-bit vector types for f16
70 pub mod f16x8;
71 /// A module defines a set of 128-bit vector types for f32
72 pub mod f32x4;
73 /// A module defines a set of 128-bit vector types for f64
74 pub mod f64x2;
75 /// A module defines a set of 128-bit vector types for i16
76 pub mod i16x8;
77 /// A module defines a set of 128-bit vector types for i32
78 pub mod i32x4;
79 /// A module defines a set of 128-bit vector types for i64
80 pub mod i64x2;
81 /// A module defines a set of 128-bit vector types for i8
82 pub mod i8x16;
83 /// A module defines a set of 128-bit vector types for isize
84 pub mod isizex2;
85 /// A module defines a set of 128-bit vector types for u16
86 pub mod u16x8;
87 /// A module defines a set of 128-bit vector types for u32
88 pub mod u32x4;
89 /// A module defines a set of 128-bit vector types for u64
90 pub mod u64x2;
91 /// A module defines a set of 128-bit vector types for u8
92 pub mod u8x16;
93 /// A module defines a set of 128-bit vector types for usize
94 pub mod usizex2;
95 }
96 /// A module defines a set of 256-bit vector types
97 #[cfg(target_feature = "avx2")]
98 pub mod _256bit {
99 /// A module defines a set of 256-bit vector types for bf16
100 pub mod bf16x16;
101 /// A module defines a set of 256-bit vector types for bool
102 pub mod boolx32;
103 /// A module defines a set of 256-bit vector types for cplx32
104 pub mod cplx32x4;
105 /// A module defines a set of 256-bit vector types for cplx64
106 pub mod cplx64x2;
107 /// A module defines a set of 256-bit vector types for f16
108 pub mod f16x16;
109 /// A module defines a set of 256-bit vector types for f32
110 pub mod f32x8;
111 /// A module defines a set of 256-bit vector types for f64
112 pub mod f64x4;
113 /// A module defines a set of 256-bit vector types for i16
114 pub mod i16x16;
115 /// A module defines a set of 256-bit vector types for i32
116 pub mod i32x8;
117 /// A module defines a set of 256-bit vector types for i64
118 pub mod i64x4;
119 /// A module defines a set of 256-bit vector types for i8
120 pub mod i8x32;
121 /// A module defines a set of 256-bit vector types for isize
122 pub mod isizex4;
123 /// A module defines a set of 256-bit vector types for u16
124 pub mod u16x16;
125 /// A module defines a set of 256-bit vector types for u32
126 pub mod u32x8;
127 /// A module defines a set of 256-bit vector types for u64
128 pub mod u64x4;
129 /// A module defines a set of 256-bit vector types for u8
130 pub mod u8x32;
131 /// A module defines a set of 256-bit vector types for usize
132 pub mod usizex4;
133 }
134
135 // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
136 //
137 // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
138 // Modified work Copyright (c) 2024 hpt Contributors
139 //
140 // Boost Software License - Version 1.0 - August 17th, 2003
141 //
142 // Permission is hereby granted, free of charge, to any person or organization
143 // obtaining a copy of the software and accompanying documentation covered by
144 // this license (the "Software") to use, reproduce, display, distribute,
145 // execute, and transmit the Software, and to prepare derivative works of the
146 // Software, and to permit third-parties to whom the Software is furnished to
147 // do so, all subject to the following:
148 //
149 // The copyright notices in the Software and this entire statement, including
150 // the above license grant, this restriction and the following disclaimer,
151 // must be included in all copies of the Software, in whole or in part, and
152 // all derivative works of the Software, unless such copies or derivative
153 // works are solely in the form of machine-executable object code generated by
154 // a source language processor.
155 //
156 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
157 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
158 // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
159 // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
160 // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
161 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
162 // DEALINGS IN THE SOFTWARE.
163 //
164 // This Rust port is additionally licensed under Apache-2.0 OR MIT
165 // See repository root for details
166 /// A module defines a set of vector types for sleef
167 pub mod sleef {
168 /// A module defines a set of vector types for table
169 pub mod table;
170 /// A module defines a set of vector types for helper
171 pub mod arch {
172 /// A module defines a set of vector types for helper
173 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
174 pub mod helper_aarch64;
175 /// A module defines a set of vector types for helper
176 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
177 pub mod helper_avx2;
178 /// A module defines a set of vector types for helper
179 #[cfg(all(
180 target_arch = "x86_64",
181 target_feature = "sse",
182 not(target_feature = "avx2")
183 ))]
184 pub mod helper_sse;
185 }
186 /// A module defines a set of vector types for common
187 pub mod common {
188 /// A module defines a set of vector types for common
189 pub mod commonfuncs;
190 /// A module defines a set of vector types for common
191 pub mod dd;
192 /// A module defines a set of vector types for common
193 pub mod df;
194 /// A module defines a macro for polynomial approximation
195 pub mod estrin;
196 /// A module defines a set of vector types for common
197 pub mod misc;
198 }
199 /// A module defines a set of vector types for libm
200 pub mod libm {
201 /// a module defins a set of double precision floating point functions
202 pub mod sleefsimddp;
203 /// a module defins a set of single precision floating point functions
204 pub mod sleefsimdsp;
205 }
206 }
207 }
208 /// A module defines a set of traits for vector
209 pub mod traits;
210 /// A module defines a set of utils for vector
211 pub mod utils;
212
213 #[cfg(target_feature = "avx2")]
214 pub(crate) mod vector_promote {
215 #[cfg(target_pointer_width = "64")]
216 pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
217 #[cfg(target_pointer_width = "32")]
218 pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
219 #[cfg(target_pointer_width = "64")]
220 pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
221 #[cfg(target_pointer_width = "32")]
222 pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
223 pub(crate) use crate::vectors::arch_simd::_256bit::{
224 bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
225 cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
226 f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
227 i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
228 u8x32::u8_promote,
229 };
230 }
231 #[cfg(any(
232 all(not(target_feature = "avx2"), target_feature = "sse"),
233 target_arch = "arm",
234 target_arch = "aarch64",
235 target_feature = "neon"
236 ))]
237 pub(crate) mod vector_promote {
238 #[cfg(target_pointer_width = "64")]
239 pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
240 #[cfg(target_pointer_width = "32")]
241 pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
242 #[cfg(target_pointer_width = "64")]
243 pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
244 #[cfg(target_pointer_width = "32")]
245 pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
246 pub(crate) use crate::vectors::arch_simd::_128bit::{
247 bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
248 cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
249 f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
250 i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
251 u8x16::u8_promote,
252 };
253 }
254}
255
256#[cfg(feature = "cuda")]
257/// A module defines a set of types for cuda
258pub mod cuda_types {
259 /// A module defines convertion for cuda types
260 pub mod convertion;
261 /// A module defines a scalar type for cuda
262 pub mod scalar;
263
264 pub(crate) mod _bf16;
265 pub(crate) mod _bool;
266 pub(crate) mod _cplx32;
267 pub(crate) mod _cplx64;
268 pub(crate) mod _f16;
269 pub(crate) mod _f32;
270 pub(crate) mod _f64;
271 pub(crate) mod _i16;
272 pub(crate) mod _i32;
273 pub(crate) mod _i64;
274 pub(crate) mod _i8;
275 pub(crate) mod _isize;
276 pub(crate) mod _u16;
277 pub(crate) mod _u32;
278 pub(crate) mod _u64;
279 pub(crate) mod _u8;
280 pub(crate) mod _usize;
281}
282
283pub use vectors::*;
284mod simd {
285 pub use crate::vectors::arch_simd::*;
286}
287
288#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
289pub(crate) mod sleef_types {
290 use std::arch::x86_64::*;
291 pub(crate) type VDouble = __m256d;
292 pub(crate) type VMask = __m256i;
293 pub(crate) type Vopmask = __m256i;
294 pub(crate) type VFloat = __m256;
295 pub(crate) type VInt = __m128i;
296 pub(crate) type VInt2 = __m256i;
297 pub(crate) type VInt64 = __m256i;
298 pub(crate) type VUInt64 = __m256i;
299}
300
301#[cfg(all(
302 target_arch = "x86_64",
303 target_feature = "sse",
304 not(target_feature = "avx2")
305))]
306pub(crate) mod sleef_types {
307 use std::arch::x86_64::*;
308 pub(crate) type VDouble = __m128d;
309 pub(crate) type VMask = __m128i;
310 pub(crate) type Vopmask = __m128i;
311 pub(crate) type VFloat = __m128;
312 pub(crate) type VInt = __m128i;
313 pub(crate) type VInt2 = __m128i;
314}
315
316#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
317pub(crate) mod sleef_types {
318 use std::arch::aarch64::*;
319 pub(crate) type VDouble = float64x2_t;
320 pub(crate) type VMask = uint32x4_t;
321 pub(crate) type Vopmask = uint32x4_t;
322 pub(crate) type VFloat = float32x4_t;
323 pub(crate) type VInt = int32x2_t;
324 pub(crate) type VInt2 = int32x4_t;
325}