cubecl_hip/compute/storage/
gpu.rs

1use crate::compute::uninit_vec;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::collections::HashMap;
6
7/// Buffer storage for AMD GPUs.
8///
9/// This struct manages memory resources for HIP kernels, allowing them to be used as bindings
10/// for launching kernels.
11pub struct GpuStorage {
12    mem_alignment: usize,
13    memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
14    deallocations: Vec<StorageId>,
15    ptr_bindings: PtrBindings,
16}
17
18/// A GPU memory resource allocated for HIP using [GpuStorage].
19#[derive(new, Debug)]
20pub struct GpuResource {
21    /// The GPU memory pointer.
22    pub ptr: cubecl_hip_sys::hipDeviceptr_t,
23    /// The HIP binding pointer.
24    pub binding: cubecl_hip_sys::hipDeviceptr_t,
25    /// The size of the resource.
26    pub size: u64,
27}
28
29impl GpuStorage {
30    /// Creates a new [GpuStorage] instance for the specified HIP stream.
31    ///
32    /// # Arguments
33    ///
34    /// * `mem_alignment` - The memory alignment requirement in bytes.
35    pub fn new(mem_alignment: usize) -> Self {
36        Self {
37            mem_alignment,
38            memory: HashMap::new(),
39            deallocations: Vec::new(),
40            ptr_bindings: PtrBindings::new(),
41        }
42    }
43
44    /// Deallocates buffers marked for deallocation.
45    ///
46    /// This method processes all pending deallocations by freeing the associated GPU memory.
47    pub fn perform_deallocations(&mut self) {
48        for id in self.deallocations.drain(..) {
49            if let Some(ptr) = self.memory.remove(&id) {
50                unsafe {
51                    cubecl_hip_sys::hipFree(ptr);
52                }
53            }
54        }
55    }
56}
57
58/// Manages active HIP buffer bindings in a ring buffer.
59///
60/// This ensures that pointers remain valid during kernel execution, preventing use-after-free errors.
61struct PtrBindings {
62    slots: Vec<u64>,
63    cursor: usize,
64}
65
66impl PtrBindings {
67    /// Creates a new [PtrBindings] instance with a fixed-size ring buffer.
68    fn new() -> Self {
69        Self {
70            slots: uninit_vec(crate::device::AMD_MAX_BINDINGS as usize),
71            cursor: 0,
72        }
73    }
74
75    /// Registers a new pointer in the ring buffer.
76    ///
77    /// # Arguments
78    ///
79    /// * `ptr` - The HIP device pointer to register.
80    ///
81    /// # Returns
82    ///
83    /// A reference to the registered pointer.
84    fn register(&mut self, ptr: u64) -> &u64 {
85        self.slots[self.cursor] = ptr;
86        let ptr_ref = self.slots.get(self.cursor).unwrap();
87
88        self.cursor += 1;
89
90        // Reset the cursor when the ring buffer is full.
91        if self.cursor >= self.slots.len() {
92            self.cursor = 0;
93        }
94
95        ptr_ref
96    }
97}
98
99impl ComputeStorage for GpuStorage {
100    type Resource = GpuResource;
101
102    fn alignment(&self) -> usize {
103        self.mem_alignment
104    }
105
106    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
107        let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
108
109        let offset = handle.offset();
110        let size = handle.size();
111        let ptr = self.ptr_bindings.register(ptr + offset);
112
113        GpuResource::new(
114            *ptr as cubecl_hip_sys::hipDeviceptr_t,
115            std::ptr::from_ref(ptr) as *mut std::ffi::c_void,
116            size,
117        )
118    }
119
120    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
121        let id = StorageId::new();
122        unsafe {
123            let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
124            let status = cubecl_hip_sys::hipMalloc(&mut dptr, size as usize);
125
126            match status {
127                HIP_SUCCESS => {}
128                other => {
129                    return Err(IoError::Unknown(format!("HIP allocation error: {}", other)));
130                }
131            }
132            self.memory.insert(id, dptr);
133        };
134        Ok(StorageHandle::new(
135            id,
136            StorageUtilization { offset: 0, size },
137        ))
138    }
139
140    fn dealloc(&mut self, id: StorageId) {
141        self.deallocations.push(id);
142    }
143}
144
145unsafe impl Send for GpuStorage {}
146unsafe impl Send for GpuResource {}
147
148impl core::fmt::Debug for GpuStorage {
149    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
150        f.write_str("GpuStorage".to_string().as_str())
151    }
152}