hpt_types/lib.rs
1//! This crate implment type utilities for tensor operations
2
3#![cfg_attr(feature = "stdsimd", feature(portable_simd))]
4#![deny(missing_docs)]
5
6/// A module defines a set of data types and utilities
7pub mod dtype;
8pub extern crate half;
9/// A module implement type conversion
10pub mod convertion;
11/// A module implement type conversion
12pub mod into_scalar;
13/// A module implement simd vector conversion
14pub mod into_vec;
15/// A module defines a set of traits for tensor operations, and implement computation functions for scalar and vector types
16pub mod type_promote;
17/// A module defines a set of traits for scalar operations
18pub(crate) mod scalars {
19 pub(crate) mod _bf16;
20 pub(crate) mod _bool;
21 pub(crate) mod _f16;
22 pub(crate) mod _f32;
23 pub(crate) mod _f64;
24 pub(crate) mod impls;
25}
26
27/// A module defines a set of traits for type promotion
28pub mod promotion {
29 #[cfg(feature = "normal_promote")]
30 pub(crate) mod normal_promote {
31 pub(crate) mod _bf16;
32 pub(crate) mod _bool;
33 pub(crate) mod _cplx32;
34 pub(crate) mod _cplx64;
35 pub(crate) mod _f16;
36 pub(crate) mod _f32;
37 pub(crate) mod _f64;
38 pub(crate) mod _i16;
39 pub(crate) mod _i32;
40 pub(crate) mod _i64;
41 pub(crate) mod _i8;
42 pub(crate) mod _isize;
43 pub(crate) mod _u16;
44 pub(crate) mod _u32;
45 pub(crate) mod _u64;
46 pub(crate) mod _u8;
47 pub(crate) mod _usize;
48 }
49 pub(crate) mod utils;
50}
51
52/// A module defines a set of vector types
53pub mod vectors {
54 /// A module defines a set of vector types using stdsimd
55 #[cfg(feature = "stdsimd")]
56 pub mod std_simd {
57 /// A module defines a set of 128-bit vector types
58 #[cfg(any(
59 all(not(target_feature = "avx2"), target_feature = "sse"),
60 target_arch = "arm",
61 target_arch = "aarch64",
62 target_feature = "neon"
63 ))]
64 pub mod _128bit {
65 /// A module defines a set of 128-bit vector types for bf16
66 pub mod bf16x8;
67 /// A module defines a set of 128-bit vector types for bool
68 pub mod boolx16;
69 /// A module defines a set of 128-bit vector types for cplx32
70 pub mod cplx32x2;
71 /// A module defines a set of 128-bit vector types for cplx64
72 pub mod cplx64x1;
73 /// A module defines a set of 128-bit vector types for f16
74 pub mod f16x8;
75 /// A module defines a set of 128-bit vector types for f32
76 pub mod f32x4;
77 /// A module defines a set of 128-bit vector types for f64
78 pub mod f64x2;
79 /// A module defines a set of 128-bit vector types for i16
80 pub mod i16x8;
81 /// A module defines a set of 128-bit vector types for i32
82 pub mod i32x4;
83 /// A module defines a set of 128-bit vector types for i64
84 pub mod i64x2;
85 /// A module defines a set of 128-bit vector types for i8
86 pub mod i8x16;
87 /// A module defines a set of 128-bit vector types for isize
88 pub mod isizex2;
89 /// A module defines a set of 128-bit vector types for u16
90 pub mod u16x8;
91 /// A module defines a set of 128-bit vector types for u32
92 pub mod u32x4;
93 /// A module defines a set of 128-bit vector types for u64
94 pub mod u64x2;
95 /// A module defines a set of 128-bit vector types for u8
96 pub mod u8x16;
97 /// A module defines a set of 128-bit vector types for usize
98 pub mod usizex2;
99 }
100 /// A module defines a set of 256-bit vector types
101 #[cfg(target_feature = "avx2")]
102 pub mod _256bit {
103 /// A module defines a set of 256-bit vector types for bf16
104 pub mod bf16x16;
105 /// A module defines a set of 256-bit vector types for bool
106 pub mod boolx32;
107 /// A module defines a set of 256-bit vector types for cplx32
108 pub mod cplx32x4;
109 /// A module defines a set of 256-bit vector types for cplx64
110 pub mod cplx64x2;
111 /// A module defines a set of 256-bit vector types for f16
112 pub mod f16x16;
113 /// A module defines a set of 256-bit vector types for f32
114 pub mod f32x8;
115 /// A module defines a set of 256-bit vector types for f64
116 pub mod f64x4;
117 /// A module defines a set of 256-bit vector types for i16
118 pub mod i16x16;
119 /// A module defines a set of 256-bit vector types for i32
120 pub mod i32x8;
121 /// A module defines a set of 256-bit vector types for i64
122 pub mod i64x4;
123 /// A module defines a set of 256-bit vector types for i8
124 pub mod i8x32;
125 /// A module defines a set of 256-bit vector types for isize
126 pub mod isizex4;
127 /// A module defines a set of 256-bit vector types for u16
128 pub mod u16x16;
129 /// A module defines a set of 256-bit vector types for u32
130 pub mod u32x8;
131 /// A module defines a set of 256-bit vector types for u64
132 pub mod u64x4;
133 /// A module defines a set of 256-bit vector types for u8
134 pub mod u8x32;
135 /// A module defines a set of 256-bit vector types for usize
136 pub mod usizex4;
137 }
138 /// A module defines a set of 512-bit vector types
139 #[cfg(target_feature = "avx512f")]
140 pub mod _512bit {
141 /// A module defines a set of 512-bit vector types for bf16
142 pub mod bf16x32;
143 /// A module defines a set of 512-bit vector types for bool
144 pub mod boolx64;
145 /// A module defines a set of 512-bit vector types for cplx32
146 pub mod cplx32x8;
147 /// A module defines a set of 512-bit vector types for cplx64
148 pub mod cplx64x4;
149 /// A module defines a set of 512-bit vector types for f16
150 pub mod f16x32;
151 /// A module defines a set of 512-bit vector types for f32
152 pub mod f32x16;
153 /// A module defines a set of 512-bit vector types for f64
154 pub mod f64x8;
155 /// A module defines a set of 512-bit vector types for i16
156 pub mod i16x32;
157 /// A module defines a set of 512-bit vector types for i32
158 pub mod i32x16;
159 /// A module defines a set of 512-bit vector types for i64
160 pub mod i64x8;
161 /// A module defines a set of 512-bit vector types for i8
162 pub mod i8x64;
163 /// A module defines a set of 512-bit vector types for isize
164 pub mod isizex8;
165 /// A module defines a set of 512-bit vector types for u16
166 pub mod u16x32;
167 /// A module defines a set of 512-bit vector types for u32
168 pub mod u32x16;
169 /// A module defines a set of 512-bit vector types for u64
170 pub mod u64x8;
171 /// A module defines a set of 512-bit vector types for u8
172 pub mod u8x64;
173 /// A module defines a set of 512-bit vector types for usize
174 pub mod usizex8;
175 }
176 }
177 /// A module defines a set of vector types using stdsimd
178 #[cfg(feature = "archsimd")]
179 pub mod arch_simd {
180 /// A module defines a set of 128-bit vector types
181 #[cfg(any(
182 all(not(target_feature = "avx2"), target_feature = "sse"),
183 target_arch = "arm",
184 target_arch = "aarch64",
185 target_feature = "neon"
186 ))]
187 pub mod _128bit {
188 /// A module defines a set of 128-bit vector types for bf16
189 pub mod bf16x8;
190 /// A module defines a set of 128-bit vector types for bool
191 pub mod boolx16;
192 /// A module defines a set of 128-bit vector types for cplx32
193 pub mod cplx32x2;
194 /// A module defines a set of 128-bit vector types for cplx64
195 pub mod cplx64x1;
196 /// A module defines a set of 128-bit vector types for f16
197 pub mod f16x8;
198 /// A module defines a set of 128-bit vector types for f32
199 pub mod f32x4;
200 /// A module defines a set of 128-bit vector types for f64
201 pub mod f64x2;
202 /// A module defines a set of 128-bit vector types for i16
203 pub mod i16x8;
204 /// A module defines a set of 128-bit vector types for i32
205 pub mod i32x4;
206 /// A module defines a set of 128-bit vector types for i64
207 pub mod i64x2;
208 /// A module defines a set of 128-bit vector types for i8
209 pub mod i8x16;
210 /// A module defines a set of 128-bit vector types for isize
211 pub mod isizex2;
212 /// A module defines a set of 128-bit vector types for u16
213 pub mod u16x8;
214 /// A module defines a set of 128-bit vector types for u32
215 pub mod u32x4;
216 /// A module defines a set of 128-bit vector types for u64
217 pub mod u64x2;
218 /// A module defines a set of 128-bit vector types for u8
219 pub mod u8x16;
220 /// A module defines a set of 128-bit vector types for usize
221 pub mod usizex2;
222 }
223 /// A module defines a set of 256-bit vector types
224 #[cfg(target_feature = "avx2")]
225 pub mod _256bit {
226 /// A module defines a set of 256-bit vector types for bf16
227 pub mod bf16x16;
228 /// A module defines a set of 256-bit vector types for bool
229 pub mod boolx32;
230 /// A module defines a set of 256-bit vector types for cplx32
231 pub mod cplx32x4;
232 /// A module defines a set of 256-bit vector types for cplx64
233 pub mod cplx64x2;
234 /// A module defines a set of 256-bit vector types for f16
235 pub mod f16x16;
236 /// A module defines a set of 256-bit vector types for f32
237 pub mod f32x8;
238 /// A module defines a set of 256-bit vector types for f64
239 pub mod f64x4;
240 /// A module defines a set of 256-bit vector types for i16
241 pub mod i16x16;
242 /// A module defines a set of 256-bit vector types for i32
243 pub mod i32x8;
244 /// A module defines a set of 256-bit vector types for i64
245 pub mod i64x4;
246 /// A module defines a set of 256-bit vector types for i8
247 pub mod i8x32;
248 /// A module defines a set of 256-bit vector types for isize
249 pub mod isizex4;
250 /// A module defines a set of 256-bit vector types for u16
251 pub mod u16x16;
252 /// A module defines a set of 256-bit vector types for u32
253 pub mod u32x8;
254 /// A module defines a set of 256-bit vector types for u64
255 pub mod u64x4;
256 /// A module defines a set of 256-bit vector types for u8
257 pub mod u8x32;
258 /// A module defines a set of 256-bit vector types for usize
259 pub mod usizex4;
260 }
261 /// A module defines a set of 512-bit vector types
262 #[cfg(target_feature = "avx512f")]
263 pub mod _512bit {
264 /// A module defines a set of 512-bit vector types for bf16
265 pub mod bf16x32;
266 /// A module defines a set of 512-bit vector types for bool
267 pub mod boolx64;
268 /// A module defines a set of 512-bit vector types for cplx32
269 pub mod cplx32x8;
270 /// A module defines a set of 512-bit vector types for cplx64
271 pub mod cplx64x4;
272 /// A module defines a set of 512-bit vector types for f16
273 pub mod f16x32;
274 /// A module defines a set of 512-bit vector types for f32
275 pub mod f32x16;
276 /// A module defines a set of 512-bit vector types for f64
277 pub mod f64x8;
278 /// A module defines a set of 512-bit vector types for i16
279 pub mod i16x32;
280 /// A module defines a set of 512-bit vector types for i32
281 pub mod i32x16;
282 /// A module defines a set of 512-bit vector types for i64
283 pub mod i64x8;
284 /// A module defines a set of 512-bit vector types for i8
285 pub mod i8x64;
286 /// A module defines a set of 512-bit vector types for isize
287 pub mod isizex8;
288 /// A module defines a set of 512-bit vector types for u16
289 pub mod u16x32;
290 /// A module defines a set of 512-bit vector types for u32
291 pub mod u32x16;
292 /// A module defines a set of 512-bit vector types for u64
293 pub mod u64x8;
294 /// A module defines a set of 512-bit vector types for u8
295 pub mod u8x64;
296 /// A module defines a set of 512-bit vector types for usize
297 pub mod usizex8;
298 }
299
300 // This file contains code ported from SLEEF (https://github.com/shibatch/sleef)
301 //
302 // Original work Copyright (c) 2010-2022, Naoki Shibata and contributors
303 // Modified work Copyright (c) 2024 hpt Contributors
304 //
305 // Boost Software License - Version 1.0 - August 17th, 2003
306 //
307 // Permission is hereby granted, free of charge, to any person or organization
308 // obtaining a copy of the software and accompanying documentation covered by
309 // this license (the "Software") to use, reproduce, display, distribute,
310 // execute, and transmit the Software, and to prepare derivative works of the
311 // Software, and to permit third-parties to whom the Software is furnished to
312 // do so, all subject to the following:
313 //
314 // The copyright notices in the Software and this entire statement, including
315 // the above license grant, this restriction and the following disclaimer,
316 // must be included in all copies of the Software, in whole or in part, and
317 // all derivative works of the Software, unless such copies or derivative
318 // works are solely in the form of machine-executable object code generated by
319 // a source language processor.
320 //
321 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
322 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
323 // FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
324 // SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
325 // FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
326 // ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
327 // DEALINGS IN THE SOFTWARE.
328 //
329 // This Rust port is additionally licensed under Apache-2.0 OR MIT
330 // See repository root for details
331 /// A module defines a set of vector types for sleef
332 pub mod sleef {
333 /// A module defines a set of vector types for table
334 pub mod table;
335 /// A module defines a set of vector types for helper
336 pub mod arch {
337 /// A module defines a set of vector types for helper
338 #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
339 pub mod helper_aarch64;
340 /// A module defines a set of vector types for helper
341 #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
342 pub mod helper_avx2;
343 /// A module defines a set of vector types for helper
344 #[cfg(all(
345 target_arch = "x86_64",
346 target_feature = "sse",
347 not(target_feature = "avx2")
348 ))]
349 pub mod helper_sse;
350 }
351 /// A module defines a set of vector types for common
352 pub mod common {
353 /// A module defines a set of vector types for common
354 pub mod commonfuncs;
355 /// A module defines a set of vector types for common
356 pub mod dd;
357 /// A module defines a set of vector types for common
358 pub mod df;
359 /// A module defines a macro for polynomial approximation
360 pub mod estrin;
361 /// A module defines a set of vector types for common
362 pub mod misc;
363 }
364 /// A module defines a set of vector types for libm
365 pub mod libm {
366 /// a module defins a set of double precision floating point functions
367 pub mod sleefsimddp;
368 /// a module defins a set of single precision floating point functions
369 pub mod sleefsimdsp;
370 }
371 }
372 }
373 /// A module defines a set of traits for vector
374 pub mod traits;
375 /// A module defines a set of utils for vector
376 pub mod utils;
377
378 #[cfg(target_feature = "avx2")]
379 pub(crate) mod vector_promote {
380 #[cfg(target_pointer_width = "64")]
381 pub(crate) use crate::vectors::arch_simd::_256bit::isizex4::isize_promote;
382 #[cfg(target_pointer_width = "32")]
383 pub(crate) use crate::vectors::arch_simd::_256bit::isizex8::isize_promote;
384 #[cfg(target_pointer_width = "64")]
385 pub(crate) use crate::vectors::arch_simd::_256bit::usizex4::usize_promote;
386 #[cfg(target_pointer_width = "32")]
387 pub(crate) use crate::vectors::arch_simd::_256bit::usizex8::usize_promote;
388 pub(crate) use crate::vectors::arch_simd::_256bit::{
389 bf16x16::bf16_promote, boolx32::bool_promote, cplx32x4::Complex32_promote,
390 cplx64x2::Complex64_promote, f16x16::f16_promote, f32x8::f32_promote,
391 f64x4::f64_promote, i16x16::i16_promote, i32x8::i32_promote, i64x4::i64_promote,
392 i8x32::i8_promote, u16x16::u16_promote, u32x8::u32_promote, u64x4::u64_promote,
393 u8x32::u8_promote,
394 };
395 }
396 #[cfg(any(
397 all(not(target_feature = "avx2"), target_feature = "sse"),
398 target_arch = "arm",
399 target_arch = "aarch64",
400 target_feature = "neon"
401 ))]
402 pub(crate) mod vector_promote {
403 #[cfg(target_pointer_width = "64")]
404 pub(crate) use crate::vectors::arch_simd::_128bit::isizex2::isize_promote;
405 #[cfg(target_pointer_width = "32")]
406 pub(crate) use crate::vectors::arch_simd::_128bit::isizex4::isize_promote;
407 #[cfg(target_pointer_width = "64")]
408 pub(crate) use crate::vectors::arch_simd::_128bit::usizex2::usize_promote;
409 #[cfg(target_pointer_width = "32")]
410 pub(crate) use crate::vectors::arch_simd::_128bit::usizex4::usize_promote;
411 pub(crate) use crate::vectors::arch_simd::_128bit::{
412 bf16x8::bf16_promote, boolx16::bool_promote, cplx32x2::Complex32_promote,
413 cplx64x1::Complex64_promote, f16x8::f16_promote, f32x4::f32_promote,
414 f64x2::f64_promote, i16x8::i16_promote, i32x4::i32_promote, i64x2::i64_promote,
415 i8x16::i8_promote, u16x8::u16_promote, u32x4::u32_promote, u64x2::u64_promote,
416 u8x16::u8_promote,
417 };
418 }
419 #[cfg(target_feature = "avx512f")]
420 pub(crate) mod vector_promote {
421 #[cfg(target_pointer_width = "32")]
422 pub(crate) use crate::vectors::arch_simd::_512bit::isizex16::isize_promote;
423 #[cfg(target_pointer_width = "64")]
424 pub(crate) use crate::vectors::arch_simd::_512bit::isizex8::isize_promote;
425 #[cfg(target_pointer_width = "32")]
426 pub(crate) use crate::vectors::arch_simd::_512bit::usizex16::usize_promote;
427 #[cfg(target_pointer_width = "64")]
428 pub(crate) use crate::vectors::arch_simd::_512bit::usizex8::usize_promote;
429 pub(crate) use crate::vectors::arch_simd::_512bit::{
430 bf16x32::bf16_promote, boolx64::bool_promote, cplx32x8::Complex32_promote,
431 cplx64x4::Complex64_promote, f16x32::f16_promote, f32x16::f32_promote,
432 f64x8::f64_promote, i16x32::i16_promote, i32x16::i32_promote, i64x8::i64_promote,
433 i8x64::i8_promote, u16x32::u16_promote, u32x16::u32_promote, u64x8::u64_promote,
434 u8x64::u8_promote,
435 };
436 }
437}
438
439#[cfg(feature = "cuda")]
440/// A module defines a set of types for cuda
441pub mod cuda_types {
442 /// A module defines convertion for cuda types
443 pub mod convertion;
444 /// A module defines a scalar type for cuda
445 pub mod scalar;
446
447 pub(crate) mod _bf16;
448 pub(crate) mod _bool;
449 pub(crate) mod _cplx32;
450 pub(crate) mod _cplx64;
451 pub(crate) mod _f16;
452 pub(crate) mod _f32;
453 pub(crate) mod _f64;
454 pub(crate) mod _i16;
455 pub(crate) mod _i32;
456 pub(crate) mod _i64;
457 pub(crate) mod _i8;
458 pub(crate) mod _isize;
459 pub(crate) mod _u16;
460 pub(crate) mod _u32;
461 pub(crate) mod _u64;
462 pub(crate) mod _u8;
463 pub(crate) mod _usize;
464}
465
466pub use vectors::*;
467#[cfg(feature = "archsimd")]
468mod simd {
469 pub use crate::vectors::arch_simd::*;
470}
471#[cfg(feature = "stdsimd")]
472mod simd {
473 pub use crate::vectors::std_simd::*;
474}
475
476#[cfg(all(
477 target_arch = "x86_64",
478 target_feature = "avx2",
479 not(feature = "stdsimd")
480))]
481pub(crate) mod sleef_types {
482 use std::arch::x86_64::*;
483 pub(crate) type VDouble = __m256d;
484 pub(crate) type VMask = __m256i;
485 pub(crate) type Vopmask = __m256i;
486 pub(crate) type VFloat = __m256;
487 pub(crate) type VInt = __m128i;
488 pub(crate) type VInt2 = __m256i;
489 pub(crate) type VInt64 = __m256i;
490 pub(crate) type VUInt64 = __m256i;
491}
492
493#[cfg(all(
494 target_arch = "x86_64",
495 target_feature = "sse",
496 not(target_feature = "avx2"),
497 not(feature = "stdsimd")
498))]
499pub(crate) mod sleef_types {
500 use std::arch::x86_64::*;
501 pub(crate) type VDouble = __m128d;
502 pub(crate) type VMask = __m128i;
503 pub(crate) type Vopmask = __m128i;
504 pub(crate) type VFloat = __m128;
505 pub(crate) type VInt = __m128i;
506 pub(crate) type VInt2 = __m128i;
507}
508
509#[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
510pub(crate) mod sleef_types {
511 use std::arch::aarch64::*;
512 pub(crate) type VDouble = float64x2_t;
513 pub(crate) type VMask = uint32x4_t;
514 pub(crate) type Vopmask = uint32x4_t;
515 pub(crate) type VFloat = float32x4_t;
516 pub(crate) type VInt = int32x2_t;
517 pub(crate) type VInt2 = int32x4_t;
518}