cubecl_hip/compute/
storage.rs1use cubecl_hip_sys::HIP_SUCCESS;
2use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
3use std::collections::HashMap;
4
5pub struct HipStorage {
7 memory: HashMap<StorageId, cubecl_hip_sys::hipDeviceptr_t>,
8 deallocations: Vec<StorageId>,
9 stream: cubecl_hip_sys::hipStream_t,
10 activate_slices: HashMap<ActiveResource, cubecl_hip_sys::hipDeviceptr_t>,
11}
12
13#[derive(new, Debug, Hash, PartialEq, Eq, Clone)]
14struct ActiveResource {
15 ptr: u64,
16}
17
18unsafe impl Send for HipStorage {}
19
20impl core::fmt::Debug for HipStorage {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.write_str(format!("HipStorage {{ device: {:?} }}", self.stream).as_str())
23 }
24}
25
26impl HipStorage {
28 pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
30 Self {
31 memory: HashMap::new(),
32 deallocations: Vec::new(),
33 stream,
34 activate_slices: HashMap::new(),
35 }
36 }
37
38 pub fn perform_deallocations(&mut self) {
40 for id in self.deallocations.drain(..) {
41 if let Some(ptr) = self.memory.remove(&id) {
42 unsafe {
43 cubecl_hip_sys::hipFreeAsync(ptr, self.stream);
44 }
45 }
46 }
47 }
48
49 pub fn flush(&mut self) {
50 self.activate_slices.clear();
51 }
52}
53
54pub type Binding = cubecl_hip_sys::hipDeviceptr_t;
55
56#[derive(new, Debug)]
58pub struct HipResource {
59 pub ptr: cubecl_hip_sys::hipDeviceptr_t,
61 pub binding: Binding,
62 pub offset: u64,
63 pub size: u64,
64}
65
66unsafe impl Send for HipResource {}
67
68impl ComputeStorage for HipStorage {
69 const ALIGNMENT: u64 = 32;
70
71 type Resource = HipResource;
72
73 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
74 let ptr = (*self.memory.get(&handle.id).unwrap()) as u64;
75
76 let offset = handle.offset();
77 let size = handle.size();
78
79 let ptr = ptr + offset;
80 let key = ActiveResource::new(ptr);
81
82 self.activate_slices
83 .insert(key.clone(), ptr as cubecl_hip_sys::hipDeviceptr_t);
84
85 let ptr = self.activate_slices.get(&key).unwrap();
87
88 HipResource::new(
89 *ptr,
90 ptr as *const cubecl_hip_sys::hipDeviceptr_t as *mut std::ffi::c_void,
91 offset,
92 size,
93 )
94 }
95
96 fn alloc(&mut self, size: u64) -> StorageHandle {
97 let id = StorageId::new();
98 unsafe {
99 let mut dptr: *mut ::std::os::raw::c_void = std::ptr::null_mut();
100 let status = cubecl_hip_sys::hipMallocAsync(&mut dptr, size as usize, self.stream);
101 assert_eq!(status, HIP_SUCCESS, "Should allocate memory");
102 self.memory.insert(id, dptr);
103 };
104 StorageHandle::new(id, StorageUtilization { offset: 0, size })
105 }
106
107 fn dealloc(&mut self, id: StorageId) {
108 self.deallocations.push(id);
109 }
110}