1#![deny(missing_docs)]
3
4pub(crate) mod backends {
6 pub(crate) mod cpu {
8 pub(crate) mod utils {
9 pub(crate) mod reduce {
10 pub(crate) mod reduce;
11 pub(crate) mod reduce_template;
12 pub(crate) mod reduce_utils;
13 }
14 pub(crate) mod diff {
15 pub(crate) mod diff_utils;
16 }
17 pub(crate) mod binary {
18 pub(crate) mod binary_normal;
19 }
20 pub(crate) mod unary {
21 pub(crate) mod unary;
22 }
23 }
24 pub(crate) mod std_ops;
26 pub(crate) mod kernels {
28 pub(crate) mod argreduce_kernels;
29 pub(crate) mod reduce;
30 pub(crate) mod softmax;
31 pub(crate) mod pooling {
32 pub(crate) mod common;
33 }
34 pub(crate) mod normalization {
35 pub(crate) mod batch_norm;
36 pub(crate) mod log_softmax;
37 pub(crate) mod logsoftmax;
38 pub(crate) mod normalize_utils;
39 pub(crate) mod softmax;
40 }
41 pub(crate) mod conv2d {
42 pub(crate) mod batchnorm_conv2d;
43 pub(crate) mod conv2d;
44 pub(crate) mod conv2d_direct;
45 pub(crate) mod conv2d_group;
46 pub(crate) mod dwconv2d;
47 pub(crate) mod type_kernels {
48 pub(crate) mod bf16_microkernels;
49 pub(crate) mod bool_microkernels;
50 pub(crate) mod complex32_microkernels;
51 pub(crate) mod complex64_microkernels;
52 pub(crate) mod f16_microkernels;
53 pub(crate) mod f32_microkernels;
54 pub(crate) mod f64_microkernels;
55 pub(crate) mod i16_microkernels;
56 pub(crate) mod i32_microkernels;
57 pub(crate) mod i64_microkernels;
58 pub(crate) mod i8_microkernels;
59 pub(crate) mod isize_microkernels;
60 pub(crate) mod u16_microkernels;
61 pub(crate) mod u32_microkernels;
62 pub(crate) mod u64_microkernels;
63 pub(crate) mod u8_microkernels;
64 pub(crate) mod usize_microkernels;
65 }
66 pub(crate) mod conv2d_img2col;
67 pub(crate) mod conv2d_micro_kernels;
68 pub(crate) mod conv2d_new_mp;
69 pub(crate) mod microkernel_trait;
70 pub(crate) mod utils;
71 }
72 pub(crate) mod matmul {
74 pub(crate) mod common;
75 pub(crate) mod matmul;
76 pub(crate) mod matmul_mixed_precision;
77 pub(crate) mod matmul_mp_post;
78 pub(crate) mod matmul_post;
79 pub(crate) mod microkernel_trait;
80 pub(crate) mod microkernels;
81 pub(crate) mod utils;
82 pub(crate) mod type_kernels {
83 pub(crate) mod bf16_microkernels;
84 pub(crate) mod bool_microkernels;
85 pub(crate) mod complex32_microkernels;
86 pub(crate) mod complex64_microkernels;
87 pub(crate) mod f16_microkernels;
88 pub(crate) mod f32_microkernels;
89 pub(crate) mod f64_microkernels;
90 pub(crate) mod i16_microkernels;
91 pub(crate) mod i32_microkernels;
92 pub(crate) mod i64_microkernels;
93 pub(crate) mod i8_microkernels;
94 pub(crate) mod isize_microkernels;
95 pub(crate) mod u16_microkernels;
96 pub(crate) mod u32_microkernels;
97 pub(crate) mod u64_microkernels;
98 pub(crate) mod u8_microkernels;
99 pub(crate) mod usize_microkernels;
100 }
101 }
102 }
103 pub(crate) mod tensor_external {
105 pub(crate) mod advance;
107 pub(crate) mod arg_reduce;
109 pub(crate) mod binary;
111 pub(crate) mod cmp;
113 pub(crate) mod common_reduce;
115 pub(crate) mod conv;
117 pub(crate) mod cumulative;
119 pub(crate) mod fft;
121 pub(crate) mod float_out_binary;
123 pub(crate) mod float_out_unary;
125 pub(crate) mod gemm;
127 pub(crate) mod matmul;
129 pub(crate) mod normal_creation;
131 pub(crate) mod normal_out_unary;
133 pub(crate) mod normalization;
135 pub(crate) mod pooling;
137 pub(crate) mod random;
139 pub(crate) mod regularization;
141 pub(crate) mod shape_manipulate;
143 pub(crate) mod slice;
145 pub(crate) mod tensordot;
147 pub(crate) mod windows;
149 }
150 pub(crate) mod tensor_internal {
152 pub(crate) mod advance;
154 pub(crate) mod arg_reduce;
156 pub(crate) mod cmp;
158 pub(crate) mod common_reduce;
160 pub(crate) mod conv;
162 pub(crate) mod cumulative;
164 pub(crate) mod fft;
166 pub(crate) mod float_out_binary;
168 pub(crate) mod float_out_unary;
170 pub(crate) mod gemm;
172 pub(crate) mod matmul;
174 pub(crate) mod normal_creation;
176 pub(crate) mod normal_out_unary;
178 pub(crate) mod normalization;
180 pub(crate) mod pooling;
182 pub(crate) mod random;
184 pub(crate) mod regularization;
186 pub(crate) mod shape_manipulate;
188 pub(crate) mod tensordot;
190 pub(crate) mod windows;
192 }
193 pub(crate) mod tensor_impls;
195 }
196
197 #[cfg(feature = "cuda")]
198 pub(crate) mod cuda {
200 pub(crate) mod tensor_impls;
202 pub(crate) mod tensor_internal {
204 pub(crate) mod advance;
206 pub(crate) mod arg_reduce;
208 pub(crate) mod common_reduce;
210 pub(crate) mod conv2d;
212 pub(crate) mod float_out_binary;
214 pub(crate) mod float_out_unary;
216 pub(crate) mod layernorm;
218 pub(crate) mod matmul;
220 pub(crate) mod normal_creation;
222 pub(crate) mod normal_out_unary;
224 pub(crate) mod normalization;
226 pub(crate) mod pooling;
228 pub(crate) mod shape_manipulate;
230 pub(crate) mod softmax;
232 pub(crate) mod windows;
234 }
235 pub(crate) mod tensor_external {
236 pub(crate) mod arg_reduce;
238 pub(crate) mod binary;
240 pub(crate) mod cmp;
242 pub(crate) mod common_reduce;
244 pub(crate) mod conv2d;
246 pub(crate) mod float_out_binary;
248 pub(crate) mod float_out_unary;
250 pub(crate) mod gemm;
252 pub(crate) mod matmul;
254 pub(crate) mod normal_creation;
256 pub(crate) mod normal_out_unary;
258 pub(crate) mod normalization;
260 pub(crate) mod random;
262 pub(crate) mod shape_manipulate;
264 pub(crate) mod windows;
266 }
267 pub(crate) mod utils {
268 pub(crate) mod reduce {
269 pub(crate) mod reduce;
270 pub(crate) mod reduce_template;
271 pub(crate) mod reduce_utils;
272 }
273 pub(crate) mod binary {
274 pub(crate) mod binary_normal;
275 }
276 pub(crate) mod unary {
277 pub(crate) mod unary;
278 }
279 pub(crate) mod launch_cfg {
280 pub(crate) mod launch_cfg_trait;
281 }
282 }
283 pub(crate) mod cuda_slice;
285 pub(crate) mod cuda_utils;
287 pub(crate) mod std_ops;
289 }
290
291 pub(crate) mod common {
293 pub(crate) mod conv;
295 pub(crate) mod creation;
297 pub(crate) mod divmod;
299 pub(crate) mod reduce;
301 pub(crate) mod shape_manipulate;
303 pub(crate) mod slice;
305 }
306}
307pub(crate) mod tensor;
308pub(crate) mod tensor_base;
309pub(crate) mod to_tensor;
310#[cfg(feature = "cuda")]
311pub(crate) mod cuda_compiled {
312 use std::{
313 collections::HashMap,
314 sync::{Arc, Mutex},
315 };
316
317 use hpt_cudakernels::RegisterInfo;
318 use once_cell::sync::Lazy;
319
320 pub(crate) static CUDA_COMPILED: Lazy<
321 Mutex<HashMap<usize, HashMap<String, Arc<HashMap<String, RegisterInfo>>>>>,
322 > = Lazy::new(|| Mutex::new(HashMap::new()));
323}
324pub mod ops {
326 pub use hpt_traits::ops::advance::*;
327 pub use hpt_traits::ops::binary::*;
328 pub use hpt_traits::ops::cmp::*;
329 pub use hpt_traits::ops::conv::*;
330 pub use hpt_traits::ops::creation::*;
331 pub use hpt_traits::ops::cumulative::*;
332 pub use hpt_traits::ops::fft::*;
333 pub use hpt_traits::ops::normalization::*;
334 pub use hpt_traits::ops::pooling::*;
335 pub use hpt_traits::ops::random::*;
336 pub use hpt_traits::ops::reduce::*;
337 pub use hpt_traits::ops::regularization::*;
338 pub use hpt_traits::ops::shape_manipulate::*;
339 pub use hpt_traits::ops::slice::*;
340 pub use hpt_traits::ops::unary::*;
341 pub use hpt_traits::ops::windows::*;
342}
343
344pub mod error {
346 pub use hpt_common::error::base::TensorError;
347}
348
349pub mod common {
351 pub use hpt_common::{shape::shape::Shape, strides::strides::Strides, Pointer};
352 pub use hpt_traits::tensor::{CommonBounds, TensorInfo};
353 pub mod cpu {
355 pub use hpt_traits::tensor::TensorLike;
356 }
357}
358
359pub mod alloc {
361 pub use hpt_allocator::traits::{Allocator, AllocatorOutputRetrive};
362}
363
364pub mod iter {
366 pub use hpt_iterator::iterator_traits::*;
367 pub use hpt_iterator::TensorIterator;
368 pub use rayon;
369}
370
371pub mod types {
373 pub use half::{bf16, f16};
374 pub use num::complex::{Complex32, Complex64};
375 pub mod vectors {
377 pub use hpt_types::vectors::*;
378 pub mod traits {
380 pub use hpt_types::traits::VecTrait;
381 }
382 }
383 pub mod cast {
385 pub use hpt_types::into_scalar::Cast;
386 pub use hpt_types::into_vec::IntoVec;
387 }
388 pub mod math {
390 pub use hpt_types::type_promote::{
391 BitWiseOut, Eval, FloatOutBinary, FloatOutBinaryPromote, FloatOutUnary,
392 FloatOutUnaryPromote, NormalOut, NormalOutPromote, NormalOutUnary,
393 };
394 }
395 pub use hpt_types::dtype::TypeCommon;
397}
398
399pub mod re_exports {
401 #[cfg(feature = "cuda")]
402 pub use cudarc;
403 pub use seq_macro;
404 pub use serde;
405}
406
407pub use hpt_dataloader::{Load, Save};
408pub use hpt_macros::{Load, Save};
409
410pub mod save_load {
412 pub use flate2;
413 pub use hpt_dataloader::data_loader::parse_header_compressed;
414 pub use hpt_dataloader::{
415 save, CompressionAlgo, DataLoader, Endian, FromSafeTensors, MetaLoad, TensorLoader,
416 TensorSaver,
417 };
418}
419
420pub mod backend {
422 pub use hpt_allocator::Cpu;
423 #[cfg(feature = "cuda")]
424 pub use hpt_allocator::Cuda;
425
426 pub use hpt_allocator::{BackendTy, Buffer};
427}
428
429pub mod buitin_templates {
431 pub mod cpu {
433 pub use crate::backends::cpu::kernels::matmul::matmul::matmul_template;
434 pub use crate::backends::cpu::utils::binary::binary_normal::binary_with_out;
435 }
436}
437
438pub mod utils {
440 #[cfg(feature = "cuda")]
441 use crate::CUDA_SEED;
442 use crate::{DISPLAY_LR_ELEMENTS, DISPLAY_PRECISION, THREAD_POOL};
443 pub use hpt_allocator::resize_cpu_lru_cache;
444 #[cfg(feature = "cuda")]
445 pub use hpt_allocator::resize_cuda_lru_cache;
446 pub use hpt_macros::select;
447
448 pub fn get_num_threads() -> usize {
450 THREAD_POOL.with(|x| x.borrow().max_count())
451 }
452 pub fn set_display_precision(precision: usize) {
454 DISPLAY_PRECISION.store(precision, std::sync::atomic::Ordering::Relaxed);
455 }
456 pub fn set_display_elements(lr_elements: usize) {
458 DISPLAY_LR_ELEMENTS.store(lr_elements, std::sync::atomic::Ordering::Relaxed);
459 }
460 #[allow(unused)]
461 pub fn set_seed<B: crate::backend::BackendTy>(seed: u64) {
463 match B::ID {
464 0 => {
465 panic!("CPU backend does not support setting seed");
466 }
467 #[cfg(feature = "cuda")]
468 1 => {
469 CUDA_SEED.store(seed, std::sync::atomic::Ordering::Relaxed);
470 }
471 _ => {
472 panic!("Unsupported backend {:?}", B::ID);
473 }
474 }
475 }
476 pub fn set_num_threads(num_threads: usize) {
481 THREAD_POOL.with(|x| {
482 x.borrow_mut().set_num_threads(num_threads);
483 });
484 match rayon::ThreadPoolBuilder::new()
485 .num_threads(num_threads)
486 .stack_size(4 * 1024 * 1024)
487 .build_global()
488 {
489 Ok(_) => {}
490 Err(_) => {}
491 }
492 }
493}
494
495use ctor::ctor;
496use hpt_types::arch_simd as simd;
497use std::{cell::RefCell, sync::atomic::AtomicUsize};
498pub use tensor::Tensor;
499
500#[ctor]
501fn init() {
502 THREAD_POOL.with(|x| {
503 x.borrow_mut().set_num_threads(num_cpus::get_physical());
504 });
505}
506
507thread_local! {
508 static THREAD_POOL: RefCell<threadpool::ThreadPool> = RefCell::new(
509 threadpool::ThreadPool::new(num_cpus::get_physical())
510 );
511}
512static DISPLAY_PRECISION: AtomicUsize = AtomicUsize::new(4);
513static DISPLAY_LR_ELEMENTS: AtomicUsize = AtomicUsize::new(3);
514static ALIGN: usize = 64;
515
516#[cfg(target_feature = "avx2")]
517pub(crate) const REGNUM: usize = 16;
518#[cfg(all(not(target_feature = "avx2"), target_feature = "sse"))]
519pub(crate) const REGNUM: usize = 8;
520#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
521pub(crate) const REGNUM: usize = 32;
522
523#[cfg(target_feature = "avx2")]
524type BoolVector = simd::_256bit::boolx32;
525#[cfg(any(
526 all(not(target_feature = "avx2"), target_feature = "sse"),
527 target_feature = "neon"
528))]
529type BoolVector = simd::_128bit::boolx16;
530
531#[cfg(feature = "cuda")]
532const CUDA_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(2621654116416541);
533
534#[cfg(feature = "cuda")]
535thread_local! {
536 static CUDNN: RefCell<
537 std::collections::HashMap<usize, std::sync::Arc<cudarc::cudnn::Cudnn>>
538 > = std::collections::HashMap::new().into();
539}
540
541#[cfg(feature = "cuda")]
542thread_local! {
543 static CUBLAS: RefCell<
544 std::collections::HashMap<usize, std::sync::Arc<cudarc::cublas::CudaBlas>>
545 > = std::collections::HashMap::new().into();
546}