cubecl_cuda/compute/
storage.rs1use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
2use cudarc::driver::sys::CUstream;
3use std::collections::HashMap;
4
5use super::uninit_vec;
6
7pub struct CudaStorage {
9 memory: HashMap<StorageId, cudarc::driver::sys::CUdeviceptr>,
10 deallocations: Vec<StorageId>,
11 stream: cudarc::driver::sys::CUstream,
12 ptr_bindings: PtrBindings,
13 mem_alignment: usize,
14}
15
16struct PtrBindings {
17 slots: Vec<cudarc::driver::sys::CUdeviceptr>,
18 cursor: usize,
19}
20
21impl PtrBindings {
22 fn new() -> Self {
23 Self {
24 slots: uninit_vec(crate::device::CUDA_MAX_BINDINGS as usize),
25 cursor: 0,
26 }
27 }
28
29 fn register(&mut self, ptr: u64) -> &u64 {
30 self.slots[self.cursor] = ptr;
31 let ptr = self.slots.get(self.cursor).unwrap();
32
33 self.cursor += 1;
34
35 if self.cursor >= self.slots.len() {
37 self.cursor = 0;
38 }
39
40 ptr
41 }
42}
43
44unsafe impl Send for CudaStorage {}
45
46impl core::fmt::Debug for CudaStorage {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.write_str(format!("CudaStorage {{ device: {:?} }}", self.stream).as_str())
49 }
50}
51
52impl CudaStorage {
54 pub fn new(mem_alignment: usize, stream: CUstream) -> Self {
56 Self {
57 memory: HashMap::new(),
58 deallocations: Vec::new(),
59 stream,
60 ptr_bindings: PtrBindings::new(),
61 mem_alignment,
62 }
63 }
64
65 pub fn perform_deallocations(&mut self) {
67 for id in self.deallocations.drain(..) {
68 if let Some(ptr) = self.memory.remove(&id) {
69 unsafe {
70 cudarc::driver::result::free_async(ptr, self.stream).unwrap();
71 }
72 }
73 }
74 }
75}
76
77#[derive(new, Debug)]
79pub struct CudaResource {
80 pub ptr: u64,
82 pub binding: *mut std::ffi::c_void,
83 offset: u64,
84 size: u64,
85}
86
87unsafe impl Send for CudaResource {}
88
89pub type Binding = *mut std::ffi::c_void;
90
91impl CudaResource {
92 pub fn as_binding(&self) -> Binding {
94 self.binding
95 }
96
97 pub fn size(&self) -> u64 {
99 self.size
100 }
101
102 pub fn offset(&self) -> u64 {
104 self.offset
105 }
106}
107
108impl ComputeStorage for CudaStorage {
109 type Resource = CudaResource;
110 fn alignment(&self) -> usize {
111 self.mem_alignment
112 }
113
114 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
115 let ptr = self.memory.get(&handle.id).unwrap();
116
117 let offset = handle.offset();
118 let size = handle.size();
119 let ptr = self.ptr_bindings.register(ptr + offset);
120
121 CudaResource::new(
122 *ptr,
123 ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void,
124 offset,
125 size,
126 )
127 }
128
129 fn alloc(&mut self, size: u64) -> StorageHandle {
130 let id = StorageId::new();
131 let ptr =
132 unsafe { cudarc::driver::result::malloc_async(self.stream, size as usize).unwrap() };
133 self.memory.insert(id, ptr);
134 StorageHandle::new(id, StorageUtilization { offset: 0, size })
135 }
136
137 fn dealloc(&mut self, id: StorageId) {
138 self.deallocations.push(id);
139 }
140}