use cubecl_common::backtrace::BackTrace;
use cubecl_core::server::IoError;
use cubecl_hip_sys::HIP_SUCCESS;
use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use std::collections::HashMap;
use crate::AMD_MAX_BINDINGS;
pub struct GpuStorage {
mem_alignment: usize,
memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
deallocations: Vec<StorageId>,
ptr_bindings: PtrBindings,
stream: cubecl_hip_sys::hipStream_t,
}
#[derive(new, Debug)]
pub struct GpuResource {
pub ptr: cubecl_hip_sys::hipDeviceptr_t,
pub binding: cubecl_hip_sys::hipDeviceptr_t,
pub size: u64,
}
impl GpuStorage {
pub fn new(mem_alignment: usize, stream: cubecl_hip_sys::hipStream_t) -> Self {
Self {
mem_alignment,
memory: HashMap::new(),
deallocations: Vec::new(),
ptr_bindings: PtrBindings::new(),
stream,
}
}
pub fn perform_deallocations(&mut self) {
for id in self.deallocations.drain(..) {
if let Some(ptr) = self.memory.remove(&id) {
unsafe {
cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
}
}
}
}
}
struct PtrBindings {
slots: Vec<u64>,
cursor: usize,
}
impl PtrBindings {
fn new() -> Self {
Self {
slots: vec![0; AMD_MAX_BINDINGS as usize],
cursor: 0,
}
}
fn register(&mut self, ptr: u64) -> &u64 {
self.slots[self.cursor] = ptr;
let ptr_ref = self.slots.get(self.cursor).unwrap();
self.cursor += 1;
if self.cursor >= self.slots.len() {
self.cursor = 0;
}
ptr_ref
}
}
impl ComputeStorage for GpuStorage {
type Resource = GpuResource;
fn alignment(&self) -> usize {
self.mem_alignment
}
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
let offset = handle.offset();
let size = handle.size();
let ptr = self.ptr_bindings.register(ptr + offset);
GpuResource::new(
*ptr as cubecl_hip_sys::hipDeviceptr_t,
std::ptr::from_ref(ptr) as *mut std::ffi::c_void,
size,
)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, size))
)]
fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
let id = StorageId::new();
unsafe {
let mut ptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
let status = cubecl_hip_sys::hipMallocAsync(&mut ptr, size as usize, self.stream);
match status {
HIP_SUCCESS => {}
other => {
return Err(IoError::Unknown {
description: format!("HIP allocation error: {other}"),
backtrace: BackTrace::capture(),
});
}
}
cubecl_hip_sys::hipStreamSynchronize(self.stream);
self.memory.insert(id, ptr);
};
Ok(StorageHandle::new(
id,
StorageUtilization { offset: 0, size },
))
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}
fn flush(&mut self) {
self.perform_deallocations();
}
}
unsafe impl Send for GpuStorage {}
unsafe impl Send for GpuResource {}
impl core::fmt::Debug for GpuStorage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("GpuStorage".to_string().as_str())
}
}