hpt_allocator/traits.rs
1use std::alloc::Layout;
2
3use hpt_common::error::base::TensorError;
4
5/// traits for the allocator
6pub trait Allocator: Clone {
7 /// the output type of the allocator
8 type Output: AllocatorOutputRetrive;
9 /// cpu type of the allocator
10 type CpuAllocator: Allocator;
11 /// cuda type of the allocator
12 #[cfg(feature = "cuda")]
13 type CudaAllocator: Allocator;
14 /// allocate memory by using lru cache strategy
15 ///
16 /// # Logic
17 ///
18 /// 1. check if the layout is found in the cache
19 ///
20 /// 2. if the layout is found in the cache, pop the memory out, if it return None, there is no available cached memory, we need to allocate new memory
21 ///
22 /// 3. if the layout is not found in the cache, allocate new memory
23 ///
24 /// 4. eventually, if the cache is full, pop the least recently used memory and deallocate the memory
25 fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError>;
26
27 /// similar to `allocate`, but the memory is zeroed
28 fn allocate_zeroed(
29 &mut self,
30 layout: Layout,
31 device_id: usize,
32 ) -> Result<Self::Output, TensorError>;
33
34 /// deallocate memory by using lru cache strategy
35 ///
36 /// # Logic
37 ///
38 /// 1. check if the ptr is found in the storage
39 ///
40 /// 2. if the ptr is found in the storage, decrement the reference count
41 ///
42 /// 3. if the reference count is 0, remove the ptr from the storage, remove the ptr from the allocated set, and insert the ptr into the cache
43 fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize);
44 /// if the ptr is found in the storage, increment the reference count, otherwise insert the ptr into the storage
45 fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize);
46 /// clear the cache, deallocate all the memory allocated
47 ///
48 /// this is used when the program exits, it will be called automatically
49 fn clear(&mut self);
50
51 /// forget the data in the allocator
52 fn forget(&mut self, ptr: *mut u8, device_id: usize);
53
54 /// create a new allocator
55 fn new() -> Self;
56}
57
58/// traits for the allocator output retrive
59pub trait AllocatorOutputRetrive {
60 /// get the pointer from the allocator output
61 fn get_ptr(&self) -> *mut u8;
62 /// get the device from the allocator output
63 #[cfg(feature = "cuda")]
64 fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice>;
65}
66
67impl AllocatorOutputRetrive for *mut u8 {
68 fn get_ptr(&self) -> *mut u8 {
69 self.clone()
70 }
71 #[cfg(feature = "cuda")]
72 fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
73 panic!("cuda is not enabled");
74 }
75}
76
77#[cfg(feature = "cuda")]
78impl AllocatorOutputRetrive for (*mut u8, std::sync::Arc<cudarc::driver::CudaDevice>) {
79 fn get_ptr(&self) -> *mut u8 {
80 self.0.clone()
81 }
82 #[cfg(feature = "cuda")]
83 fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
84 self.1.clone()
85 }
86}
87/// traits for the allocator output convert to backend
88pub trait FromAllocatorOutput<T> {
89 /// convert the allocator output to backend
90 fn from_allocator_output(alloc_output: T) -> Self;
91}