Skip to main content

cubecl_hip/compute/storage/
cpu.rs

1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::{collections::HashMap, ffi::c_void};
6
7/// Memory alignment for pinned host memory, set to the size of `u128` for optimal performance.
8pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
9
10/// Manages pinned host memory for HIP operations.
11///
12/// This storage handles allocation and deallocation of pinned (page-locked) host memory,
13/// which is optimized for fast data transfers between host and GPU in HIP applications.
14pub struct PinnedMemoryStorage {
15    memory: HashMap<StorageId, PinnedMemory>,
16    mem_alignment: usize,
17    stream: cubecl_hip_sys::hipStream_t,
18}
19
20/// A pinned memory resource allocated on the host.
21#[derive(Debug)]
22pub struct PinnedMemoryResource {
23    /// Pointer to the pinned memory buffer.
24    pub ptr: *mut u8,
25    /// Size of the memory resource in bytes.
26    pub size: usize,
27}
28
29/// Internal representation of pinned memory with associated pointers.
30#[derive(Debug)]
31struct PinnedMemory {
32    /// Pointer to the pinned memory buffer.
33    ptr: *mut c_void,
34    /// Device pointer: Pointer-to-pointer for HIP allocation, kept alive for async operations.
35    #[allow(unused)]
36    dev_ptr: *mut *mut c_void,
37}
38
39impl PinnedMemoryStorage {
40    /// Creates a new [`PinnedMemoryStorage`] instance.
41    ///
42    /// Initializes the storage with the default pinned memory alignment
43    /// defined by [`PINNED_MEMORY_ALIGNMENT`].
44    pub fn new(stream: cubecl_hip_sys::hipStream_t) -> Self {
45        Self {
46            memory: HashMap::new(),
47            mem_alignment: PINNED_MEMORY_ALIGNMENT,
48            stream,
49        }
50    }
51}
52
53// SAFETY: `PinnedMemoryResource` contains a raw pointer to page-locked host memory.
54// It is safe to send between threads because the memory remains valid and pinned
55// regardless of which thread accesses it, and access is serialized by the `DeviceHandle`.
56unsafe impl Send for PinnedMemoryResource {}
57// SAFETY: `PinnedMemoryStorage` is only accessed from one thread at a time via the
58// `DeviceHandle`, which serializes all server access. The HIP stream and pinned memory
59// it manages are never shared across threads without synchronization.
60unsafe impl Send for PinnedMemoryStorage {}
61
62impl ComputeStorage for PinnedMemoryStorage {
63    type Resource = PinnedMemoryResource;
64
65    fn alignment(&self) -> usize {
66        self.mem_alignment
67    }
68
69    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
70        let memory = self
71            .memory
72            .get(&handle.id)
73            .expect("Storage handle not found");
74
75        let offset = handle.offset() as usize;
76        let size = handle.size() as usize;
77
78        // SAFETY: `memory.ptr` was allocated by `hipHostMalloc` with at least `offset + size`
79        // bytes. The `add(offset)` produces a pointer within the allocation bounds as
80        // guaranteed by the storage handle's offset/size validation.
81        unsafe {
82            PinnedMemoryResource {
83                ptr: memory.ptr.cast::<u8>().add(offset),
84                size,
85            }
86        }
87    }
88
89    #[cfg_attr(
90        feature = "tracing",
91        tracing::instrument(level = "trace", skip(self, size))
92    )]
93    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
94        // SAFETY: Calling HIP FFI to allocate page-locked (pinned) host memory. The
95        // `hipHostMallocMapped` flag makes the memory accessible from both host and device.
96        // We synchronize the stream afterward to ensure the allocation is visible.
97        // The returned pointer is stored and freed via `hipFreeHost` on deallocation.
98        let resource = unsafe {
99            let mut ptr: *mut c_void = std::ptr::null_mut();
100            let dev_ptr: *mut *mut c_void = &mut ptr;
101
102            let result = cubecl_hip_sys::hipHostMalloc(
103                dev_ptr,
104                size as usize,
105                cubecl_hip_sys::hipHostMallocMapped,
106            );
107
108            if result != HIP_SUCCESS {
109                return Err(IoError::Unknown {
110                    description: format!("cuMemAllocHost_v2 failed with error code: {result:?}"),
111                    backtrace: BackTrace::capture(),
112                });
113            }
114
115            // For safety, reducing the odds of missing mapped memory page.
116            cubecl_hip_sys::hipStreamSynchronize(self.stream);
117
118            PinnedMemory { ptr, dev_ptr }
119        };
120
121        let id = StorageId::new();
122        self.memory.insert(id, resource);
123        Ok(StorageHandle::new(
124            id,
125            StorageUtilization { offset: 0, size },
126        ))
127    }
128
129    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
130    fn dealloc(&mut self, id: StorageId) {
131        if let Some(resource) = self.memory.remove(&id) {
132            // SAFETY: `resource.ptr` was allocated by `hipHostMalloc` and has not been freed
133            // yet. After this call, the pointer is invalid and removed from `self.memory`.
134            unsafe {
135                cubecl_hip_sys::hipFreeHost(resource.ptr);
136            }
137        }
138    }
139
140    fn flush(&mut self) {
141        // We don't wait for dealloc.
142    }
143}