cubecl_hip/compute/storage/
cpu.rs1use cubecl_core::server::IoError;
2use cubecl_hip_sys::HIP_SUCCESS;
3use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
4use std::{collections::HashMap, ffi::c_void};
5
6pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
8
9pub struct PinnedMemoryStorage {
14 memory: HashMap<StorageId, PinnedMemory>,
15 mem_alignment: usize,
16}
17
18#[derive(Debug)]
20pub struct PinnedMemoryResource {
21 pub ptr: *mut u8,
23 pub size: usize,
25}
26
27#[derive(Debug)]
29struct PinnedMemory {
30 ptr: *mut c_void,
32 #[allow(unused)]
34 ptr2ptr: *mut *mut c_void,
35}
36
37impl Default for PinnedMemoryStorage {
38 fn default() -> Self {
39 Self::new()
40 }
41}
42
43impl PinnedMemoryStorage {
44 pub fn new() -> Self {
49 Self {
50 memory: HashMap::new(),
51 mem_alignment: PINNED_MEMORY_ALIGNMENT,
52 }
53 }
54}
55
56unsafe impl Send for PinnedMemoryResource {}
57unsafe impl Send for PinnedMemoryStorage {}
58
59impl ComputeStorage for PinnedMemoryStorage {
60 type Resource = PinnedMemoryResource;
61
62 fn alignment(&self) -> usize {
63 self.mem_alignment
64 }
65
66 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
67 let memory = self
68 .memory
69 .get(&handle.id)
70 .expect("Storage handle not found");
71
72 let offset = handle.offset() as usize;
73 let size = handle.size() as usize;
74
75 unsafe {
76 PinnedMemoryResource {
77 ptr: memory.ptr.cast::<u8>().add(offset),
78 size,
79 }
80 }
81 }
82
83 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
84 let resource = unsafe {
85 let mut ptr: *mut c_void = std::ptr::null_mut();
86 let ptr2ptr: *mut *mut c_void = &mut ptr;
87
88 let result = cubecl_hip_sys::hipMallocHost(ptr2ptr, size as usize);
89
90 if result != HIP_SUCCESS {
91 return Err(IoError::Unknown(format!(
92 "cuMemAllocHost_v2 failed with error code: {:?}",
93 result
94 )));
95 }
96
97 PinnedMemory { ptr, ptr2ptr }
98 };
99
100 let id = StorageId::new();
101 self.memory.insert(id, resource);
102 Ok(StorageHandle::new(
103 id,
104 StorageUtilization { offset: 0, size },
105 ))
106 }
107
108 fn dealloc(&mut self, id: StorageId) {
109 if let Some(resource) = self.memory.remove(&id) {
110 unsafe {
111 cubecl_hip_sys::hipFreeHost(resource.ptr);
112 }
113 }
114 }
115}