hpt_types/lib.rs
1//! This crate implment type utilities for tensor operations
2#![deny(missing_docs)]
3
4/// A module implement type conversion
5pub mod convertion;
6/// A module defines a set of data types and utilities
7pub mod dtype;
8/// A module implement type conversion
9pub mod into_scalar;
10/// A module implement simd vector conversion
11pub mod into_vec;
12/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
13pub mod type_promote;
14/// A module defines a set of traits for scalar operations
15pub(crate) mod scalars {
16 pub(crate) mod _bf16;
17 pub(crate) mod _bool;
18 pub(crate) mod _f16;
19 pub(crate) mod _f32;
20 pub(crate) mod _f64;
21 pub(crate) mod impls;
22}
23
24/// A module defines a set of traits for type promotion
25pub mod promotion {
26 #[cfg(feature = "normal_promote")]
27 pub(crate) mod normal_promote {
28 pub(crate) mod _bf16;
29 pub(crate) mod _bool;
30 pub(crate) mod _cplx32;
31 pub(crate) mod _cplx64;
32 pub(crate) mod _f16;
33 pub(crate) mod _f32;
34 pub(crate) mod _f64;
35 pub(crate) mod _i16;
36 pub(crate) mod _i32;
37 pub(crate) mod _i64;
38 pub(crate) mod _i8;
39 pub(crate) mod _isize;
40 pub(crate) mod _u16;
41 pub(crate) mod _u32;
42 pub(crate) mod _u64;
43 pub(crate) mod _u8;
44 pub(crate) mod _usize;
45 }
46 pub(crate) mod utils;
47}
48
49/// A module defines a set of vector types
50pub mod vectors {
51 /// A module defines a set of vector types using stdsimd
52 pub mod arch_simd {
53 /// A module defines a set of 128-bit vector types
54 #[cfg(any(
55 all(not(target_feature = "avx2"), target_feature = "sse"),
56 target_arch = "arm",
57 target_arch = "aarch64",
58 target_feature = "neon"
59 ))]
60 pub mod _128bit {
61 /// A module defines a set of 128-bit vector types for bf16
62 pub mod bf16x8;
63 /// A module defines a set of 128-bit vector types for bool
64 pub mod boolx16;
65 /// A module defines a set of 128-bit vector types for cplx32
66 pub mod cplx32x2;
67 /// A module defines a set of 128-bit vector types for cplx64
68 pub mod cplx64x1;
69 /// A module defines a set of 128-bit vector types for f16
70 pub mod f16x8;
71 /// A module defines a set of 128-bit vector types for f32
72 pub mod f32x4;
73 /// A module defines a set of 128-bit vector types for f64
74 pub mod f64x2;
75 /// A module defines a set of 128-bit vector types for i16
76 pub mod i16x8;
77 /// A module defines a set of 128-bit vector types for i32
78 pub mod i32x4;
79 /// A module defines a set of 128-bit vector types for i64
80 pub mod i64x2;
81 /// A module defines a set of 128-bit vector types for i8
82 pub mod i8x16;
83 /// A module defines a set of 128-bit vector types for isize
84 pub mod isizex2;
85 /// A module defines a set of 128-bit vector types for u16
86 pub mod u16x8;
87 /// A module defines a set of 128-bit vector types for u32
88 pub mod u32x4;
89 /// A module defines a set of 128-bit vector types for u64
90 pub mod u64x2;
91 /// A module defines a set of 128-bit vector types for u8
92 pub mod u8x16;
93 /// A module defines a set of 128-bit vector types for usize
94 pub mod usizex2;
95 }
96 /// A module defines a set of 256-bit vector types
97 #[cfg(target_feature = "avx2")]
98 pub mod _256bit {
99 /// A module defines a set of 256-bit vector types for bf16
100 pub mod bf16x16;
101 /// A module defines a set of 256-bit vector types for bool
102 pub mod boolx32;
103 /// A module defines a set of 256-bit vector types for cplx32
104 pub mod cplx32x4;
105 /// A module defines a set of 256-bit vector types for cplx64
106 pub mod cplx64x2;
107 /// A module defines a set of 256-bit vector types for f16
108 pub mod f16x16;
109 /// A module defines a set of 256-bit vector types for f32
110 pub mod f32x8;
111 /// A module defines a set of 256-bit vector types for f64
112 pub mod f64x4;
113 /// A module defines a set of 256-bit vector types for i16
114 pub mod i16x16;
115 /// A module defines a set of 256-bit vector types for i32
116 pub mod i32x8;
117 /// A module defines a set of 256-bit vector types for i64
118 pub mod i64x4;
119 /// A module defines a set of 256-bit vector types for i8
120 pub mod i8x32;
121 /// A module defines a set of 256-bit vector types for isize
122 pub mod isizex4;
123 /// A module defines a set of 256-bit vector types for u16
124 pub mod u16x16;
125 /// A module defines a set of 256-bit vector types for u32
126 pub mod u32x8;
127 /// A module defines a set of 256-bit vector types for u64
128 pub mod u64x4;
129 /// A module defines a set of 256-bit vector types for u8
130 pub mod u8x32;
131 /// A module defines a set of 256-bit vector types for usize
132 pub mod usizex4;
133 }
134 /// A module defines a set of 512-bit vector types
135 #[cfg(target_feature = "avx512f")]
136 pub mod _512bit {
137 /// A module defines a set of 512-bit vector types for bf16
138 pub mod bf16x32;
139 /// A module defines a set of 512-bit vector types for bool
140 pub mod boolx64;
141 /// A module defines a set of 512-bit vector types for cplx32
142 pub mod cplx32x8;
143 /// A module defines a set of 512-bit vector types for cplx64
144 pub mod cplx64x4;
145 /// A module defines a set of 512-bit vector types for f16
146 pub mod f16x32;
147 /// A module defines a set of 512-bit vector types for f32
148 pub mod f32x16;
149 /// A module defines a set of 512-bit vector types for f64
150 pub mod f64x8;
151 /// A module defines a set of 512-bit vector types for i16
152 pub mod i16x32;
153 /// A module defines a set of 512-bit vector types for i32
154 pub mod i32x16;
155 /// A module defines a set of 512-bit vector types for i64
156 pub mod i64x8;
157 /// A module defines a set of 512-bit vector types for i8
158 pub mod i8x64;
159 /// A module defines a set of 512-bit vector types for isize
160 pub mod isizex8;
161 /// A module defines a set of 512-bit vector types for u16
162 pub mod u16x32;
163 /// A module defines a set of 512-bit vector types for u32
164 pub mod u32x16;
165 /// A module defines a set of 512-bit vector types for u64
166 pub mod u64x8;
167 /// A module defines a set of 512-bit vector types for u8
168 pub mod u8x64;
169 /// A module defines a set of 512-bit vector types for usize
170 pub mod usizex8;
171 }
172
173 // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
174 //
175 // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
176 // Modified work Copyright (c) 2024 hpt Contributors
177 //
178 // Boost Software License - Version 1.0 - August 17th, 2003
179 //
180 // Permission is hereby granted, free of charge, to any person or organization
181 // obtaining a copy of the software and accompanying documentation covered by
182 // this license (the "Software") to use, reproduce, display, distribute,
183 // execute, and transmit the Software, and to prepare derivative works of the
184 // Software, and to permit third-parties to whom the Software is furnished to
185 // do so, all subject to the following:
186 //
187 // The copyright notices in the Software and this entire statement, including
188 // the above license grant, this restriction and the following disclaimer,
189 // must be included in all copies of the Software, in whole or in part, and
190 // all derivative works of the Software, unless such copies or derivative
191 // works are solely in the form of machine-executable object code generated by
192 // a source language processor.
193 //
194 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
195 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
196 // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
197 // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
198 // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
199 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
200 // DEALINGS IN THE SOFTWARE.
201 //
202 // This Rust port is additionally licensed under Apache-2.0 OR MIT
203 // See repository root for details
204 /// A module defines a set of vector types for sleef
205 pub mod sleef {
206 /// A module defines a set of vector types for table
207 pub mod table;
208 /// A module defines a set of vector types for helper
209 pub mod arch {
210 /// A module defines a set of vector types for helper
211 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
212 pub mod helper_aarch64;
213 /// A module defines a set of vector types for helper
214 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
215 pub mod helper_avx2;
216 /// A module defines a set of vector types for helper
217 #[cfg(all(
218 target_arch = "x86_64",
219 target_feature = "sse",
220 not(target_feature = "avx2")
221 ))]
222 pub mod helper_sse;
223 }
224 /// A module defines a set of vector types for common
225 pub mod common {
226 /// A module defines a set of vector types for common
227 pub mod commonfuncs;
228 /// A module defines a set of vector types for common
229 pub mod dd;
230 /// A module defines a set of vector types for common
231 pub mod df;
232 /// A module defines a macro for polynomial approximation
233 pub mod estrin;
234 /// A module defines a set of vector types for common
235 pub mod misc;
236 }
237 /// A module defines a set of vector types for libm
238 pub mod libm {
239 /// a module defins a set of double precision floating point functions
240 pub mod sleefsimddp;
241 /// a module defins a set of single precision floating point functions
242 pub mod sleefsimdsp;
243 }
244 }
245 }
246 /// A module defines a set of traits for vector
247 pub mod traits;
248 /// A module defines a set of utils for vector
249 pub mod utils;
250
251 #[cfg(target_feature = "avx2")]
252 pub(crate) mod vector_promote {
253 #[cfg(target_pointer_width = "64")]
254 pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
255 #[cfg(target_pointer_width = "32")]
256 pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
257 #[cfg(target_pointer_width = "64")]
258 pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
259 #[cfg(target_pointer_width = "32")]
260 pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
261 pub(crate) use crate::vectors::arch_simd::_256bit::{
262 bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
263 cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
264 f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
265 i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
266 u8x32::u8_promote,
267 };
268 }
269 #[cfg(any(
270 all(not(target_feature = "avx2"), target_feature = "sse"),
271 target_arch = "arm",
272 target_arch = "aarch64",
273 target_feature = "neon"
274 ))]
275 pub(crate) mod vector_promote {
276 #[cfg(target_pointer_width = "64")]
277 pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
278 #[cfg(target_pointer_width = "32")]
279 pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
280 #[cfg(target_pointer_width = "64")]
281 pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
282 #[cfg(target_pointer_width = "32")]
283 pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
284 pub(crate) use crate::vectors::arch_simd::_128bit::{
285 bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
286 cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
287 f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
288 i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
289 u8x16::u8_promote,
290 };
291 }
292 #[cfg(target_feature = "avx512f")]
293 pub(crate) mod vector_promote {
294 #[cfg(target_pointer_width = "32")]
295 pub(crate) use crate::vectors::arch_simd::_512bit::isizex16::isize_promote;
296 #[cfg(target_pointer_width = "64")]
297 pub(crate) use crate::vectors::arch_simd::_512bit::isizex8::isize_promote;
298 #[cfg(target_pointer_width = "32")]
299 pub(crate) use crate::vectors::arch_simd::_512bit::usizex16::usize_promote;
300 #[cfg(target_pointer_width = "64")]
301 pub(crate) use crate::vectors::arch_simd::_512bit::usizex8::usize_promote;
302 pub(crate) use crate::vectors::arch_simd::_512bit::{
303 bf16x32::bf16_promote, boolx64::bool_promote, cplx32x8::Complex32_promote,
304 cplx64x4::Complex64_promote, f16x32::f16_promote, f32x16::f32_promote,
305 f64x8::f64_promote, i16x32::i16_promote, i32x16::i32_promote, i64x8::i64_promote,
306 i8x64::i8_promote, u16x32::u16_promote, u32x16::u32_promote, u64x8::u64_promote,
307 u8x64::u8_promote,
308 };
309 }
310}
311
312#[cfg(feature = "cuda")]
313/// A module defines a set of types for cuda
314pub mod cuda_types {
315 /// A module defines convertion for cuda types
316 pub mod convertion;
317 /// A module defines a scalar type for cuda
318 pub mod scalar;
319
320 pub(crate) mod _bf16;
321 pub(crate) mod _bool;
322 pub(crate) mod _cplx32;
323 pub(crate) mod _cplx64;
324 pub(crate) mod _f16;
325 pub(crate) mod _f32;
326 pub(crate) mod _f64;
327 pub(crate) mod _i16;
328 pub(crate) mod _i32;
329 pub(crate) mod _i64;
330 pub(crate) mod _i8;
331 pub(crate) mod _isize;
332 pub(crate) mod _u16;
333 pub(crate) mod _u32;
334 pub(crate) mod _u64;
335 pub(crate) mod _u8;
336 pub(crate) mod _usize;
337}
338
339pub use vectors::*;
340mod simd {
341 pub use crate::vectors::arch_simd::*;
342}
343
344#[cfg(all(target_arch = "x86_64", target_feature = "avx2",))]
345pub(crate) mod sleef_types {
346 use std::arch::x86_64::*;
347 pub(crate) type VDouble = __m256d;
348 pub(crate) type VMask = __m256i;
349 pub(crate) type Vopmask = __m256i;
350 pub(crate) type VFloat = __m256;
351 pub(crate) type VInt = __m128i;
352 pub(crate) type VInt2 = __m256i;
353 pub(crate) type VInt64 = __m256i;
354 pub(crate) type VUInt64 = __m256i;
355}
356
357#[cfg(all(
358 target_arch = "x86_64",
359 target_feature = "sse",
360 not(target_feature = "avx2")
361))]
362pub(crate) mod sleef_types {
363 use std::arch::x86_64::*;
364 pub(crate) type VDouble = __m128d;
365 pub(crate) type VMask = __m128i;
366 pub(crate) type Vopmask = __m128i;
367 pub(crate) type VFloat = __m128;
368 pub(crate) type VInt = __m128i;
369 pub(crate) type VInt2 = __m128i;
370}
371
372#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
373pub(crate) mod sleef_types {
374 use std::arch::aarch64::*;
375 pub(crate) type VDouble = float64x2_t;
376 pub(crate) type VMask = uint32x4_t;
377 pub(crate) type Vopmask = uint32x4_t;
378 pub(crate) type VFloat = float32x4_t;
379 pub(crate) type VInt = int32x2_t;
380 pub(crate) type VInt2 = int32x4_t;
381}