cubecl_hip/compute/storage/
cpu.rs

1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::server::IoError;
3use cubecl_hip_sys::HIP_SUCCESS;
4use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization};
5use std::{collections::HashMap, ffi::c_void};
6
7/// Memory alignment for pinned host memory, set to the size of `u128` for optimal performance.
8pub const PINNED_MEMORY_ALIGNMENT: usize = core::mem::size_of::<u128>();
9
10/// Manages pinned host memory for HIP operations.
11///
12/// This storage handles allocation and deallocation of pinned (page-locked) host memory,
13/// which is optimized for fast data transfers between host and GPU in HIP applications.
14pub struct PinnedMemoryStorage {
15    memory: HashMap<StorageId, PinnedMemory>,
16    mem_alignment: usize,
17}
18
19/// A pinned memory resource allocated on the host.
20#[derive(Debug)]
21pub struct PinnedMemoryResource {
22    /// Pointer to the pinned memory buffer.
23    pub ptr: *mut u8,
24    /// Size of the memory resource in bytes.
25    pub size: usize,
26}
27
28/// Internal representation of pinned memory with associated pointers.
29#[derive(Debug)]
30struct PinnedMemory {
31    /// Pointer to the pinned memory buffer.
32    ptr: *mut c_void,
33    /// Pointer-to-pointer for HIP allocation, kept alive for async operations.
34    #[allow(unused)]
35    ptr2ptr: *mut *mut c_void,
36}
37
38impl Default for PinnedMemoryStorage {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44impl PinnedMemoryStorage {
45    /// Creates a new [PinnedMemoryStorage] instance.
46    ///
47    /// Initializes the storage with the default pinned memory alignment
48    /// defined by [PINNED_MEMORY_ALIGNMENT].
49    pub fn new() -> Self {
50        Self {
51            memory: HashMap::new(),
52            mem_alignment: PINNED_MEMORY_ALIGNMENT,
53        }
54    }
55}
56
57unsafe impl Send for PinnedMemoryResource {}
58unsafe impl Send for PinnedMemoryStorage {}
59
60impl ComputeStorage for PinnedMemoryStorage {
61    type Resource = PinnedMemoryResource;
62
63    fn alignment(&self) -> usize {
64        self.mem_alignment
65    }
66
67    fn get(&mut self, handle: &StorageHandle) -> Self::Resource {
68        let memory = self
69            .memory
70            .get(&handle.id)
71            .expect("Storage handle not found");
72
73        let offset = handle.offset() as usize;
74        let size = handle.size() as usize;
75
76        unsafe {
77            PinnedMemoryResource {
78                ptr: memory.ptr.cast::<u8>().add(offset),
79                size,
80            }
81        }
82    }
83
84    #[cfg_attr(
85        feature = "tracing",
86        tracing::instrument(level = "trace", skip(self, size))
87    )]
88    fn alloc(&mut self, size: u64) -> Result<StorageHandle, IoError> {
89        let resource = unsafe {
90            let mut ptr: *mut c_void = std::ptr::null_mut();
91            let ptr2ptr: *mut *mut c_void = &mut ptr;
92
93            let result = cubecl_hip_sys::hipMallocHost(ptr2ptr, size as usize);
94
95            if result != HIP_SUCCESS {
96                return Err(IoError::Unknown {
97                    description: format!("cuMemAllocHost_v2 failed with error code: {result:?}"),
98                    backtrace: BackTrace::capture(),
99                });
100            }
101
102            PinnedMemory { ptr, ptr2ptr }
103        };
104
105        let id = StorageId::new();
106        self.memory.insert(id, resource);
107        Ok(StorageHandle::new(
108            id,
109            StorageUtilization { offset: 0, size },
110        ))
111    }
112
113    #[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
114    fn dealloc(&mut self, id: StorageId) {
115        if let Some(resource) = self.memory.remove(&id) {
116            unsafe {
117                cubecl_hip_sys::hipFreeHost(resource.ptr);
118            }
119        }
120    }
121}