hpt_allocator/allocators/
cpu.rs

1use std::{alloc::Layout, num::NonZeroUsize, panic::Location, sync::Mutex};
2
3use crate::{ptr::SafePtr, storage::cpu::CPU_STORAGE, storage::Storage, traits::Allocator};
4use hashbrown::{HashMap, HashSet};
5use hpt_common::error::base::TensorError;
6use hpt_common::error::memory::MemoryError;
7use lru::LruCache;
8use once_cell::sync::Lazy;
9
10/// `lru` cache allocator
11pub static CACHE: Lazy<Mutex<CpuAllocator>> = Lazy::new(|| Mutex::new(CpuAllocator::new()));
12
13/// # Allocator
14///
15/// a `lru` based allocator, to allocate and deallocate memory
16///
17/// this allocator is used widely in the library, to allocate and deallocate memory
18///
19/// # Safety
20///
21/// thread safe
22///
23/// # Potential Memory Leak
24///
25/// developer must carefully manage the reference count of the pointer allocated
26pub struct CpuAllocator {
27    allocator: HashMap<usize, _Allocator>,
28}
29
30impl Allocator for CpuAllocator {
31    fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<*mut u8, TensorError> {
32        if let Some(allocator) = self.allocator.get_mut(&device_id) {
33            allocator.allocate(layout, device_id)
34        } else {
35            let mut allocator = _Allocator {
36                cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
37                allocated: HashSet::new(),
38            };
39            let ptr = allocator.allocate(layout, device_id)?;
40            self.allocator.insert(device_id, allocator);
41            Ok(ptr)
42        }
43    }
44    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, device_id: usize) {
45        if let Some(allocator) = self.allocator.get_mut(&device_id) {
46            allocator.deallocate(ptr, layout);
47        }
48    }
49
50    fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
51        if let Some(allocator) = self.allocator.get_mut(&device_id) {
52            allocator.insert_ptr(ptr);
53        } else {
54            let mut allocator = _Allocator {
55                cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
56                allocated: HashSet::new(),
57            };
58            allocator.insert_ptr(ptr);
59            self.allocator.insert(device_id, allocator);
60        }
61    }
62
63    fn clear(&mut self) {
64        for (_, allocator) in self.allocator.iter_mut() {
65            for (layout, ptrs) in allocator.cache.iter_mut() {
66                for ptr in ptrs.iter() {
67                    unsafe {
68                        std::alloc::dealloc(ptr.ptr, layout.clone());
69                    }
70                }
71            }
72            allocator.cache.clear();
73            assert_eq!(allocator.allocated.len(), 0);
74        }
75    }
76}
77
78impl CpuAllocator {
79    pub fn new() -> Self {
80        CpuAllocator {
81            allocator: HashMap::new(),
82        }
83    }
84}
85
86struct _Allocator {
87    cache: LruCache<Layout, Vec<SafePtr>>,
88    allocated: HashSet<SafePtr>,
89}
90
91impl _Allocator {
92    /// # Main Allocation Function
93    /// allocating and freeing memory is expensive, we are using `LRU`(least recently used) algorithm to reuse the memory
94    ///
95    /// allocate memory based on layout provided, if the layout is not found in the cache, allocate, otherwise pop from the cache
96    ///
97    /// this function internally checks if the cache is full, if it is full, it pops the least recently used layout and deallocates the memory
98    ///
99    /// if the cache is not full, it inserts the allocated memory into the cache, and increments the reference count in the storage
100    ///
101    /// # Safety
102    ///
103    /// This function checks `null` ptr internally, any memory allocated through this method, downstream don't need to check for `null` ptr
104    fn allocate(
105        &mut self,
106        layout: Layout,
107        device_id: usize,
108    ) -> std::result::Result<*mut u8, TensorError> {
109        let ptr = if let Some(ptr) = self.cache.get_mut(&layout)
110        /*check if we previously allocated same layout of memory */
111        {
112            // try pop the memory out, if it return None, there is no available cached memory, we need to allocate new memory
113            if let Some(safe_ptr) = ptr.pop() {
114                safe_ptr.ptr
115            } else {
116                let ptr = unsafe { std::alloc::alloc(layout) };
117                if ptr.is_null() {
118                    return Err(TensorError::Memory(MemoryError::AllocationFailed {
119                        device: "cpu".to_string(),
120                        id: device_id,
121                        size: layout.size() / 1024 / 1024,
122                        source: None,
123                        location: Location::caller(),
124                    }));
125                }
126                self.allocated.insert(SafePtr { ptr });
127                ptr
128            }
129        } else {
130            let ptr = unsafe { std::alloc::alloc(layout) };
131            if ptr.is_null() {
132                return Err(TensorError::Memory(MemoryError::AllocationFailed {
133                    device: "cpu".to_string(),
134                    id: device_id,
135                    size: layout.size() / 1024 / 1024,
136                    source: None,
137                    location: Location::caller(),
138                }));
139            }
140            self.allocated.insert(SafePtr { ptr });
141            ptr
142        };
143        // check if the cache is full, if it is full, pop the least recently used layout and deallocate the memory
144        if self.cache.cap().get() == self.cache.len() {
145            if let Some((layout, ptrs)) = self.cache.pop_lru() {
146                for safe_ptr in ptrs {
147                    unsafe {
148                        std::alloc::dealloc(safe_ptr.ptr, layout);
149                    }
150                }
151            }
152        }
153        // increment the reference count in the storage of the ptr allocated
154        if let Ok(mut storage) = CPU_STORAGE.lock() {
155            storage.increment_ref(SafePtr { ptr });
156        }
157        Ok(ptr)
158    }
159
160    /// # Main Deallocation Function
161    ///
162    /// deallocate memory based on the ptr provided, if the ptr is found in the storage, decrement the reference count
163    ///
164    /// if the reference count is 0, remove the ptr from the storage, remove the ptr from the allocated set, and insert the ptr into the cache
165    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout) {
166        if let Ok(mut storage) = CPU_STORAGE.lock() {
167            if storage.decrement_ref(SafePtr { ptr }) {
168                self.allocated.remove(&SafePtr { ptr });
169                if let Some(ptrs) = self.cache.get_mut(layout) {
170                    ptrs.push(SafePtr { ptr });
171                } else {
172                    self.cache.put(layout.clone(), vec![SafePtr { ptr }]);
173                }
174            }
175        }
176    }
177
178    /// # Insert Pointer
179    ///
180    /// insert the ptr into the allocated set, and increment the reference count in the storage
181    ///
182    /// this function is used to insert the ptr into the allocated set, and increment the reference count in the storage
183    fn insert_ptr(&mut self, ptr: *mut u8) {
184        self.allocated.insert(SafePtr { ptr });
185        if let Ok(mut storage) = CPU_STORAGE.lock() {
186            storage.increment_ref(SafePtr { ptr });
187        }
188    }
189}