use std::{
alloc::Layout,
collections::{HashMap, HashSet},
num::NonZeroUsize,
panic::Location,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use crate::{
ptr::SafePtr,
storage::{CommonStorage, Storage},
traits::Allocator,
CUDA_STORAGE,
};
use hpt_common::error::base::TensorError;
use lru::LruCache;
use once_cell::sync::Lazy;
pub(crate) static CUDA_CACHE: Lazy<Mutex<CudaAllocator>> =
Lazy::new(|| Mutex::new(CudaAllocator::new()));
pub(crate) static CUDA_LRU_CACHE_SIZE: AtomicUsize = AtomicUsize::new(1);
#[derive(Clone)]
pub struct CudaAllocator {
allocator: HashMap<usize, (Arc<cudarc::driver::CudaDevice>, _Allocator)>,
}
impl Allocator for CudaAllocator {
type Output = (*mut u8, Arc<cudarc::driver::CudaDevice>);
type CpuAllocator = crate::allocators::cpu::CpuAllocator;
#[cfg(feature = "cuda")]
type CudaAllocator = CudaAllocator;
fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError> {
if let Some((device, allocator)) = self.allocator.get_mut(&device_id) {
Ok((
allocator.allocate(layout, device_id, device.clone())?,
device.clone(),
))
} else {
let mut allocator = _Allocator {
cache: LruCache::new(
NonZeroUsize::new(CUDA_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
),
allocated: HashSet::new(),
};
let device = cudarc::driver::CudaDevice::new(device_id).unwrap();
let ptr = allocator.allocate(layout, device_id, device.clone())?;
self.allocator
.insert(device_id, (device.clone(), allocator));
Ok((ptr, device))
}
}
fn allocate_zeroed(
&mut self,
layout: Layout,
device_id: usize,
) -> Result<Self::Output, TensorError> {
if let Some((device, allocator)) = self.allocator.get_mut(&device_id) {
Ok((
allocator.allocate_zeroed(layout, device_id, device.clone())?,
device.clone(),
))
} else {
let mut allocator = _Allocator {
cache: LruCache::new(
NonZeroUsize::new(CUDA_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
),
allocated: HashSet::new(),
};
let device = cudarc::driver::CudaDevice::new(device_id).unwrap();
let ptr = allocator.allocate_zeroed(layout, device_id, device.clone())?;
self.allocator
.insert(device_id, (device.clone(), allocator));
Ok((ptr, device))
}
}
fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
if let Some((_, allocator)) = self.allocator.get_mut(&device_id) {
allocator.deallocate(ptr, layout, should_drop, device_id);
} else {
if !should_drop {
return;
}
panic!("Allocator for device {} not found", device_id);
}
}
fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
if let Some((_, allocator)) = self.allocator.get_mut(&device_id) {
allocator.insert_ptr(ptr, device_id);
} else {
panic!("Allocator for device {} not found", device_id);
}
}
fn clear(&mut self) {
if let Ok(mut storage) = CUDA_STORAGE.lock() {
for (device, allocator) in self.allocator.values_mut() {
for (layout, ptrs) in allocator.cache.iter_mut() {
for ptr in ptrs.iter() {
storage
.get_mut(&device.ordinal())
.unwrap()
.decrement_ref(SafePtr { ptr: ptr.ptr });
unsafe { device.upgrade_device_ptr::<u8>(ptr.ptr as u64, layout.size()) };
}
}
allocator.cache.clear();
assert_eq!(allocator.allocated.len(), 0);
assert_eq!(storage[&device.ordinal()].storage.len(), 0);
}
}
}
fn new() -> Self {
CudaAllocator {
allocator: HashMap::new(),
}
}
fn forget(&mut self, ptr: *mut u8, device_id: usize) {
if let Some((_, allocator)) = self.allocator.get_mut(&device_id) {
if allocator.allocated.get(&SafePtr { ptr }).is_some() {
allocator.allocated.remove(&SafePtr { ptr });
}
}
}
}
#[derive(Clone)]
struct _Allocator {
cache: LruCache<Layout, Vec<SafePtr>>,
allocated: HashSet<SafePtr>,
}
impl _Allocator {
#[track_caller]
fn allocate(
&mut self,
layout: Layout,
device_id: usize,
device: Arc<cudarc::driver::CudaDevice>,
) -> Result<*mut u8, TensorError> {
if let Ok(mut storage) = CUDA_STORAGE.lock() {
let res = crate::utils::allocate::allocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
|| {
let res = unsafe {
device
.alloc::<u8>(layout.size())
.map_err(
|e| hpt_common::error::device::DeviceError::CudaDriverError {
message: format!(
"Failed to allocate memory, for {} MB",
layout.size() / 1024 / 1024
),
source: Some(e),
location: Location::caller(),
},
)
.expect("Failed to allocate memory")
};
res.leak() as *mut u8
},
|_, _| {},
|ptr, layout| {
let slice =
unsafe { device.upgrade_device_ptr::<u8>(ptr as u64, layout.size()) };
drop(slice);
},
layout,
device_id,
);
res
} else {
panic!("Failed to lock CPU_STORAGE");
}
}
#[track_caller]
fn allocate_zeroed(
&mut self,
layout: Layout,
device_id: usize,
device: Arc<cudarc::driver::CudaDevice>,
) -> Result<*mut u8, TensorError> {
if let Ok(mut storage) = CUDA_STORAGE.lock() {
let res = crate::utils::allocate::allocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
|| {
let res = device
.alloc_zeros::<u8>(layout.size())
.map_err(
|e| hpt_common::error::device::DeviceError::CudaDriverError {
message: format!(
"Failed to allocate memory, for {} MB",
layout.size() / 1024 / 1024
),
source: Some(e),
location: Location::caller(),
},
)
.expect("Failed to allocate memory");
res.leak() as *mut u8
},
|ptr, layout| {
let mut slice =
unsafe { device.upgrade_device_ptr::<u8>(ptr as u64, layout.size()) };
device
.memset_zeros(&mut slice)
.expect("Failed to memset zeros");
slice.leak();
},
|ptr, layout| {
let slice =
unsafe { device.upgrade_device_ptr::<u8>(ptr as u64, layout.size()) };
drop(slice);
},
layout,
device_id,
);
res
} else {
panic!("Failed to lock CPU_STORAGE");
}
}
fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
if let Ok(mut storage) = CUDA_STORAGE.lock() {
crate::utils::deallocate::deallocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
layout,
ptr,
should_drop,
device_id,
);
} else {
panic!("Failed to lock CUDA_STORAGE");
}
}
fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
self.allocated.insert(SafePtr { ptr });
if let Ok(mut map) = CUDA_STORAGE.lock() {
if let Some(storage) = map.get_mut(&device_id) {
storage.increment_ref(SafePtr { ptr });
} else {
let mut storage = CommonStorage::new();
storage.increment_ref(SafePtr { ptr });
map.insert(device_id, storage);
}
}
}
}
pub fn resize_cuda_lru_cache(new_size: usize, device_id: usize) {
if let Ok(mut cache) = CUDA_CACHE.lock() {
if let Some((device, allocator)) = cache.allocator.get_mut(&device_id) {
crate::utils::cache_resize::resize_lru_cache(
&mut allocator.cache,
|ptr, layout| {
let slice =
unsafe { device.upgrade_device_ptr::<u8>(ptr as u64, layout.size()) };
drop(slice);
},
new_size,
);
} else {
let allocator = _Allocator {
cache: LruCache::new(NonZeroUsize::new(new_size).unwrap()),
allocated: HashSet::new(),
};
let device = cudarc::driver::CudaDevice::new(device_id).unwrap();
cache
.allocator
.insert(device_id, (device.clone(), allocator));
}
} else {
panic!("Failed to lock CUDA_CACHE");
}
CUDA_LRU_CACHE_SIZE.store(new_size, Ordering::Relaxed);
}