use cubecl_common::backtrace::BackTrace;
use cubecl_core::server::IoError;
use cubecl_hip_sys::HIP_SUCCESS;
use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use std::{collections::HashMap, ffi::c_void};
pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
pub struct PinnedMemoryStorage {
memory: HashMap<StorageId, PinnedMemory>,
mem_alignment: usize,
stream: cubecl_hip_sys::hipStream_t,
}
#[derive(Debug)]
pub struct PinnedMemoryResource {
pub ptr: *mut u8,
pub size: usize,
}
#[derive(Debug)]
struct PinnedMemory {
ptr: *mut c_void,
#[allow(unused)]
dev_ptr: *mut *mut c_void,
}
impl PinnedMemoryStorage {
pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
Self {
memory: HashMap::new(),
mem_alignment: PINNED_MEMORY_ALIGNMENT,
stream,
}
}
}
unsafe impl Send for PinnedMemoryResource {}
unsafe impl Send for PinnedMemoryStorage {}
impl ComputeStorage for PinnedMemoryStorage {
type Resource = PinnedMemoryResource;
fn alignment(&self) -> usize {
self.mem_alignment
}
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let memory = self
.memory
.get(&handle.id)
.expect("Storage handle not found");
let offset = handle.offset() as usize;
let size = handle.size() as usize;
unsafe {
PinnedMemoryResource {
ptr: memory.ptr.cast::<u8>().add(offset),
size,
}
}
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, size))
)]
fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
let resource = unsafe {
let mut ptr: *mut c_void = std::ptr::null_mut();
let dev_ptr: *mut *mut c_void = &mut ptr;
let result = cubecl_hip_sys::hipHostMalloc(
dev_ptr,
size as usize,
cubecl_hip_sys::hipHostMallocMapped,
);
if result != HIP_SUCCESS {
return Err(IoError::Unknown {
description: format!("cuMemAllocHost_v2 failed with error code: {result:?}"),
backtrace: BackTrace::capture(),
});
}
cubecl_hip_sys::hipStreamSynchronize(self.stream);
PinnedMemory { ptr, dev_ptr }
};
let id = StorageId::new();
self.memory.insert(id, resource);
Ok(StorageHandle::new(
id,
StorageUtilization { offset: 0, size },
))
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn dealloc(&mut self, id: StorageId) {
if let Some(resource) = self.memory.remove(&id) {
unsafe {
cubecl_hip_sys::hipFreeHost(resource.ptr);
}
}
}
fn flush(&mut self) {
}
}