hpt_allocator/allocators/
cpu.rs1use std::{
2 alloc::Layout,
3 collections::{HashMap, HashSet},
4 num::NonZeroUsize,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Mutex,
8 },
9};
10
11use crate::{
12 ptr::SafePtr,
13 storage::{cpu::CPU_STORAGE, CommonStorage, Storage},
14 traits::Allocator,
15};
16use hpt_common::error::base::TensorError;
17use lru::LruCache;
18use once_cell::sync::Lazy;
19
20pub(crate) static CACHE: Lazy<Mutex<CpuAllocator>> = Lazy::new(|| Mutex::new(CpuAllocator::new()));
22
23pub(crate) static CPU_LRU_CACHE_SIZE: AtomicUsize = AtomicUsize::new(100);
24
25#[derive(Clone)]
39pub struct CpuAllocator {
40 allocator: HashMap<usize, _Allocator>,
41}
42
43impl Allocator for CpuAllocator {
44 type Output = *mut u8;
45 type CpuAllocator = CpuAllocator;
46 #[cfg(feature = "cuda")]
47 type CudaAllocator = crate::allocators::cuda::CudaAllocator;
48 fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError> {
49 if let Some(allocator) = self.allocator.get_mut(&device_id) {
50 allocator.allocate(layout, device_id)
51 } else {
52 let mut allocator = _Allocator {
53 cache: LruCache::new(
54 NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
55 ),
56 allocated: HashSet::new(),
57 };
58 let ptr = allocator.allocate(layout, device_id)?;
59 self.allocator.insert(device_id, allocator);
60 Ok(ptr)
61 }
62 }
63 fn allocate_zeroed(
64 &mut self,
65 layout: Layout,
66 device_id: usize,
67 ) -> Result<Self::Output, TensorError> {
68 if let Some(allocator) = self.allocator.get_mut(&device_id) {
69 allocator.allocate_zeroed(layout, device_id)
70 } else {
71 let mut allocator = _Allocator {
72 cache: LruCache::new(
73 NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
74 ),
75 allocated: HashSet::new(),
76 };
77 let ptr = allocator.allocate_zeroed(layout, device_id)?;
78 self.allocator.insert(device_id, allocator);
79 Ok(ptr)
80 }
81 }
82
83 fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
84 if let Some(allocator) = self.allocator.get_mut(&device_id) {
85 allocator.deallocate(ptr, layout, should_drop, device_id);
86 } else {
87 if !should_drop {
88 return;
89 }
90 panic!("device {} not found in allocator", device_id);
91 }
92 }
93 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
94 if let Some(allocator) = self.allocator.get_mut(&device_id) {
95 allocator.insert_ptr(ptr, device_id);
96 } else {
97 let mut allocator = _Allocator {
98 cache: LruCache::new(
99 NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
100 ),
101 allocated: HashSet::new(),
102 };
103 allocator.insert_ptr(ptr, device_id);
104 self.allocator.insert(device_id, allocator);
105 }
106 }
107 fn clear(&mut self) {
108 if let Ok(mut storage) = CPU_STORAGE.lock() {
109 let storage: &mut HashMap<usize, CommonStorage> = &mut storage;
110 for (device_id, allocator) in self.allocator.iter_mut() {
111 for (layout, ptrs) in allocator.cache.iter_mut() {
112 for ptr in ptrs.iter() {
113 storage
114 .get_mut(device_id)
115 .unwrap()
116 .decrement_ref(SafePtr { ptr: ptr.ptr });
117 unsafe {
118 std::alloc::dealloc(ptr.ptr, layout.clone());
119 }
120 }
121 }
122 allocator.cache.clear();
123 assert_eq!(allocator.allocated.len(), 0);
124 assert_eq!(storage[device_id].storage.len(), 0);
125 }
126 }
127 }
128
129 fn new() -> Self {
130 CpuAllocator {
131 allocator: HashMap::new(),
132 }
133 }
134
135 fn forget(&mut self, ptr: *mut u8, device_id: usize) {
139 if let Some(allocator) = self.allocator.get_mut(&device_id) {
140 if allocator.allocated.get(&SafePtr { ptr }).is_some() {
141 allocator.allocated.remove(&SafePtr { ptr });
142 }
143 }
144 }
145}
146
147#[derive(Clone)]
148struct _Allocator {
149 cache: LruCache<Layout, Vec<SafePtr>>,
150 allocated: HashSet<SafePtr>,
151}
152
153impl _Allocator {
154 fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<*mut u8, TensorError> {
155 if let Ok(mut storage) = CPU_STORAGE.lock() {
156 crate::utils::allocate::allocate_helper(
157 &mut self.cache,
158 &mut self.allocated,
159 &mut storage,
160 || unsafe { std::alloc::alloc(layout) },
161 |_, _| {},
162 |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
163 layout,
164 device_id,
165 )
166 } else {
167 panic!("Failed to lock CPU_STORAGE");
168 }
169 }
170
171 fn allocate_zeroed(
172 &mut self,
173 layout: Layout,
174 device_id: usize,
175 ) -> Result<*mut u8, TensorError> {
176 if let Ok(mut storage) = CPU_STORAGE.lock() {
177 crate::utils::allocate::allocate_helper(
178 &mut self.cache,
179 &mut self.allocated,
180 &mut storage,
181 || unsafe { std::alloc::alloc_zeroed(layout) },
182 |ptr, layout| unsafe {
183 std::ptr::write_bytes(ptr, 0, layout.size());
184 },
185 |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
186 layout,
187 device_id,
188 )
189 } else {
190 panic!("Failed to lock CPU_STORAGE");
191 }
192 }
193
194 fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
200 if let Ok(mut storage) = CPU_STORAGE.lock() {
201 crate::utils::deallocate::deallocate_helper(
202 &mut self.cache,
203 &mut self.allocated,
204 &mut storage,
205 layout,
206 ptr,
207 should_drop,
208 device_id,
209 );
210 } else {
211 panic!("Failed to lock CPU_STORAGE");
212 }
213 }
214
215 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
221 self.allocated.insert(SafePtr { ptr });
222 if let Ok(mut map) = CPU_STORAGE.lock() {
223 if let Some(storage) = map.get_mut(&device_id) {
224 storage.increment_ref(SafePtr { ptr });
225 } else {
226 let mut storage = CommonStorage::new();
227 storage.increment_ref(SafePtr { ptr });
228 map.insert(device_id, storage);
229 }
230 }
231 }
232}
233
234pub fn resize_cpu_lru_cache(new_size: usize, device_id: usize) {
240 if let Ok(mut cache) = CACHE.lock() {
241 if let Some(allocator) = cache.allocator.get_mut(&device_id) {
242 crate::utils::cache_resize::resize_lru_cache(
243 &mut allocator.cache,
244 |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
245 new_size,
246 );
247 } else {
248 let allocator = _Allocator {
249 cache: LruCache::new(NonZeroUsize::new(new_size).unwrap()),
250 allocated: HashSet::new(),
251 };
252 cache.allocator.insert(device_id, allocator);
253 }
254 } else {
255 panic!("Failed to lock CACHE");
256 }
257 CPU_LRU_CACHE_SIZE.store(new_size, Ordering::Relaxed);
258}