hpt_allocator/allocators/
cpu.rs

1use std::{
2    alloc::Layout,
3    collections::{HashMap, HashSet},
4    num::NonZeroUsize,
5    sync::{
6        atomic::{AtomicUsize, Ordering},
7        Mutex,
8    },
9};
10
11use crate::{
12    ptr::SafePtr,
13    storage::{cpu::CPU_STORAGE, CommonStorage, Storage},
14    traits::Allocator,
15};
16use hpt_common::error::base::TensorError;
17use lru::LruCache;
18use once_cell::sync::Lazy;
19
20/// `lru` cache allocator
21pub(crate) static CACHE: Lazy<Mutex<CpuAllocator>> = Lazy::new(|| Mutex::new(CpuAllocator::new()));
22
23pub(crate) static CPU_LRU_CACHE_SIZE: AtomicUsize = AtomicUsize::new(100);
24
25/// # Allocator
26///
27/// a `lru` based allocator, to allocate and deallocate memory
28///
29/// this allocator is used widely in the library, to allocate and deallocate memory
30///
31/// # Safety
32///
33/// thread safe
34///
35/// # Potential Memory Leak
36///
37/// developer must carefully manage the reference count of the pointer allocated
38#[derive(Clone)]
39pub struct CpuAllocator {
40    allocator: HashMap<usize, _Allocator>,
41}
42
43impl Allocator for CpuAllocator {
44    type Output = *mut u8;
45    type CpuAllocator = CpuAllocator;
46    #[cfg(feature = "cuda")]
47    type CudaAllocator = crate::allocators::cuda::CudaAllocator;
48    fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError> {
49        if let Some(allocator) = self.allocator.get_mut(&device_id) {
50            allocator.allocate(layout, device_id)
51        } else {
52            let mut allocator = _Allocator {
53                cache: LruCache::new(
54                    NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
55                ),
56                allocated: HashSet::new(),
57            };
58            let ptr = allocator.allocate(layout, device_id)?;
59            self.allocator.insert(device_id, allocator);
60            Ok(ptr)
61        }
62    }
63    fn allocate_zeroed(
64        &mut self,
65        layout: Layout,
66        device_id: usize,
67    ) -> Result<Self::Output, TensorError> {
68        if let Some(allocator) = self.allocator.get_mut(&device_id) {
69            allocator.allocate_zeroed(layout, device_id)
70        } else {
71            let mut allocator = _Allocator {
72                cache: LruCache::new(
73                    NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
74                ),
75                allocated: HashSet::new(),
76            };
77            let ptr = allocator.allocate_zeroed(layout, device_id)?;
78            self.allocator.insert(device_id, allocator);
79            Ok(ptr)
80        }
81    }
82
83    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
84        if let Some(allocator) = self.allocator.get_mut(&device_id) {
85            allocator.deallocate(ptr, layout, should_drop, device_id);
86        } else {
87            if !should_drop {
88                return;
89            }
90            panic!("device {} not found in allocator", device_id);
91        }
92    }
93    fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
94        if let Some(allocator) = self.allocator.get_mut(&device_id) {
95            allocator.insert_ptr(ptr, device_id);
96        } else {
97            let mut allocator = _Allocator {
98                cache: LruCache::new(
99                    NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
100                ),
101                allocated: HashSet::new(),
102            };
103            allocator.insert_ptr(ptr, device_id);
104            self.allocator.insert(device_id, allocator);
105        }
106    }
107    fn clear(&mut self) {
108        if let Ok(mut storage) = CPU_STORAGE.lock() {
109            let storage: &mut HashMap<usize, CommonStorage> = &mut storage;
110            for (device_id, allocator) in self.allocator.iter_mut() {
111                for (layout, ptrs) in allocator.cache.iter_mut() {
112                    for ptr in ptrs.iter() {
113                        storage
114                            .get_mut(device_id)
115                            .unwrap()
116                            .decrement_ref(SafePtr { ptr: ptr.ptr });
117                        unsafe {
118                            std::alloc::dealloc(ptr.ptr, layout.clone());
119                        }
120                    }
121                }
122                allocator.cache.clear();
123                assert_eq!(allocator.allocated.len(), 0);
124                assert_eq!(storage[device_id].storage.len(), 0);
125            }
126        }
127    }
128
129    fn new() -> Self {
130        CpuAllocator {
131            allocator: HashMap::new(),
132        }
133    }
134
135    /// # Forget
136    ///
137    /// forget the ptr from the storage, remove the ptr from the allocated set
138    fn forget(&mut self, ptr: *mut u8, device_id: usize) {
139        if let Some(allocator) = self.allocator.get_mut(&device_id) {
140            if allocator.allocated.get(&SafePtr { ptr }).is_some() {
141                allocator.allocated.remove(&SafePtr { ptr });
142            }
143        }
144    }
145}
146
147#[derive(Clone)]
148struct _Allocator {
149    cache: LruCache<Layout, Vec<SafePtr>>,
150    allocated: HashSet<SafePtr>,
151}
152
153impl _Allocator {
154    fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<*mut u8, TensorError> {
155        if let Ok(mut storage) = CPU_STORAGE.lock() {
156            crate::utils::allocate::allocate_helper(
157                &mut self.cache,
158                &mut self.allocated,
159                &mut storage,
160                || unsafe { std::alloc::alloc(layout) },
161                |_, _| {},
162                |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
163                layout,
164                device_id,
165            )
166        } else {
167            panic!("Failed to lock CPU_STORAGE");
168        }
169    }
170
171    fn allocate_zeroed(
172        &mut self,
173        layout: Layout,
174        device_id: usize,
175    ) -> Result<*mut u8, TensorError> {
176        if let Ok(mut storage) = CPU_STORAGE.lock() {
177            crate::utils::allocate::allocate_helper(
178                &mut self.cache,
179                &mut self.allocated,
180                &mut storage,
181                || unsafe { std::alloc::alloc_zeroed(layout) },
182                |ptr, layout| unsafe {
183                    std::ptr::write_bytes(ptr, 0, layout.size());
184                },
185                |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
186                layout,
187                device_id,
188            )
189        } else {
190            panic!("Failed to lock CPU_STORAGE");
191        }
192    }
193
194    /// # Main Deallocation Function
195    ///
196    /// deallocate memory based on the ptr provided, if the ptr is found in the storage, decrement the reference count
197    ///
198    /// if the reference count is 0, remove the ptr from the storage, remove the ptr from the allocated set, and insert the ptr into the cache
199    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
200        if let Ok(mut storage) = CPU_STORAGE.lock() {
201            crate::utils::deallocate::deallocate_helper(
202                &mut self.cache,
203                &mut self.allocated,
204                &mut storage,
205                layout,
206                ptr,
207                should_drop,
208                device_id,
209            );
210        } else {
211            panic!("Failed to lock CPU_STORAGE");
212        }
213    }
214
215    /// # Insert Pointer
216    ///
217    /// insert the ptr into the allocated set, and increment the reference count in the storage
218    ///
219    /// this function is used to insert the ptr into the allocated set, and increment the reference count in the storage
220    fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
221        self.allocated.insert(SafePtr { ptr });
222        if let Ok(mut map) = CPU_STORAGE.lock() {
223            if let Some(storage) = map.get_mut(&device_id) {
224                storage.increment_ref(SafePtr { ptr });
225            } else {
226                let mut storage = CommonStorage::new();
227                storage.increment_ref(SafePtr { ptr });
228                map.insert(device_id, storage);
229            }
230        }
231    }
232}
233
234/// resize the lru cache of the cpu allocator
235///
236/// when `new_size` >= `old_size`, cache size will increase and data won't be deallocated
237///
238/// when `new_size` < `old_size`, all the data in cache will be deallocated
239pub fn resize_cpu_lru_cache(new_size: usize, device_id: usize) {
240    if let Ok(mut cache) = CACHE.lock() {
241        if let Some(allocator) = cache.allocator.get_mut(&device_id) {
242            crate::utils::cache_resize::resize_lru_cache(
243                &mut allocator.cache,
244                |ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
245                new_size,
246            );
247        } else {
248            let allocator = _Allocator {
249                cache: LruCache::new(NonZeroUsize::new(new_size).unwrap()),
250                allocated: HashSet::new(),
251            };
252            cache.allocator.insert(device_id, allocator);
253        }
254    } else {
255        panic!("Failed to lock CACHE");
256    }
257    CPU_LRU_CACHE_SIZE.store(new_size, Ordering::Relaxed);
258}