use std::{
alloc::Layout,
collections::{HashMap, HashSet},
num::NonZeroUsize,
sync::{
atomic::{AtomicUsize, Ordering},
Mutex,
},
};
use crate::{
ptr::SafePtr,
storage::{cpu::CPU_STORAGE, CommonStorage, Storage},
traits::Allocator,
};
use hpt_common::error::base::TensorError;
use lru::LruCache;
use once_cell::sync::Lazy;
pub(crate) static CACHE: Lazy<Mutex<CpuAllocator>> = Lazy::new(|| Mutex::new(CpuAllocator::new()));
pub(crate) static CPU_LRU_CACHE_SIZE: AtomicUsize = AtomicUsize::new(100);
#[derive(Clone)]
pub struct CpuAllocator {
allocator: HashMap<usize, _Allocator>,
}
impl Allocator for CpuAllocator {
type Output = *mut u8;
type CpuAllocator = CpuAllocator;
#[cfg(feature = "cuda")]
type CudaAllocator = crate::allocators::cuda::CudaAllocator;
fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<Self::Output, TensorError> {
if let Some(allocator) = self.allocator.get_mut(&device_id) {
allocator.allocate(layout, device_id)
} else {
let mut allocator = _Allocator {
cache: LruCache::new(
NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
),
allocated: HashSet::new(),
};
let ptr = allocator.allocate(layout, device_id)?;
self.allocator.insert(device_id, allocator);
Ok(ptr)
}
}
fn allocate_zeroed(
&mut self,
layout: Layout,
device_id: usize,
) -> Result<Self::Output, TensorError> {
if let Some(allocator) = self.allocator.get_mut(&device_id) {
allocator.allocate_zeroed(layout, device_id)
} else {
let mut allocator = _Allocator {
cache: LruCache::new(
NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
),
allocated: HashSet::new(),
};
let ptr = allocator.allocate_zeroed(layout, device_id)?;
self.allocator.insert(device_id, allocator);
Ok(ptr)
}
}
fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
if let Some(allocator) = self.allocator.get_mut(&device_id) {
allocator.deallocate(ptr, layout, should_drop, device_id);
} else {
if !should_drop {
return;
}
panic!("device {} not found in allocator", device_id);
}
}
fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
if let Some(allocator) = self.allocator.get_mut(&device_id) {
allocator.insert_ptr(ptr, device_id);
} else {
let mut allocator = _Allocator {
cache: LruCache::new(
NonZeroUsize::new(CPU_LRU_CACHE_SIZE.load(Ordering::Relaxed)).unwrap(),
),
allocated: HashSet::new(),
};
allocator.insert_ptr(ptr, device_id);
self.allocator.insert(device_id, allocator);
}
}
fn clear(&mut self) {
if let Ok(mut storage) = CPU_STORAGE.lock() {
let storage: &mut HashMap<usize, CommonStorage> = &mut storage;
for (device_id, allocator) in self.allocator.iter_mut() {
for (layout, ptrs) in allocator.cache.iter_mut() {
for ptr in ptrs.iter() {
storage
.get_mut(device_id)
.unwrap()
.decrement_ref(SafePtr { ptr: ptr.ptr });
unsafe {
std::alloc::dealloc(ptr.ptr, layout.clone());
}
}
}
allocator.cache.clear();
assert_eq!(allocator.allocated.len(), 0);
assert_eq!(storage[device_id].storage.len(), 0);
}
}
}
fn new() -> Self {
CpuAllocator {
allocator: HashMap::new(),
}
}
fn forget(&mut self, ptr: *mut u8, device_id: usize) {
if let Some(allocator) = self.allocator.get_mut(&device_id) {
if allocator.allocated.get(&SafePtr { ptr }).is_some() {
allocator.allocated.remove(&SafePtr { ptr });
}
}
}
}
#[derive(Clone)]
struct _Allocator {
cache: LruCache<Layout, Vec<SafePtr>>,
allocated: HashSet<SafePtr>,
}
impl _Allocator {
fn allocate(&mut self, layout: Layout, device_id: usize) -> Result<*mut u8, TensorError> {
if let Ok(mut storage) = CPU_STORAGE.lock() {
crate::utils::allocate::allocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
|| unsafe { std::alloc::alloc(layout) },
|_, _| {},
|ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
layout,
device_id,
)
} else {
panic!("Failed to lock CPU_STORAGE");
}
}
fn allocate_zeroed(
&mut self,
layout: Layout,
device_id: usize,
) -> Result<*mut u8, TensorError> {
if let Ok(mut storage) = CPU_STORAGE.lock() {
crate::utils::allocate::allocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
|| unsafe { std::alloc::alloc_zeroed(layout) },
|ptr, layout| unsafe {
std::ptr::write_bytes(ptr, 0, layout.size());
},
|ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
layout,
device_id,
)
} else {
panic!("Failed to lock CPU_STORAGE");
}
}
fn deallocate(&mut self, ptr: *mut u8, layout: &Layout, should_drop: bool, device_id: usize) {
if let Ok(mut storage) = CPU_STORAGE.lock() {
crate::utils::deallocate::deallocate_helper(
&mut self.cache,
&mut self.allocated,
&mut storage,
layout,
ptr,
should_drop,
device_id,
);
} else {
panic!("Failed to lock CPU_STORAGE");
}
}
fn insert_ptr(&mut self, ptr: *mut u8, device_id: usize) {
self.allocated.insert(SafePtr { ptr });
if let Ok(mut map) = CPU_STORAGE.lock() {
if let Some(storage) = map.get_mut(&device_id) {
storage.increment_ref(SafePtr { ptr });
} else {
let mut storage = CommonStorage::new();
storage.increment_ref(SafePtr { ptr });
map.insert(device_id, storage);
}
}
}
}
pub fn resize_cpu_lru_cache(new_size: usize, device_id: usize) {
if let Ok(mut cache) = CACHE.lock() {
if let Some(allocator) = cache.allocator.get_mut(&device_id) {
crate::utils::cache_resize::resize_lru_cache(
&mut allocator.cache,
|ptr, layout| unsafe { std::alloc::dealloc(ptr, layout) },
new_size,
);
} else {
let allocator = _Allocator {
cache: LruCache::new(NonZeroUsize::new(new_size).unwrap()),
allocated: HashSet::new(),
};
cache.allocator.insert(device_id, allocator);
}
} else {
panic!("Failed to lock CACHE");
}
CPU_LRU_CACHE_SIZE.store(new_size, Ordering::Relaxed);
}