use std::{
collections::{HashMap, HashSet},
panic::Location,
};
use hpt_common::error::{base::TensorError, memory::MemoryError};
use lru::LruCache;
use crate::{
ptr::SafePtr,
storage::{CommonStorage, Storage},
};
fn allocate_mem(
allocate_fn: impl Fn() -> *mut u8,
deallocate_fn: &impl Fn(*mut u8, std::alloc::Layout),
device_id: usize,
layout: std::alloc::Layout,
allocated: &mut HashSet<SafePtr>,
cache: &mut LruCache<std::alloc::Layout, Vec<SafePtr>>,
) -> std::result::Result<*mut u8, TensorError> {
let ptr = allocate_fn();
if !ptr.is_null() {
allocated.insert(SafePtr { ptr });
Ok(ptr)
} else {
let needed_size = layout.size();
let mut freed_size = 0;
while freed_size < needed_size {
if let Some((layout, ptrs)) = cache.pop_lru() {
for safe_ptr in ptrs {
deallocate_fn(safe_ptr.ptr, layout);
freed_size += layout.size();
let ptr = allocate_fn();
if !ptr.is_null() {
allocated.insert(SafePtr { ptr });
return Ok(ptr);
}
}
} else {
break;
}
}
let ptr = allocate_fn();
if !ptr.is_null() {
allocated.insert(SafePtr { ptr });
Ok(ptr)
} else {
Err(TensorError::Memory(MemoryError::AllocationFailed {
device: "cpu".to_string(),
id: device_id,
size: layout.size() / 1024 / 1024,
source: None,
location: Location::caller(),
}))
}
}
}
#[track_caller]
pub(crate) fn allocate_helper(
cache: &mut LruCache<std::alloc::Layout, Vec<SafePtr>>,
allocated: &mut HashSet<SafePtr>,
storage: &mut HashMap<usize, CommonStorage>,
allocate_fn: impl Fn() -> *mut u8,
zero_fn: impl Fn(*mut u8, std::alloc::Layout),
deallocate_fn: impl Fn(*mut u8, std::alloc::Layout),
layout: std::alloc::Layout,
device_id: usize,
) -> std::result::Result<*mut u8, TensorError> {
let ptr = if let Some(ptr) = cache.get_mut(&layout)
{
if let Some(safe_ptr) = ptr.pop() {
zero_fn(safe_ptr.ptr, layout);
safe_ptr.ptr
} else {
allocate_mem(
allocate_fn,
&deallocate_fn,
device_id,
layout,
allocated,
cache,
)?
}
} else {
allocate_mem(
allocate_fn,
&deallocate_fn,
device_id,
layout,
allocated,
cache,
)?
};
if cache.cap().get() == cache.len() {
if let Some((layout, ptrs)) = cache.pop_lru() {
for safe_ptr in ptrs {
deallocate_fn(safe_ptr.ptr, layout);
}
}
}
if let Some(storage) = storage.get_mut(&device_id) {
storage.increment_ref(SafePtr { ptr });
} else {
let mut new_storage = CommonStorage::new();
new_storage.increment_ref(SafePtr { ptr });
storage.insert(device_id, new_storage);
}
Ok(ptr)
}