hpt_allocator/
traits.rs

1use std::alloc::Layout;
2
3use hpt_common::error::base::TensorError;
4
5/// traits for the allocator
6pub trait Allocator: Clone {
7    /// the output type of the allocator
8    type Output: AllocatorOutputRetrive;
9    /// cpu type of the allocator
10    type CpuAllocator: Allocator;
11    /// cuda type of the allocator
12    #[cfg(feature = "cuda")]
13    type CudaAllocator: Allocator;
14    /// allocate memory by using lru cache strategy
15    ///
16    /// # Logic
17    ///
18    /// 1. check if the layout is found in the cache
19    ///
20    /// 2. if the layout is found in the cache, pop the memory out, if it return None, there is no available cached memory, we need to allocate new memory
21    ///
22    /// 3. if the layout is not found in the cache, allocate new memory
23    ///
24    /// 4. eventually, if the cache is full, pop the least recently used memory and deallocate the memory
25    fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError>;
26
27    /// similar to `allocate`, but the memory is zeroed
28    fn allocate_zeroed(
29        &mut self,
30        layout: Layout,
31        device_id: usize,
32    ) -> Result<Self::Output, TensorError>;
33
34    /// deallocate memory by using lru cache strategy
35    ///
36    /// # Logic
37    ///
38    /// 1. check if the ptr is found in the storage
39    ///
40    /// 2. if the ptr is found in the storage, decrement the reference count
41    ///
42    /// 3. if the reference count is 0, remove the ptr from the storage, remove the ptr from the allocated set, and insert the ptr into the cache
43    fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, device_id: usize);
44    /// if the ptr is found in the storage, increment the reference count, otherwise insert the ptr into the storage
45    fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize);
46    /// clear the cache, deallocate all the memory allocated
47    ///
48    /// this is used when the program exits, it will be called automatically
49    fn clear(&mut self);
50    /// create a new allocator
51    fn new() -> Self;
52}
53
54/// traits for the allocator output retrive
55pub trait AllocatorOutputRetrive {
56    /// get the pointer from the allocator output
57    fn get_ptr(&self) -> *mut u8;
58    /// get the device from the allocator output
59    #[cfg(feature = "cuda")]
60    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice>;
61}
62
63impl AllocatorOutputRetrive for *mut u8 {
64    fn get_ptr(&self) -> *mut u8 {
65        self.clone()
66    }
67    #[cfg(feature = "cuda")]
68    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
69        panic!("cuda is not enabled");
70    }
71}
72
73#[cfg(feature = "cuda")]
74impl AllocatorOutputRetrive for (*mut u8, std::sync::Arc<cudarc::driver::CudaDevice>) {
75    fn get_ptr(&self) -> *mut u8 {
76        self.0.clone()
77    }
78    #[cfg(feature = "cuda")]
79    fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
80        self.1.clone()
81    }
82}
83/// traits for the allocator output convert to backend
84pub trait FromAllocatorOutput<T> {
85    /// convert the allocator output to backend
86    fn from_allocator_output(alloc_output: T) -> Self;
87}