cubecl_cuda/compute/
storage.rs1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use cudarc::driver::sys::CUstream;
3use std::collections::HashMap;
4
5use super::uninit_vec;
6
7pub struct CudaStorage {
9    memory: HashMap<StorageId, cudarc::driver::sys::CUdeviceptr>,
10    deallocations: Vec<StorageId>,
11    stream: cudarc::driver::sys::CUstream,
12    ptr_bindings: PtrBindings,
13    mem_alignment: usize,
14}
15
16struct PtrBindings {
17    slots: Vec<cudarc::driver::sys::CUdeviceptr>,
18    cursor: usize,
19}
20
21impl PtrBindings {
22    fn new() -> Self {
23        Self {
24            slots: uninit_vec(crate::device::CUDA_MAX_BINDINGS as usize),
25            cursor: 0,
26        }
27    }
28
29    fn register(&mut self, ptr: u64) -> &u64 {
30        self.slots[self.cursor] = ptr;
31        let ptr = self.slots.get(self.cursor).unwrap();
32
33        self.cursor += 1;
34
35        if self.cursor >= self.slots.len() {
37            self.cursor = 0;
38        }
39
40        ptr
41    }
42}
43
44unsafe impl Send for CudaStorage {}
45
46impl core::fmt::Debug for CudaStorage {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.write_str(format!("CudaStorage {{ device: {:?} }}", self.stream).as_str())
49    }
50}
51
52impl CudaStorage {
54    pub fn new(mem_alignment: usize, stream: CUstream) -> Self {
56        Self {
57            memory: HashMap::new(),
58            deallocations: Vec::new(),
59            stream,
60            ptr_bindings: PtrBindings::new(),
61            mem_alignment,
62        }
63    }
64
65    pub fn perform_deallocations(&mut self) {
67        for id in self.deallocations.drain(..) {
68            if let Some(ptr) = self.memory.remove(&id) {
69                unsafe {
70                    cudarc::driver::result::free_async(ptr, self.stream).unwrap();
71                }
72            }
73        }
74    }
75}
76
77#[derive(new, Debug)]
79pub struct CudaResource {
80    pub ptr: u64,
82    pub binding: *mut std::ffi::c_void,
83    offset: u64,
84    size: u64,
85}
86
87unsafe impl Send for CudaResource {}
88
89pub type Binding = *mut std::ffi::c_void;
90
91impl CudaResource {
92    pub fn as_binding(&self) -> Binding {
94        self.binding
95    }
96
97    pub fn size(&self) -> u64 {
99        self.size
100    }
101
102    pub fn offset(&self) -> u64 {
104        self.offset
105    }
106}
107
108impl ComputeStorage for CudaStorage {
109    type Resource = CudaResource;
110    fn alignment(&self) -> usize {
111        self.mem_alignment
112    }
113
114    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
115        let ptr = self.memory.get(&handle.id).unwrap();
116
117        let offset = handle.offset();
118        let size = handle.size();
119        let ptr = self.ptr_bindings.register(ptr + offset);
120
121        CudaResource::new(
122            *ptr,
123            ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void,
124            offset,
125            size,
126        )
127    }
128
129    fn alloc(&mut self, size: u64) -> StorageHandle {
130        let id = StorageId::new();
131        let ptr =
132            unsafe { cudarc::driver::result::malloc_async(self.stream, size as usize).unwrap() };
133        self.memory.insert(id, ptr);
134        StorageHandle::new(id, StorageUtilization { offset: 0, size })
135    }
136
137    fn dealloc(&mut self, id: StorageId) {
138        self.deallocations.push(id);
139    }
140}