1#![deny(missing_docs)]
4
5mod allocators;
6mod backend;
7mod ptr;
8mod storage;
9pub(crate) mod utils {
10 pub(crate) mod allocate;
11 pub(crate) mod cache_resize;
12 pub(crate) mod deallocate;
13}
14pub mod traits;
16
17use std::marker::PhantomData;
18
19use crate::allocators::cpu::CACHE;
20#[cfg(feature = "cuda")]
21use crate::allocators::cuda::CUDA_CACHE;
22pub use crate::storage::clone_storage;
23pub use allocators::cpu::resize_cpu_lru_cache;
24#[cfg(feature = "cuda")]
25pub use allocators::cuda::resize_cuda_lru_cache;
26pub use backend::*;
27pub use storage::cpu::CPU_STORAGE;
28#[cfg(feature = "cuda")]
29pub use storage::cuda::CUDA_STORAGE;
30use traits::Allocator;
31#[allow(non_snake_case)]
33#[ctor::dtor]
34fn free_pools() {
35 CACHE.lock().unwrap().clear();
36 #[cfg(feature = "cuda")]
37 CUDA_CACHE.lock().unwrap().clear();
38}
39
40pub struct HptAllocator<B: BackendTy> {
42 phantom: PhantomData<B>,
43}
44
45impl<B: BackendTy> Clone for HptAllocator<B> {
46 fn clone(&self) -> Self {
47 HptAllocator {
48 phantom: PhantomData,
49 }
50 }
51}
52
53impl Allocator for HptAllocator<Cpu> {
54 type Output = *mut u8;
55 type CpuAllocator = HptAllocator<Cpu>;
56 #[cfg(feature = "cuda")]
57 type CudaAllocator = HptAllocator<Cuda>;
58 fn allocate(
59 &mut self,
60 layout: std::alloc::Layout,
61 device_id: usize,
62 ) -> Result<Self::Output, hpt_common::error::base::TensorError> {
63 CACHE.lock().unwrap().allocate(layout, device_id)
64 }
65 fn allocate_zeroed(
66 &mut self,
67 layout: std::alloc::Layout,
68 device_id: usize,
69 ) -> Result<Self::Output, hpt_common::error::base::TensorError> {
70 CACHE.lock().unwrap().allocate_zeroed(layout, device_id)
71 }
72 fn deallocate(&mut self, ptr: *mut u8, layout: &std::alloc::Layout, device_id: usize) {
73 CACHE.lock().unwrap().deallocate(ptr, layout, device_id);
74 }
75
76 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
77 CACHE.lock().unwrap().insert_ptr(ptr, device_id);
78 }
79
80 fn clear(&mut self) {
81 CACHE.lock().unwrap().clear();
82 }
83
84 fn new() -> Self {
85 HptAllocator {
86 phantom: PhantomData,
87 }
88 }
89}
90
91#[cfg(feature = "cuda")]
92impl Allocator for HptAllocator<Cuda> {
93 type Output = (*mut u8, std::sync::Arc<cudarc::driver::CudaDevice>);
94 type CpuAllocator = HptAllocator<Cpu>;
95 type CudaAllocator = HptAllocator<Cuda>;
96
97 fn allocate(
98 &mut self,
99 layout: std::alloc::Layout,
100 device_id: usize,
101 ) -> Result<Self::Output, hpt_common::error::base::TensorError> {
102 CUDA_CACHE.lock().unwrap().allocate(layout, device_id)
103 }
104
105 fn allocate_zeroed(
106 &mut self,
107 layout: std::alloc::Layout,
108 device_id: usize,
109 ) -> Result<Self::Output, hpt_common::error::base::TensorError> {
110 CUDA_CACHE
111 .lock()
112 .unwrap()
113 .allocate_zeroed(layout, device_id)
114 }
115
116 fn deallocate(&mut self, ptr: *mut u8, layout: &std::alloc::Layout, device_id: usize) {
117 CUDA_CACHE
118 .lock()
119 .unwrap()
120 .deallocate(ptr, layout, device_id);
121 }
122
123 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
124 CUDA_CACHE.lock().unwrap().insert_ptr(ptr, device_id);
125 }
126
127 fn clear(&mut self) {
128 CUDA_CACHE.lock().unwrap().clear();
129 }
130
131 fn new() -> Self {
132 HptAllocator {
133 phantom: PhantomData,
134 }
135 }
136}
137
138unsafe impl<B: BackendTy> Send for HptAllocator<B> {}
139unsafe impl<B: BackendTy> Sync for HptAllocator<B> {}