cubecl_hip/compute/storage/cpu.rs
1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::{collections::HashMap, ffi::c_void};
6
7/// Memory alignment for pinned host memory, set to the size of `u128` for optimal performance.
8pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
9
10/// Manages pinned host memory for HIP operations.
11///
12/// This storage handles allocation and deallocation of pinned (page-locked) host memory,
13/// which is optimized for fast data transfers between host and GPU in HIP applications.
14pub struct PinnedMemoryStorage {
15 memory: HashMap<StorageId, PinnedMemory>,
16 mem_alignment: usize,
17 stream: cubecl_hip_sys::hipStream_t,
18}
19
20/// A pinned memory resource allocated on the host.
21#[derive(Debug)]
22pub struct PinnedMemoryResource {
23 /// Pointer to the pinned memory buffer.
24 pub ptr: *mut u8,
25 /// Size of the memory resource in bytes.
26 pub size: usize,
27}
28
29/// Internal representation of pinned memory with associated pointers.
30#[derive(Debug)]
31struct PinnedMemory {
32 /// Pointer to the pinned memory buffer.
33 ptr: *mut c_void,
34 /// Device pointer: Pointer-to-pointer for HIP allocation, kept alive for async operations.
35 #[allow(unused)]
36 dev_ptr: *mut *mut c_void,
37}
38
39impl PinnedMemoryStorage {
40 /// Creates a new [`PinnedMemoryStorage`] instance.
41 ///
42 /// Initializes the storage with the default pinned memory alignment
43 /// defined by [`PINNED_MEMORY_ALIGNMENT`].
44 pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
45 Self {
46 memory: HashMap::new(),
47 mem_alignment: PINNED_MEMORY_ALIGNMENT,
48 stream,
49 }
50 }
51}
52
53// SAFETY: `PinnedMemoryResource` contains a raw pointer to page-locked host memory.
54// It is safe to send between threads because the memory remains valid and pinned
55// regardless of which thread accesses it, and access is serialized by the `DeviceHandle`.
56unsafe impl Send for PinnedMemoryResource {}
57// SAFETY: `PinnedMemoryStorage` is only accessed from one thread at a time via the
58// `DeviceHandle`, which serializes all server access. The HIP stream and pinned memory
59// it manages are never shared across threads without synchronization.
60unsafe impl Send for PinnedMemoryStorage {}
61
62impl ComputeStorage for PinnedMemoryStorage {
63 type Resource = PinnedMemoryResource;
64
65 fn alignment(&self) -> usize {
66 self.mem_alignment
67 }
68
69 fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
70 let memory = self
71 .memory
72 .get(&handle.id)
73 .expect("Storage handle not found");
74
75 let offset = handle.offset() as usize;
76 let size = handle.size() as usize;
77
78 // SAFETY: `memory.ptr` was allocated by `hipHostMalloc` with at least `offset + size`
79 // bytes. The `add(offset)` produces a pointer within the allocation bounds as
80 // guaranteed by the storage handle's offset/size validation.
81 unsafe {
82 PinnedMemoryResource {
83 ptr: memory.ptr.cast::<u8>().add(offset),
84 size,
85 }
86 }
87 }
88
89 #[cfg_attr(
90 feature = "tracing",
91 tracing::instrument(level = "trace", skip(self, size))
92 )]
93 fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
94 // SAFETY: Calling HIP FFI to allocate page-locked (pinned) host memory. The
95 // `hipHostMallocMapped` flag makes the memory accessible from both host and device.
96 // We synchronize the stream afterward to ensure the allocation is visible.
97 // The returned pointer is stored and freed via `hipFreeHost` on deallocation.
98 let resource = unsafe {
99 let mut ptr: *mut c_void = std::ptr::null_mut();
100 let dev_ptr: *mut *mut c_void = &mut ptr;
101
102 let result = cubecl_hip_sys::hipHostMalloc(
103 dev_ptr,
104 size as usize,
105 cubecl_hip_sys::hipHostMallocMapped,
106 );
107
108 if result != HIP_SUCCESS {
109 return Err(IoError::Unknown {
110 description: format!("cuMemAllocHost_v2 failed with error code: {result:?}"),
111 backtrace: BackTrace::capture(),
112 });
113 }
114
115 // For safety, reducing the odds of missing mapped memory page.
116 cubecl_hip_sys::hipStreamSynchronize(self.stream);
117
118 PinnedMemory { ptr, dev_ptr }
119 };
120
121 let id = StorageId::new();
122 self.memory.insert(id, resource);
123 Ok(StorageHandle::new(
124 id,
125 StorageUtilization { offset: 0, size },
126 ))
127 }
128
129 #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
130 fn dealloc(&mut self, id: StorageId) {
131 if let Some(resource) = self.memory.remove(&id) {
132 // SAFETY: `resource.ptr` was allocated by `hipHostMalloc` and has not been freed
133 // yet. After this call, the pointer is invalid and removed from `self.memory`.
134 unsafe {
135 cubecl_hip_sys::hipFreeHost(resource.ptr);
136 }
137 }
138 }
139
140 fn flush(&mut self) {
141 // We don't wait for dealloc.
142 }
143}