cubecl_hip/compute/
storage.rs

1use cubecl_hip_sys::HIP_SUCCESS;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use std::collections::HashMap;
4
5/// Buffer storage for HIP.
6pub struct HipStorage {
7    memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
8    deallocations: Vec<StorageId>,
9    stream: cubecl_hip_sys::hipStream_t,
10    activate_slices: HashMap<ActiveResource, cubecl_hip_sys::hipDeviceptr_t>,
11}
12
13#[derive(new, Debug, Hash, PartialEq, Eq, Clone)]
14struct ActiveResource {
15    ptr: u64,
16}
17
18unsafe impl Send for HipStorage {}
19
20impl core::fmt::Debug for HipStorage {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.write_str(format!("HipStorage {{ device: {:?} }}", self.stream).as_str())
23    }
24}
25
26/// Keeps actual HIP buffer references in a hashmap with ids as key.
27impl HipStorage {
28    /// Create a new storage on the given stream.
29    pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
30        Self {
31            memory: HashMap::new(),
32            deallocations: Vec::new(),
33            stream,
34            activate_slices: HashMap::new(),
35        }
36    }
37
38    /// Actually deallocates buffers tagged to be deallocated.
39    pub fn perform_deallocations(&mut self) {
40        for id in self.deallocations.drain(..) {
41            if let Some(ptr) = self.memory.remove(&id) {
42                unsafe {
43                    cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
44                }
45            }
46        }
47    }
48
49    pub fn flush(&mut self) {
50        self.activate_slices.clear();
51    }
52}
53
54pub type Binding = cubecl_hip_sys::hipDeviceptr_t;
55
56/// The memory resource that can be allocated for the device.
57#[derive(new, Debug)]
58pub struct HipResource {
59    /// The buffer.
60    pub ptr: cubecl_hip_sys::hipDeviceptr_t,
61    pub binding: Binding,
62    pub offset: u64,
63    pub size: u64,
64}
65
66unsafe impl Send for HipResource {}
67
68impl ComputeStorage for HipStorage {
69    const ALIGNMENT: u64 = 32;
70
71    type Resource = HipResource;
72
73    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
74        let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
75
76        let offset = handle.offset();
77        let size = handle.size();
78
79        let ptr = ptr + offset;
80        let key = ActiveResource::new(ptr);
81
82        self.activate_slices
83            .insert(key.clone(), ptr as cubecl_hip_sys::hipDeviceptr_t);
84
85        // The ptr needs to stay alive until we send the task to the server.
86        let ptr = self.activate_slices.get(&key).unwrap();
87
88        HipResource::new(
89            *ptr,
90            ptr as *const cubecl_hip_sys::hipDeviceptr_t as *mut std::ffi::c_void,
91            offset,
92            size,
93        )
94    }
95
96    fn alloc(&mut self, size: u64) -> StorageHandle {
97        let id = StorageId::new();
98        unsafe {
99            let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
100            let status = cubecl_hip_sys::hipMallocAsync(&mut dptr, size as usize, self.stream);
101            assert_eq!(status, HIP_SUCCESS, "Should allocate memory");
102            self.memory.insert(id, dptr);
103        };
104        StorageHandle::new(id, StorageUtilization { offset: 0, size })
105    }
106
107    fn dealloc(&mut self, id: StorageId) {
108        self.deallocations.push(id);
109    }
110}