cubecl_hip/compute/storage/
cpu.rs1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::{collections::HashMap, ffi::c_void};
6
7pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
9
10pub struct PinnedMemoryStorage {
15 memory: HashMap<StorageId, PinnedMemory>,
16 mem_alignment: usize,
17}
18
19#[derive(Debug)]
21pub struct PinnedMemoryResource {
22 pub ptr: *mut u8,
24 pub size: usize,
26}
27
28#[derive(Debug)]
30struct PinnedMemory {
31 ptr: *mut c_void,
33 #[allow(unused)]
35 ptr2ptr: *mut *mut c_void,
36}
37
38impl Default for PinnedMemoryStorage {
39 fn default() -> Self {
40 Self::new()
41 }
42}
43
44impl PinnedMemoryStorage {
45 pub fn new() -> Self {
50 Self {
51 memory: HashMap::new(),
52 mem_alignment: PINNED_MEMORY_ALIGNMENT,
53 }
54 }
55}
56
57unsafe impl Send for PinnedMemoryResource {}
58unsafe impl Send for PinnedMemoryStorage {}
59
60impl ComputeStorage for PinnedMemoryStorage {
61 type Resource = PinnedMemoryResource;
62
63 fn alignment(&self) -> usize {
64 self.mem_alignment
65 }
66
67 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
68 let memory = self
69 .memory
70 .get(&handle.id)
71 .expect("Storage handle not found");
72
73 let offset = handle.offset() as usize;
74 let size = handle.size() as usize;
75
76 unsafe {
77 PinnedMemoryResource {
78 ptr: memory.ptr.cast::<u8>().add(offset),
79 size,
80 }
81 }
82 }
83
84 #[cfg_attr(
85 feature = "tracing",
86 tracing::instrument(level = "trace", skip(self, size))
87 )]
88 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
89 let resource = unsafe {
90 let mut ptr: *mut c_void = std::ptr::null_mut();
91 let ptr2ptr: *mut *mut c_void = &mut ptr;
92
93 let result = cubecl_hip_sys::hipMallocHost(ptr2ptr, size as usize);
94
95 if result != HIP_SUCCESS {
96 return Err(IoError::Unknown {
97 description: format!("cuMemAllocHost_v2 failed with error code: {result:?}"),
98 backtrace: BackTrace::capture(),
99 });
100 }
101
102 PinnedMemory { ptr, ptr2ptr }
103 };
104
105 let id = StorageId::new();
106 self.memory.insert(id, resource);
107 Ok(StorageHandle::new(
108 id,
109 StorageUtilization { offset: 0, size },
110 ))
111 }
112
113 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
114 fn dealloc(&mut self, id: StorageId) {
115 if let Some(resource) = self.memory.remove(&id) {
116 unsafe {
117 cubecl_hip_sys::hipFreeHost(resource.ptr);
118 }
119 }
120 }
121}