use crate::compute::uninit_vec;
use cubecl_common::backtrace::BackTrace;
use cubecl_core::server::IoError;
use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
use cudarc::driver::DriverError;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
enum AllocationKind {
Async,
Sync,
}
pub struct GpuStorage {
memory: HashMap<StorageId, (cudarc::driver::sys::CUdeviceptr, AllocationKind)>,
deallocations: Vec<StorageId>,
ptr_bindings: PtrBindings,
stream: cudarc::driver::sys::CUstream,
mem_alignment: usize,
}
#[derive(Debug)]
pub struct GpuResource {
pub ptr: u64,
pub binding: *mut std::ffi::c_void,
pub size: u64,
}
impl GpuResource {
pub fn new(ptr: u64, binding: *mut std::ffi::c_void, size: u64) -> Self {
Self { ptr, binding, size }
}
}
impl GpuStorage {
pub fn new(mem_alignment: usize, stream: cudarc::driver::sys::CUstream) -> Self {
Self {
memory: HashMap::new(),
deallocations: Vec::new(),
ptr_bindings: PtrBindings::new(),
stream,
mem_alignment,
}
}
fn perform_deallocations(&mut self) {
self.deallocations
.drain(..)
.filter_map(|id| self.memory.remove(&id))
.for_each(|(ptr, kind)| unsafe {
match kind {
AllocationKind::Async => {
let _ = cudarc::driver::result::free_async(ptr, self.stream);
}
AllocationKind::Sync => {
if let Err(e) = cudarc::driver::result::free_sync(ptr) {
eprintln!("CUDA free error: {}", e);
}
}
}
});
}
}
unsafe impl Send for GpuResource {}
unsafe impl Send for GpuStorage {}
impl core::fmt::Debug for GpuStorage {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("GpuStorage").finish()
}
}
struct PtrBindings {
slots: Vec<cudarc::driver::sys::CUdeviceptr>,
cursor: usize,
}
impl PtrBindings {
fn new() -> Self {
Self {
slots: unsafe { uninit_vec(crate::device::CUDA_MAX_BINDINGS as usize) },
cursor: 0,
}
}
fn register(&mut self, ptr: u64) -> &u64 {
self.slots[self.cursor] = ptr;
let ptr_ref = self.slots.get(self.cursor).unwrap();
self.cursor += 1;
if self.cursor >= self.slots.len() {
self.cursor = 0;
}
ptr_ref
}
}
impl ComputeStorage for GpuStorage {
type Resource = GpuResource;
fn alignment(&self) -> usize {
self.mem_alignment
}
fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
let (ptr, _) = self
.memory
.get(&handle.id)
.expect("Storage handle not found");
let offset = handle.offset();
let size = handle.size();
let ptr = self.ptr_bindings.register(ptr + offset);
GpuResource::new(
*ptr,
ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void,
size,
)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(level = "trace", skip(self, size))
)]
fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
let id = StorageId::new();
let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size as usize) };
let (ptr, kind) = match ptr {
Ok(ptr) => (ptr, AllocationKind::Async),
Err(_) => unsafe {
match cudarc::driver::result::malloc_sync(size as usize) {
Ok(ptr) => (ptr, AllocationKind::Sync),
Err(DriverError(cudarc::driver::sys::CUresult::CUDA_ERROR_OUT_OF_MEMORY)) => {
return Err(IoError::BufferTooBig {
size,
backtrace: BackTrace::capture(),
});
}
Err(other) => {
return Err(IoError::Unknown {
description: format!("CUDA allocation error: {other}"),
backtrace: BackTrace::capture(),
});
}
}
},
};
self.memory.insert(id, (ptr, kind));
Ok(StorageHandle::new(
id,
StorageUtilization { offset: 0, size },
))
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
fn flush(&mut self) {
self.perform_deallocations();
}
}