hpt_allocator/
traits.rs

1use std::alloc::Layout;
2
3use hpt_common::error::base::TensorError;
4
5/// traits for the allocator
6pub trait Allocator: Clone {
7    /// the output type of the allocator
8    type Output: AllocatorOutputRetrive;
9    /// cpu type of the allocator
10    type CpuAllocator: Allocator;
11    /// cuda type of the allocator
12    #[cfg(feature = "cuda")]
13    type CudaAllocator: Allocator;
14    /// allocate memory by using lru cache strategy
15    ///
16    /// # Logic
17    ///
18    /// 1. check if the layout is found in the cache
19    ///
20    /// 2. if the layout is found in the cache, pop the memory out, if it return None, there is no available cached memory, we need to allocate new memory
21    ///
22    /// 3. if the layout is not found in the cache, allocate new memory
23    ///
24    /// 4. eventually, if the cache is full, pop the least recently used memory and deallocate the memory
25    fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError>;
26
27    /// similar to `allocate`, but the memory is zeroed
28    fn allocate_zeroed(
29        &mut self,
30        layout: Layout,
31        device_id: usize,
32    ) -> Result<Self::Output, TensorError>;
33
34    /// deallocate memory by using lru cache strategy
35    ///
36    /// # Logic
37    ///
38    /// 1. check if the ptr is found in the storage
39    ///
40    /// 2. if the ptr is found in the storage, decrement the reference count
41    ///
42    /// 3. if the reference count is 0, remove the ptr from the storage, remove the ptr from the allocated set, and insert the ptr into the cache
43    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize);
44    /// if the ptr is found in the storage, increment the reference count, otherwise insert the ptr into the storage
45    fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize);
46    /// clear the cache, deallocate all the memory allocated
47    ///
48    /// this is used when the program exits, it will be called automatically
49    fn clear(&mut self);
50
51    /// forget the data in the allocator
52    fn forget(&mut self, ptr: *mut u8, device_id: usize);
53
54    /// create a new allocator
55    fn new() -> Self;
56}
57
58/// traits for the allocator output retrive
59pub trait AllocatorOutputRetrive {
60    /// get the pointer from the allocator output
61    fn get_ptr(&self) -> *mut u8;
62    /// get the device from the allocator output
63    #[cfg(feature = "cuda")]
64    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice>;
65}
66
67impl AllocatorOutputRetrive for *mut u8 {
68    fn get_ptr(&self) -> *mut u8 {
69        self.clone()
70    }
71    #[cfg(feature = "cuda")]
72    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
73        panic!("cuda is not enabled");
74    }
75}
76
77#[cfg(feature = "cuda")]
78impl AllocatorOutputRetrive for (*mut u8, std::sync::Arc<cudarc::driver::CudaDevice>) {
79    fn get_ptr(&self) -> *mut u8 {
80        self.0.clone()
81    }
82    #[cfg(feature = "cuda")]
83    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
84        self.1.clone()
85    }
86}
87/// traits for the allocator output convert to backend
88pub trait FromAllocatorOutput<T> {
89    /// convert the allocator output to backend
90    fn from_allocator_output(alloc_output: T) -> Self;
91}