cubecl_hip/compute/
storage.rs1use cubecl_hip_sys::HIP_SUCCESS;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use std::collections::HashMap;
4
5pub struct HipStorage {
7 mem_alignment: usize,
8 memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
9 deallocations: Vec<StorageId>,
10 stream: cubecl_hip_sys::hipStream_t,
11 activate_slices: HashMap<ActiveResource, cubecl_hip_sys::hipDeviceptr_t>,
12}
13
14#[derive(new, Debug, Hash, PartialEq, Eq, Clone)]
15struct ActiveResource {
16 ptr: u64,
17}
18
19unsafe impl Send for HipStorage {}
20
21impl core::fmt::Debug for HipStorage {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.write_str(format!("HipStorage {{ device: {:?} }}", self.stream).as_str())
24 }
25}
26
27impl HipStorage {
29 pub fn new(mem_alignment: usize, stream: cubecl_hip_sys::hipStream_t) -> Self {
31 Self {
32 mem_alignment,
33 memory: HashMap::new(),
34 deallocations: Vec::new(),
35 stream,
36 activate_slices: HashMap::new(),
37 }
38 }
39
40 pub fn perform_deallocations(&mut self) {
42 for id in self.deallocations.drain(..) {
43 if let Some(ptr) = self.memory.remove(&id) {
44 unsafe {
45 cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
46 }
47 }
48 }
49 }
50
51 pub fn flush(&mut self) {
52 self.activate_slices.clear();
53 }
54}
55
56pub type Binding = cubecl_hip_sys::hipDeviceptr_t;
57
58#[derive(new, Debug)]
60pub struct HipResource {
61 pub ptr: cubecl_hip_sys::hipDeviceptr_t,
63 pub binding: Binding,
64 pub offset: u64,
65 pub size: u64,
66}
67
68unsafe impl Send for HipResource {}
69
70impl ComputeStorage for HipStorage {
71 type Resource = HipResource;
72
73 fn alignment(&self) -> usize {
74 self.mem_alignment
75 }
76
77 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
78 let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
79
80 let offset = handle.offset();
81 let size = handle.size();
82
83 let ptr = ptr + offset;
84 let key = ActiveResource::new(ptr);
85
86 self.activate_slices
87 .insert(key.clone(), ptr as cubecl_hip_sys::hipDeviceptr_t);
88
89 let ptr = self.activate_slices.get(&key).unwrap();
91
92 HipResource::new(
93 *ptr,
94 ptr as *const cubecl_hip_sys::hipDeviceptr_t as *mut std::ffi::c_void,
95 offset,
96 size,
97 )
98 }
99
100 fn alloc(&mut self, size: u64) -> StorageHandle {
101 let id = StorageId::new();
102 unsafe {
103 let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
104 let status = cubecl_hip_sys::hipMallocAsync(&mut dptr, size as usize, self.stream);
105 assert_eq!(status, HIP_SUCCESS, "Should allocate memory");
106 self.memory.insert(id, dptr);
107 };
108 StorageHandle::new(id, StorageUtilization { offset: 0, size })
109 }
110
111 fn dealloc(&mut self, id: StorageId) {
112 self.deallocations.push(id);
113 }
114}