hpt/
lib.rs

1//! This crate is dynamic graph based tensor library
2#![deny(missing_docs)]
3
4/// a module contains all the Tensor operations. include the CPU and GPU operations
5pub(crate) mod backends {
6    /// a module contains all the CPU operations
7    pub(crate) mod cpu {
8        pub(crate) mod utils {
9            pub(crate) mod reduce {
10                pub(crate) mod reduce;
11                pub(crate) mod reduce_template;
12                pub(crate) mod reduce_utils;
13            }
14            pub(crate) mod diff {
15                pub(crate) mod diff_utils;
16            }
17            pub(crate) mod binary {
18                pub(crate) mod binary_normal;
19            }
20            pub(crate) mod unary {
21                pub(crate) mod unary;
22            }
23        }
24        /// a module defines all the std::ops operations
25        pub(crate) mod std_ops;
26        /// a module defines all the kernels
27        pub(crate) mod kernels {
28            pub(crate) mod argreduce_kernels;
29            pub(crate) mod reduce;
30            pub(crate) mod softmax;
31            pub(crate) mod pooling {
32                pub(crate) mod common;
33            }
34            pub(crate) mod normalization {
35                pub(crate) mod batch_norm;
36                pub(crate) mod log_softmax;
37                pub(crate) mod logsoftmax;
38                pub(crate) mod normalize_utils;
39                pub(crate) mod softmax;
40            }
41            pub(crate) mod conv2d {
42                pub(crate) mod batchnorm_conv2d;
43                pub(crate) mod conv2d;
44                pub(crate) mod conv2d_direct;
45                pub(crate) mod conv2d_group;
46                pub(crate) mod dwconv2d;
47                pub(crate) mod type_kernels {
48                    pub(crate) mod bf16_microkernels;
49                    pub(crate) mod bool_microkernels;
50                    pub(crate) mod complex32_microkernels;
51                    pub(crate) mod complex64_microkernels;
52                    pub(crate) mod f16_microkernels;
53                    pub(crate) mod f32_microkernels;
54                    pub(crate) mod f64_microkernels;
55                    pub(crate) mod i16_microkernels;
56                    pub(crate) mod i32_microkernels;
57                    pub(crate) mod i64_microkernels;
58                    pub(crate) mod i8_microkernels;
59                    pub(crate) mod isize_microkernels;
60                    pub(crate) mod u16_microkernels;
61                    pub(crate) mod u32_microkernels;
62                    pub(crate) mod u64_microkernels;
63                    pub(crate) mod u8_microkernels;
64                    pub(crate) mod usize_microkernels;
65                }
66                pub(crate) mod conv2d_img2col;
67                pub(crate) mod conv2d_micro_kernels;
68                pub(crate) mod conv2d_new_mp;
69                pub(crate) mod microkernel_trait;
70                pub(crate) mod utils;
71            }
72            /// a module defines gemm operation for cpu
73            pub(crate) mod matmul {
74                pub(crate) mod common;
75                pub(crate) mod matmul;
76                pub(crate) mod matmul_mixed_precision;
77                pub(crate) mod matmul_mp_post;
78                pub(crate) mod matmul_post;
79                pub(crate) mod microkernel_trait;
80                pub(crate) mod microkernels;
81                pub(crate) mod utils;
82                pub(crate) mod type_kernels {
83                    pub(crate) mod bf16_microkernels;
84                    pub(crate) mod bool_microkernels;
85                    pub(crate) mod complex32_microkernels;
86                    pub(crate) mod complex64_microkernels;
87                    pub(crate) mod f16_microkernels;
88                    pub(crate) mod f32_microkernels;
89                    pub(crate) mod f64_microkernels;
90                    pub(crate) mod i16_microkernels;
91                    pub(crate) mod i32_microkernels;
92                    pub(crate) mod i64_microkernels;
93                    pub(crate) mod i8_microkernels;
94                    pub(crate) mod isize_microkernels;
95                    pub(crate) mod u16_microkernels;
96                    pub(crate) mod u32_microkernels;
97                    pub(crate) mod u64_microkernels;
98                    pub(crate) mod u8_microkernels;
99                    pub(crate) mod usize_microkernels;
100                }
101            }
102        }
103        /// a module that contains all the functions expose for the external user (we may have diff tensor (differentiable tensor) in the future)
104        pub(crate) mod tensor_external {
105            /// a module that contains all the advance operations
106            pub(crate) mod advance;
107            /// a module that contains all the arg reduce functions
108            pub(crate) mod arg_reduce;
109            /// a module defines all normal binary operation
110            pub(crate) mod binary;
111            /// a module that contains all the tensor compare functions
112            pub(crate) mod cmp;
113            /// a module that contains all the common reduce functions
114            pub(crate) mod common_reduce;
115            /// a module that contains all the conv functions
116            pub(crate) mod conv;
117            /// a module that contains all the cumulative operations
118            pub(crate) mod cumulative;
119            /// a module that contains all fft operations
120            pub(crate) mod fft;
121            /// a module that contains all the float out binary operations
122            pub(crate) mod float_out_binary;
123            /// a module that contains all the unary operations that has floating type output
124            pub(crate) mod float_out_unary;
125            /// a module that contains all the gemm functions
126            pub(crate) mod gemm;
127            /// a module that contains matrix multiplication operations
128            pub(crate) mod matmul;
129            /// a module that contains all normal methods to create a tensor
130            pub(crate) mod normal_creation;
131            /// a module that contains all the unary operations that has self type output
132            pub(crate) mod normal_out_unary;
133            /// a module that contains all the normalization functions
134            pub(crate) mod normalization;
135            /// a module that contains all the pooling functions
136            pub(crate) mod pooling;
137            /// a module that contains all the random number generate functions
138            pub(crate) mod random;
139            /// a module that contains all the regularization functions
140            pub(crate) mod regularization;
141            /// a module that contains all the shape manipulation functions
142            pub(crate) mod shape_manipulate;
143            /// a module that contains all the slice functions
144            pub(crate) mod slice;
145            /// a module that contains all the tensordot functions
146            pub(crate) mod tensordot;
147            /// a module that contains all the windows creation functions
148            pub(crate) mod windows;
149        }
150        /// a module that contains all the functions only for the internal user (we may have diff tensor (differentiable tensor) in the future)
151        pub(crate) mod tensor_internal {
152            /// a module that contains all the advance operations
153            pub(crate) mod advance;
154            /// a module that contains all the arg reduce functions
155            pub(crate) mod arg_reduce;
156            /// a module that contains all the tensor compare functions
157            pub(crate) mod cmp;
158            /// a module that contains all the common reduce functions
159            pub(crate) mod common_reduce;
160            /// a module that contains all the conv functions
161            pub(crate) mod conv;
162            /// a module that contains all the cumulative operations
163            pub(crate) mod cumulative;
164            /// a module that contains all fft operations
165            pub(crate) mod fft;
166            /// a module that contains all the float out binary operations
167            pub(crate) mod float_out_binary;
168            /// a module that contains all the unary operations that has floating type output
169            pub(crate) mod float_out_unary;
170            /// a module that contains all the gemm functions
171            pub(crate) mod gemm;
172            /// a module that contains matrix multiplication operations
173            pub(crate) mod matmul;
174            /// a module that contains all normal methods to create a tensor
175            pub(crate) mod normal_creation;
176            /// a module that contains all the unary operations that has self type output
177            pub(crate) mod normal_out_unary;
178            /// a module that contains all the normalization functions
179            pub(crate) mod normalization;
180            /// a module that contains all the pooling functions
181            pub(crate) mod pooling;
182            /// a module that contains all the random number generate functions
183            pub(crate) mod random;
184            /// a module that contains all the regularization functions
185            pub(crate) mod regularization;
186            /// a module that contains all the shape manipulation functions
187            pub(crate) mod shape_manipulate;
188            /// a module that contains all the tensordot functions
189            pub(crate) mod tensordot;
190            /// a module that contains all the windows creation functions
191            pub(crate) mod windows;
192        }
193        /// a module contains cpu tensor impls
194        pub(crate) mod tensor_impls;
195    }
196
197    #[cfg(feature = "cuda")]
198    /// a module contains cuda tensor impls
199    pub(crate) mod cuda {
200        /// a module contains cuda tensor impls
201        pub(crate) mod tensor_impls;
202        /// a module contains cuda tensor internal impls
203        pub(crate) mod tensor_internal {
204            /// a module contains cuda tensor advanced impls
205            pub(crate) mod advance;
206            /// a module contains cuda tensor arg reduce impls
207            pub(crate) mod arg_reduce;
208            /// a module contains cuda tensor common reduce impls
209            pub(crate) mod common_reduce;
210            /// a module contains cuda tensor conv impls
211            pub(crate) mod conv2d;
212            /// a module contains cuda tensor float out binary impls
213            pub(crate) mod float_out_binary;
214            /// a module contains cuda tensor float out unary impls
215            pub(crate) mod float_out_unary;
216            /// a module contains cuda tensor layernorm impls
217            pub(crate) mod layernorm;
218            /// a module contains cuda matmul impls
219            pub(crate) mod matmul;
220            /// a module contains cuda tensor normal creation impls
221            pub(crate) mod normal_creation;
222            /// a module contains cuda tensor normal out unary impls
223            pub(crate) mod normal_out_unary;
224            /// a module contains cuda tensor normalization impls
225            pub(crate) mod normalization;
226            /// a module contains cuda tensor pooling impls
227            pub(crate) mod pooling;
228            /// a module contains cuda tensor shape manipulation impls
229            pub(crate) mod shape_manipulate;
230            /// a module contains cuda tensor softmax impls
231            pub(crate) mod softmax;
232            /// a module contains cuda tensor windows impls
233            pub(crate) mod windows;
234        }
235        pub(crate) mod tensor_external {
236            /// a module contains cuda tensor arg reduce impls
237            pub(crate) mod arg_reduce;
238            /// a module that contains inplace binary operations
239            pub(crate) mod binary;
240            /// a module contains cuda tensor cmp impls
241            pub(crate) mod cmp;
242            /// a module contains cuda tensor common reduce impls
243            pub(crate) mod common_reduce;
244            /// a module contains cuda tensor conv2d impls
245            pub(crate) mod conv2d;
246            /// a module contains cuda tensor float out binary impls
247            pub(crate) mod float_out_binary;
248            /// a module contains cuda tensor float out unary impls
249            pub(crate) mod float_out_unary;
250            /// a module contains cuda tensor gemm impls
251            pub(crate) mod gemm;
252            /// a module contains cuda tensor matmul impls
253            pub(crate) mod matmul;
254            /// a module contains cuda tensor normal creation impls
255            pub(crate) mod normal_creation;
256            /// a module contains cuda tensor normal out unary impls
257            pub(crate) mod normal_out_unary;
258            /// a module contains cuda tensor normalization impls
259            pub(crate) mod normalization;
260            /// a module contains cuda tensor random impls
261            pub(crate) mod random;
262            /// a module contains cuda tensor shape manipulation impls
263            pub(crate) mod shape_manipulate;
264            /// a module contains cuda tensor windows impls
265            pub(crate) mod windows;
266        }
267        pub(crate) mod utils {
268            pub(crate) mod reduce {
269                pub(crate) mod reduce;
270                pub(crate) mod reduce_template;
271                pub(crate) mod reduce_utils;
272            }
273            pub(crate) mod binary {
274                pub(crate) mod binary_normal;
275            }
276            pub(crate) mod unary {
277                pub(crate) mod unary;
278            }
279            pub(crate) mod launch_cfg {
280                pub(crate) mod launch_cfg_trait;
281            }
282        }
283        /// a module contains cuda slice impls
284        pub(crate) mod cuda_slice;
285        /// a module contains cuda utils
286        pub(crate) mod cuda_utils;
287        /// a module contains cuda std ops impls
288        pub(crate) mod std_ops;
289    }
290
291    /// a module contains all the common ops
292    pub(crate) mod common {
293        /// a module contains conv utils
294        pub(crate) mod conv;
295        /// a module contains all the functions to help create a tensor
296        pub(crate) mod creation;
297        /// a module contains fast divmod ops
298        pub(crate) mod divmod;
299        /// a module contains reduce utils
300        pub(crate) mod reduce;
301        /// a module contains all the shape manipulation ops
302        pub(crate) mod shape_manipulate;
303        /// a module contains slice op
304        pub(crate) mod slice;
305    }
306}
307pub(crate) mod tensor;
308pub(crate) mod tensor_base;
309pub(crate) mod to_tensor;
310#[cfg(feature = "cuda")]
311pub(crate) mod cuda_compiled {
312    use std::{
313        collections::HashMap,
314        sync::{Arc, Mutex},
315    };
316
317    use hpt_cudakernels::RegisterInfo;
318    use once_cell::sync::Lazy;
319
320    pub(crate) static CUDA_COMPILED: Lazy<
321        Mutex<HashMap<usize, HashMap<String, Arc<HashMap<String, RegisterInfo>>>>>,
322    > = Lazy::new(|| Mutex::new(HashMap::new()));
323}
324/// this module contains all the operators for the Tensor
325pub mod ops {
326    pub use hpt_traits::ops::advance::*;
327    pub use hpt_traits::ops::binary::*;
328    pub use hpt_traits::ops::cmp::*;
329    pub use hpt_traits::ops::conv::*;
330    pub use hpt_traits::ops::creation::*;
331    pub use hpt_traits::ops::cumulative::*;
332    pub use hpt_traits::ops::fft::*;
333    pub use hpt_traits::ops::normalization::*;
334    pub use hpt_traits::ops::pooling::*;
335    pub use hpt_traits::ops::random::*;
336    pub use hpt_traits::ops::reduce::*;
337    pub use hpt_traits::ops::regularization::*;
338    pub use hpt_traits::ops::shape_manipulate::*;
339    pub use hpt_traits::ops::slice::*;
340    pub use hpt_traits::ops::unary::*;
341    pub use hpt_traits::ops::windows::*;
342}
343
344/// module for error handling
345pub mod error {
346    pub use hpt_common::error::base::TensorError;
347}
348
349/// module for common utils like shape and strides
350pub mod common {
351    pub use hpt_common::{shape::shape::Shape, strides::strides::Strides, Pointer};
352    pub use hpt_traits::tensor::{CommonBounds, TensorInfo};
353    /// common utils for cpu
354    pub mod cpu {
355        pub use hpt_traits::tensor::TensorLike;
356    }
357}
358
359/// module for memory allocation
360pub mod alloc {
361    pub use hpt_allocator::traits::{Allocator, AllocatorOutputRetrive};
362}
363
364/// module for tensor iterator
365pub mod iter {
366    pub use hpt_iterator::iterator_traits::*;
367    pub use hpt_iterator::TensorIterator;
368    pub use rayon;
369}
370
371/// type related module
372pub mod types {
373    pub use half::{bf16, f16};
374    pub use num::complex::{Complex32, Complex64};
375    /// module contains vector types and traits
376    pub mod vectors {
377        pub use hpt_types::vectors::*;
378        /// module contains vector traits
379        pub mod traits {
380            pub use hpt_types::traits::VecTrait;
381        }
382    }
383    /// module contains cast traits, perform type conversion
384    pub mod cast {
385        pub use hpt_types::into_scalar::Cast;
386        pub use hpt_types::into_vec::IntoVec;
387    }
388    /// module contains math traits for scalar and vector, all the methods will auto promote the type
389    pub mod math {
390        pub use hpt_types::type_promote::{
391            BitWiseOut, Eval, FloatOutBinary, FloatOutBinaryPromote, FloatOutUnary,
392            FloatOutUnaryPromote, NormalOut, NormalOutPromote, NormalOutUnary,
393        };
394    }
395    /// module contains type common traits
396    pub use hpt_types::dtype::TypeCommon;
397}
398
399/// reexport serde
400pub mod re_exports {
401    #[cfg(feature = "cuda")]
402    pub use cudarc;
403    pub use seq_macro;
404    pub use serde;
405}
406
407pub use hpt_dataloader::{Load, Save};
408pub use hpt_macros::{Load, Save};
409
410/// module for save and load
411pub mod save_load {
412    pub use flate2;
413    pub use hpt_dataloader::data_loader::parse_header_compressed;
414    pub use hpt_dataloader::{
415        save, CompressionAlgo, DataLoader, Endian, FromSafeTensors, MetaLoad, TensorLoader,
416        TensorSaver,
417    };
418}
419
420/// module for backend
421pub mod backend {
422    pub use hpt_allocator::Cpu;
423    #[cfg(feature = "cuda")]
424    pub use hpt_allocator::Cuda;
425
426    pub use hpt_allocator::{BackendTy, Buffer};
427}
428
429/// module for buitin templates
430pub mod buitin_templates {
431    /// module for cpu buitin templates
432    pub mod cpu {
433        pub use crate::backends::cpu::kernels::matmul::matmul::matmul_template;
434        pub use crate::backends::cpu::utils::binary::binary_normal::binary_with_out;
435    }
436}
437
438/// module for utils, like set_num_threads, set_seed, etc.
439pub mod utils {
440    #[cfg(feature = "cuda")]
441    use crate::CUDA_SEED;
442    use crate::{DISPLAY_LR_ELEMENTS, DISPLAY_PRECISION, THREAD_POOL};
443    pub use hpt_allocator::resize_cpu_lru_cache;
444    #[cfg(feature = "cuda")]
445    pub use hpt_allocator::resize_cuda_lru_cache;
446    pub use hpt_macros::select;
447
448    /// Get the global number of threads
449    pub fn get_num_threads() -> usize {
450        THREAD_POOL.with(|x| x.borrow().max_count())
451    }
452    /// Set the Tensor display precision
453    pub fn set_display_precision(precision: usize) {
454        DISPLAY_PRECISION.store(precision, std::sync::atomic::Ordering::Relaxed);
455    }
456    /// Set the left and right elements to display for each dimension
457    pub fn set_display_elements(lr_elements: usize) {
458        DISPLAY_LR_ELEMENTS.store(lr_elements, std::sync::atomic::Ordering::Relaxed);
459    }
460    #[allow(unused)]
461    /// Set the seed for random number generation
462    pub fn set_seed<B: crate::backend::BackendTy>(seed: u64) {
463        match B::ID {
464            0 => {
465                panic!("CPU backend does not support setting seed");
466            }
467            #[cfg(feature = "cuda")]
468            1 => {
469                CUDA_SEED.store(seed, std::sync::atomic::Ordering::Relaxed);
470            }
471            _ => {
472                panic!("Unsupported backend {:?}", B::ID);
473            }
474        }
475    }
476    /// Set the global number of threads
477    ///
478    /// # Note
479    /// Rayon only allows the number of threads to be set once, so the rayon thread pool won't have any effect if it's called more than once.
480    pub fn set_num_threads(num_threads: usize) {
481        THREAD_POOL.with(|x| {
482            x.borrow_mut().set_num_threads(num_threads);
483        });
484        match rayon::ThreadPoolBuilder::new()
485            .num_threads(num_threads)
486            .stack_size(4 * 1024 * 1024)
487            .build_global()
488        {
489            Ok(_) => {}
490            Err(_) => {}
491        }
492    }
493}
494
495use ctor::ctor;
496use hpt_types::arch_simd as simd;
497use std::{cell::RefCell, sync::atomic::AtomicUsize};
498pub use tensor::Tensor;
499
500#[ctor]
501fn init() {
502    THREAD_POOL.with(|x| {
503        x.borrow_mut().set_num_threads(num_cpus::get_physical());
504    });
505}
506
507thread_local! {
508    static THREAD_POOL: RefCell<threadpool::ThreadPool> = RefCell::new(
509        threadpool::ThreadPool::new(num_cpus::get_physical())
510    );
511}
512static DISPLAY_PRECISION: AtomicUsize = AtomicUsize::new(4);
513static DISPLAY_LR_ELEMENTS: AtomicUsize = AtomicUsize::new(3);
514static ALIGN: usize = 64;
515
516#[cfg(target_feature = "avx2")]
517pub(crate) const REGNUM: usize = 16;
518#[cfg(all(not(target_feature = "avx2"), target_feature = "sse"))]
519pub(crate) const REGNUM: usize = 8;
520#[cfg(any(target_feature = "avx512f", target_arch = "aarch64"))]
521pub(crate) const REGNUM: usize = 32;
522
523#[cfg(target_feature = "avx2")]
524type BoolVector = simd::_256bit::boolx32;
525#[cfg(any(
526    all(not(target_feature = "avx2"), target_feature = "sse"),
527    target_feature = "neon"
528))]
529type BoolVector = simd::_128bit::boolx16;
530
531#[cfg(feature = "cuda")]
532const CUDA_SEED: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(2621654116416541);
533
534#[cfg(feature = "cuda")]
535thread_local! {
536    static CUDNN: RefCell<
537        std::collections::HashMap<usize, std::sync::Arc<cudarc::cudnn::Cudnn>>
538    > = std::collections::HashMap::new().into();
539}
540
541#[cfg(feature = "cuda")]
542thread_local! {
543    static CUBLAS: RefCell<
544        std::collections::HashMap<usize, std::sync::Arc<cudarc::cublas::CudaBlas>>
545    > = std::collections::HashMap::new().into();
546}