use std::alloc::Layout;
use hpt_common::error::base::TensorError;
pub trait Allocator: Clone {
type Output: AllocatorOutputRetrive;
type CpuAllocator: Allocator;
#[cfg(feature = "cuda")]
type CudaAllocator: Allocator;
fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError>;
fn allocate_zeroed(
&mut self,
layout: Layout,
device_id: usize,
) -> Result<Self::Output, TensorError>;
fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize);
fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize);
fn clear(&mut self);
fn forget(&mut self, ptr: *mut u8, device_id: usize);
fn new() -> Self;
}
pub trait AllocatorOutputRetrive {
fn get_ptr(&self) -> *mut u8;
#[cfg(feature = "cuda")]
fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice>;
}
impl AllocatorOutputRetrive for *mut u8 {
fn get_ptr(&self) -> *mut u8 {
self.clone()
}
#[cfg(feature = "cuda")]
fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
panic!("cuda is not enabled");
}
}
#[cfg(feature = "cuda")]
impl AllocatorOutputRetrive for (*mut u8, std::sync::Arc<cudarc::driver::CudaDevice>) {
fn get_ptr(&self) -> *mut u8 {
self.0.clone()
}
#[cfg(feature = "cuda")]
fn get_device(&self) -> std::sync::Arc<cudarc::driver::CudaDevice> {
self.1.clone()
}
}
pub trait FromAllocatorOutput<T> {
fn from_allocator_output(alloc_output: T) -> Self;
}