cubecl_hip/compute/storage/
gpu.rs1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::collections::HashMap;
6
7use crate::AMD_MAX_BINDINGS;
8
9pub struct GpuStorage {
14 mem_alignment: usize,
15 memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
16 deallocations: Vec<StorageId>,
17 ptr_bindings: PtrBindings,
18 stream: cubecl_hip_sys::hipStream_t,
19}
20
21#[derive(new, Debug)]
23pub struct GpuResource {
24 pub ptr: cubecl_hip_sys::hipDeviceptr_t,
26 pub binding: cubecl_hip_sys::hipDeviceptr_t,
28 pub size: u64,
30}
31
32impl GpuStorage {
33 pub fn new(mem_alignment: usize, stream: cubecl_hip_sys::hipStream_t) -> Self {
39 Self {
40 mem_alignment,
41 memory: HashMap::new(),
42 deallocations: Vec::new(),
43 ptr_bindings: PtrBindings::new(),
44 stream,
45 }
46 }
47
48 pub fn perform_deallocations(&mut self) {
52 for id in self.deallocations.drain(..) {
53 if let Some(ptr) = self.memory.remove(&id) {
54 unsafe {
57 cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
58 }
59 }
60 }
61 }
62}
63
64struct PtrBindings {
68 slots: Vec<u64>,
69 cursor: usize,
70}
71
72impl PtrBindings {
73 fn new() -> Self {
75 Self {
76 slots: vec![0; AMD_MAX_BINDINGS as usize],
77 cursor: 0,
78 }
79 }
80
81 fn register(&mut self, ptr: u64) -> &u64 {
91 self.slots[self.cursor] = ptr;
92 let ptr_ref = self.slots.get(self.cursor).unwrap();
93
94 self.cursor += 1;
95
96 if self.cursor >= self.slots.len() {
98 self.cursor = 0;
99 }
100
101 ptr_ref
102 }
103}
104
105impl ComputeStorage for GpuStorage {
106 type Resource = GpuResource;
107
108 fn alignment(&self) -> usize {
109 self.mem_alignment
110 }
111
112 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
113 let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
114
115 let offset = handle.offset();
116 let size = handle.size();
117 let ptr = self.ptr_bindings.register(ptr + offset);
118
119 GpuResource::new(
120 *ptr as cubecl_hip_sys::hipDeviceptr_t,
121 std::ptr::from_ref(ptr) as *mut std::ffi::c_void,
122 size,
123 )
124 }
125
126 #[cfg_attr(
127 feature = "tracing",
128 tracing::instrument(level = "trace", skip(self, size))
129 )]
130 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
131 let id = StorageId::new();
132 unsafe {
136 let mut ptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
137 let status = cubecl_hip_sys::hipMallocAsync(&mut ptr, size as usize, self.stream);
138
139 match status {
140 HIP_SUCCESS => {}
141 other => {
142 return Err(IoError::Unknown {
143 description: format!("HIP allocation error: {other}"),
144 backtrace: BackTrace::capture(),
145 });
146 }
147 }
148
149 cubecl_hip_sys::hipStreamSynchronize(self.stream);
151
152 self.memory.insert(id, ptr);
153 };
154
155 Ok(StorageHandle::new(
156 id,
157 StorageUtilization { offset: 0, size },
158 ))
159 }
160
161 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
162 fn dealloc(&mut self, id: StorageId) {
163 self.deallocations.push(id);
164 }
165
166 fn flush(&mut self) {
167 self.perform_deallocations();
168 }
169}
170
171unsafe impl Send for GpuStorage {}
175unsafe impl Send for GpuResource {}
178
179impl core::fmt::Debug for GpuStorage {
180 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181 f.write_str("GpuStorage".to_string().as_str())
182 }
183}