Skip to main content

cubecl_hip/compute/storage/
gpu.rs

1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::collections::HashMap;
6
7use crate::AMD_MAX_BINDINGS;
8
9/// Buffer storage for AMD GPUs.
10///
11/// This struct manages memory resources for HIP kernels, allowing them to be used as bindings
12/// for launching kernels.
13pub struct GpuStorage {
14    mem_alignment: usize,
15    memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
16    deallocations: Vec<StorageId>,
17    ptr_bindings: PtrBindings,
18    stream: cubecl_hip_sys::hipStream_t,
19}
20
21/// A GPU memory resource allocated for HIP using [`GpuStorage`].
22#[derive(new, Debug)]
23pub struct GpuResource {
24    /// The GPU memory pointer.
25    pub ptr: cubecl_hip_sys::hipDeviceptr_t,
26    /// The HIP binding pointer.
27    pub binding: cubecl_hip_sys::hipDeviceptr_t,
28    /// The size of the resource.
29    pub size: u64,
30}
31
32impl GpuStorage {
33    /// Creates a new [`GpuStorage`] instance for the specified HIP stream.
34    ///
35    /// # Arguments
36    ///
37    /// * `mem_alignment` - The memory alignment requirement in bytes.
38    pub fn new(mem_alignment: usize, stream: cubecl_hip_sys::hipStream_t) -> Self {
39        Self {
40            mem_alignment,
41            memory: HashMap::new(),
42            deallocations: Vec::new(),
43            ptr_bindings: PtrBindings::new(),
44            stream,
45        }
46    }
47
48    /// Deallocates buffers marked for deallocation.
49    ///
50    /// This method processes all pending deallocations by freeing the associated GPU memory.
51    pub fn perform_deallocations(&mut self) {
52        for id in self.deallocations.drain(..) {
53            if let Some(ptr) = self.memory.remove(&id) {
54                // SAFETY: `ptr` was obtained from a prior `hipMallocAsync` call and has not
55                // been freed yet. `self.stream` is the same stream used for allocation.
56                unsafe {
57                    cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
58                }
59            }
60        }
61    }
62}
63
64/// Manages active HIP buffer bindings in a ring buffer.
65///
66/// This ensures that pointers remain valid during kernel execution, preventing use-after-free errors.
67struct PtrBindings {
68    slots: Vec<u64>,
69    cursor: usize,
70}
71
72impl PtrBindings {
73    /// Creates a new [`PtrBindings`] instance with a fixed-size ring buffer.
74    fn new() -> Self {
75        Self {
76            slots: vec![0; AMD_MAX_BINDINGS as usize],
77            cursor: 0,
78        }
79    }
80
81    /// Registers a new pointer in the ring buffer.
82    ///
83    /// # Arguments
84    ///
85    /// * `ptr` - The HIP device pointer to register.
86    ///
87    /// # Returns
88    ///
89    /// A reference to the registered pointer.
90    fn register(&mut self, ptr: u64) -> &u64 {
91        self.slots[self.cursor] = ptr;
92        let ptr_ref = self.slots.get(self.cursor).unwrap();
93
94        self.cursor += 1;
95
96        // Reset the cursor when the ring buffer is full.
97        if self.cursor >= self.slots.len() {
98            self.cursor = 0;
99        }
100
101        ptr_ref
102    }
103}
104
105impl ComputeStorage for GpuStorage {
106    type Resource = GpuResource;
107
108    fn alignment(&self) -> usize {
109        self.mem_alignment
110    }
111
112    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
113        let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
114
115        let offset = handle.offset();
116        let size = handle.size();
117        let ptr = self.ptr_bindings.register(ptr + offset);
118
119        GpuResource::new(
120            *ptr as cubecl_hip_sys::hipDeviceptr_t,
121            std::ptr::from_ref(ptr) as *mut std::ffi::c_void,
122            size,
123        )
124    }
125
126    #[cfg_attr(
127        feature = "tracing",
128        tracing::instrument(level = "trace", skip(self, size))
129    )]
130    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
131        let id = StorageId::new();
132        // SAFETY: Calling HIP FFI to allocate device memory asynchronously. The returned
133        // pointer is valid after stream synchronization (performed below). The pointer is
134        // stored in `self.memory` and will be freed via `hipFreeAsync` on deallocation.
135        unsafe {
136            let mut ptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
137            let status = cubecl_hip_sys::hipMallocAsync(&mut ptr, size as usize, self.stream);
138
139            match status {
140                HIP_SUCCESS => {}
141                other => {
142                    return Err(IoError::Unknown {
143                        description: format!("HIP allocation error: {other}"),
144                        backtrace: BackTrace::capture(),
145                    });
146                }
147            }
148
149            // For safety, reducing the odds of missing mapped memory page.
150            cubecl_hip_sys::hipStreamSynchronize(self.stream);
151
152            self.memory.insert(id, ptr);
153        };
154
155        Ok(StorageHandle::new(
156            id,
157            StorageUtilization { offset: 0, size },
158        ))
159    }
160
161    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
162    fn dealloc(&mut self, id: StorageId) {
163        self.deallocations.push(id);
164    }
165
166    fn flush(&mut self) {
167        self.perform_deallocations();
168    }
169}
170
171// SAFETY: `GpuStorage` is only accessed from one thread at a time via the `DeviceHandle`,
172// which serializes all server access. The raw HIP pointers it contains are never shared
173// across threads without synchronization.
174unsafe impl Send for GpuStorage {}
175// SAFETY: `GpuResource` contains raw HIP device pointers that are safe to send between
176// threads as long as proper stream synchronization is maintained by the caller.
177unsafe impl Send for GpuResource {}
178
179impl core::fmt::Debug for GpuStorage {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        f.write_str("GpuStorage".to_string().as_str())
182    }
183}