cubecl_hip/compute/
storage.rs

1use cubecl_hip_sys::HIP_SUCCESS;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use std::collections::HashMap;
4
5/// Buffer storage for HIP.
6pub struct HipStorage {
7    mem_alignment: usize,
8    memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
9    deallocations: Vec<StorageId>,
10    stream: cubecl_hip_sys::hipStream_t,
11    activate_slices: HashMap<ActiveResource, cubecl_hip_sys::hipDeviceptr_t>,
12}
13
14#[derive(new, Debug, Hash, PartialEq, Eq, Clone)]
15struct ActiveResource {
16    ptr: u64,
17}
18
19unsafe impl Send for HipStorage {}
20
21impl core::fmt::Debug for HipStorage {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        f.write_str(format!("HipStorage {{ device: {:?} }}", self.stream).as_str())
24    }
25}
26
27/// Keeps actual HIP buffer references in a hashmap with ids as key.
28impl HipStorage {
29    /// Create a new storage on the given stream.
30    pub fn new(mem_alignment: usize, stream: cubecl_hip_sys::hipStream_t) -> Self {
31        Self {
32            mem_alignment,
33            memory: HashMap::new(),
34            deallocations: Vec::new(),
35            stream,
36            activate_slices: HashMap::new(),
37        }
38    }
39
40    /// Actually deallocates buffers tagged to be deallocated.
41    pub fn perform_deallocations(&mut self) {
42        for id in self.deallocations.drain(..) {
43            if let Some(ptr) = self.memory.remove(&id) {
44                unsafe {
45                    cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
46                }
47            }
48        }
49    }
50
51    pub fn flush(&mut self) {
52        self.activate_slices.clear();
53    }
54}
55
56pub type Binding = cubecl_hip_sys::hipDeviceptr_t;
57
58/// The memory resource that can be allocated for the device.
59#[derive(new, Debug)]
60pub struct HipResource {
61    /// The buffer.
62    pub ptr: cubecl_hip_sys::hipDeviceptr_t,
63    pub binding: Binding,
64    pub offset: u64,
65    pub size: u64,
66}
67
68unsafe impl Send for HipResource {}
69
70impl ComputeStorage for HipStorage {
71    type Resource = HipResource;
72
73    fn alignment(&self) -> usize {
74        self.mem_alignment
75    }
76
77    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
78        let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
79
80        let offset = handle.offset();
81        let size = handle.size();
82
83        let ptr = ptr + offset;
84        let key = ActiveResource::new(ptr);
85
86        self.activate_slices
87            .insert(key.clone(), ptr as cubecl_hip_sys::hipDeviceptr_t);
88
89        // The ptr needs to stay alive until we send the task to the server.
90        let ptr = self.activate_slices.get(&key).unwrap();
91
92        HipResource::new(
93            *ptr,
94            ptr as *const cubecl_hip_sys::hipDeviceptr_t as *mut std::ffi::c_void,
95            offset,
96            size,
97        )
98    }
99
100    fn alloc(&mut self, size: u64) -> StorageHandle {
101        let id = StorageId::new();
102        unsafe {
103            let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
104            let status = cubecl_hip_sys::hipMallocAsync(&mut dptr, size as usize, self.stream);
105            assert_eq!(status, HIP_SUCCESS, "Should allocate memory");
106            self.memory.insert(id, dptr);
107        };
108        StorageHandle::new(id, StorageUtilization { offset: 0, size })
109    }
110
111    fn dealloc(&mut self, id: StorageId) {
112        self.deallocations.push(id);
113    }
114}