1use super::{MemoryDescriptor, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use cudarc::driver::sys;
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12
13fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
15 static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
16 let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
17
18 if let Some(existing) = map.get(&device_id) {
19 return Ok(existing.clone());
20 }
21
22 let ctx = CudaContext::new(device_id as usize)?;
23 map.insert(device_id, ctx.clone());
24 Ok(ctx)
25}
26
27#[derive(Debug)]
29pub struct PinnedStorage {
30 ptr: usize,
32 len: usize,
34 ctx: Arc<CudaContext>,
36}
37
38unsafe impl Send for PinnedStorage {}
39unsafe impl Sync for PinnedStorage {}
40
41impl PinnedStorage {
42 pub fn new(len: usize) -> Result<Self> {
49 Self::new_for_device(len, None)
50 }
51
52 pub fn new_for_device(len: usize, device_id: Option<u32>) -> Result<Self> {
71 if len == 0 {
72 return Err(StorageError::AllocationFailed(
73 "zero-sized allocations are not supported".into(),
74 ));
75 }
76
77 let gpu_id = device_id.unwrap_or(0);
78 let ctx = cuda_context(gpu_id)?;
79
80 let ptr = match device_id {
81 #[cfg(target_os = "linux")]
82 Some(gpu_id) if super::numa::is_numa_enabled() => {
83 tracing::debug!(
84 "Using NUMA-aware allocation for {} bytes on GPU {}",
85 len,
86 gpu_id
87 );
88 super::numa::worker_pool::NumaWorkerPool::global()
89 .allocate_pinned_for_gpu(len, gpu_id)
90 .map_err(StorageError::AllocationFailed)? as usize
91 }
92 _ => unsafe {
93 ctx.bind_to_thread().map_err(StorageError::Cuda)?;
94
95 let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_DEVICEMAP)
96 .map_err(StorageError::Cuda)?;
97
98 let ptr = ptr as *mut u8;
99 assert!(!ptr.is_null(), "Failed to allocate pinned memory");
100 assert!(ptr.is_aligned(), "Pinned memory is not aligned");
101 assert!(len < isize::MAX as usize);
102
103 ptr as usize
104 },
105 };
106
107 Ok(Self { ptr, len, ctx })
108 }
109
110 pub unsafe fn as_ptr(&self) -> *const u8 {
115 self.ptr as *const u8
116 }
117
118 pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
124 self.ptr as *mut u8
125 }
126}
127
128impl Drop for PinnedStorage {
129 fn drop(&mut self) {
130 if let Err(e) = self.ctx.bind_to_thread() {
131 tracing::debug!("failed to bind CUDA context for free: {e}");
132 }
133 unsafe {
134 if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
135 tracing::debug!("failed to free pinned memory: {e}");
136 }
137 };
138 }
139}
140
141impl MemoryDescriptor for PinnedStorage {
142 fn addr(&self) -> usize {
143 unsafe { self.as_ptr() as usize }
144 }
145
146 fn size(&self) -> usize {
147 self.len
148 }
149
150 fn storage_kind(&self) -> StorageKind {
151 StorageKind::Pinned
152 }
153
154 fn as_any(&self) -> &dyn Any {
155 self
156 }
157
158 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
159 None
160 }
161}
162
163impl super::nixl::NixlCompatible for PinnedStorage {
165 fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
166 let ptr = unsafe { self.as_ptr() };
167 (ptr, self.len, nixl_sys::MemType::Dram, 0)
168 }
169}
170
171impl actions::Memset for PinnedStorage {
172 fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
173 let end = offset
174 .checked_add(size)
175 .ok_or_else(|| StorageError::OperationFailed("memset: offset overflow".into()))?;
176 if end > self.len {
177 return Err(StorageError::OperationFailed(
178 "memset: offset + size > storage size".into(),
179 ));
180 }
181 unsafe {
182 let ptr = (self.ptr as *mut u8).add(offset);
183 std::ptr::write_bytes(ptr, value, size);
184 }
185 Ok(())
186 }
187}