hpt_allocator/allocators/
cpu.rs1use std::{alloc::Layout, num::NonZeroUsize, panic::Location, sync::Mutex};
2
3use crate::{ptr::SafePtr, storage::cpu::CPU_STORAGE, storage::Storage, traits::Allocator};
4use hashbrown::{HashMap, HashSet};
5use hpt_common::error::base::TensorError;
6use hpt_common::error::memory::MemoryError;
7use lru::LruCache;
8use once_cell::sync::Lazy;
9
10pub static CACHE: Lazy<Mutex<CpuAllocator>> = Lazy::new(|| Mutex::new(CpuAllocator::new()));
12
13pub struct CpuAllocator {
27 allocator: HashMap<usize, _Allocator>,
28}
29
30impl Allocator for CpuAllocator {
31 fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<*mut u8, TensorError> {
32 if let Some(allocator) = self.allocator.get_mut(&device_id) {
33 allocator.allocate(layout, device_id)
34 } else {
35 let mut allocator = _Allocator {
36 cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
37 allocated: HashSet::new(),
38 };
39 let ptr = allocator.allocate(layout, device_id)?;
40 self.allocator.insert(device_id, allocator);
41 Ok(ptr)
42 }
43 }
44 fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, device_id: usize) {
45 if let Some(allocator) = self.allocator.get_mut(&device_id) {
46 allocator.deallocate(ptr, layout);
47 }
48 }
49
50 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
51 if let Some(allocator) = self.allocator.get_mut(&device_id) {
52 allocator.insert_ptr(ptr);
53 } else {
54 let mut allocator = _Allocator {
55 cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
56 allocated: HashSet::new(),
57 };
58 allocator.insert_ptr(ptr);
59 self.allocator.insert(device_id, allocator);
60 }
61 }
62
63 fn clear(&mut self) {
64 for (_, allocator) in self.allocator.iter_mut() {
65 for (layout, ptrs) in allocator.cache.iter_mut() {
66 for ptr in ptrs.iter() {
67 unsafe {
68 std::alloc::dealloc(ptr.ptr, layout.clone());
69 }
70 }
71 }
72 allocator.cache.clear();
73 assert_eq!(allocator.allocated.len(), 0);
74 }
75 }
76}
77
78impl CpuAllocator {
79 pub fn new() -> Self {
80 CpuAllocator {
81 allocator: HashMap::new(),
82 }
83 }
84}
85
86struct _Allocator {
87 cache: LruCache<Layout, Vec<SafePtr>>,
88 allocated: HashSet<SafePtr>,
89}
90
91impl _Allocator {
92 fn allocate(
105 &mut self,
106 layout: Layout,
107 device_id: usize,
108 ) -> std::result::Result<*mut u8, TensorError> {
109 let ptr = if let Some(ptr) = self.cache.get_mut(&layout)
110 {
112 if let Some(safe_ptr) = ptr.pop() {
114 safe_ptr.ptr
115 } else {
116 let ptr = unsafe { std::alloc::alloc(layout) };
117 if ptr.is_null() {
118 return Err(TensorError::Memory(MemoryError::AllocationFailed {
119 device: "cpu".to_string(),
120 id: device_id,
121 size: layout.size() / 1024 / 1024,
122 source: None,
123 location: Location::caller(),
124 }));
125 }
126 self.allocated.insert(SafePtr { ptr });
127 ptr
128 }
129 } else {
130 let ptr = unsafe { std::alloc::alloc(layout) };
131 if ptr.is_null() {
132 return Err(TensorError::Memory(MemoryError::AllocationFailed {
133 device: "cpu".to_string(),
134 id: device_id,
135 size: layout.size() / 1024 / 1024,
136 source: None,
137 location: Location::caller(),
138 }));
139 }
140 self.allocated.insert(SafePtr { ptr });
141 ptr
142 };
143 if self.cache.cap().get() == self.cache.len() {
145 if let Some((layout, ptrs)) = self.cache.pop_lru() {
146 for safe_ptr in ptrs {
147 unsafe {
148 std::alloc::dealloc(safe_ptr.ptr, layout);
149 }
150 }
151 }
152 }
153 if let Ok(mut storage) = CPU_STORAGE.lock() {
155 storage.increment_ref(SafePtr { ptr });
156 }
157 Ok(ptr)
158 }
159
160 fn deallocate(&mut self, ptr: *mut u8, layout: &Layout) {
166 if let Ok(mut storage) = CPU_STORAGE.lock() {
167 if storage.decrement_ref(SafePtr { ptr }) {
168 self.allocated.remove(&SafePtr { ptr });
169 if let Some(ptrs) = self.cache.get_mut(layout) {
170 ptrs.push(SafePtr { ptr });
171 } else {
172 self.cache.put(layout.clone(), vec![SafePtr { ptr }]);
173 }
174 }
175 }
176 }
177
178 fn insert_ptr(&mut self, ptr: *mut u8) {
184 self.allocated.insert(SafePtr { ptr });
185 if let Ok(mut storage) = CPU_STORAGE.lock() {
186 storage.increment_ref(SafePtr { ptr });
187 }
188 }
189}