1use super::{MemoryDescription, Result, StorageError, StorageKind, actions, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use cudarc::driver::sys;
9use std::any::Any;
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex, OnceLock};
12
13fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
15 static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
16 let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
17
18 if let Some(existing) = map.get(&device_id) {
19 return Ok(existing.clone());
20 }
21
22 let ctx = CudaContext::new(device_id as usize)?;
23 map.insert(device_id, ctx.clone());
24 Ok(ctx)
25}
26
27#[derive(Debug)]
29pub struct PinnedStorage {
30 ptr: usize,
31 len: usize,
32 ctx: Arc<CudaContext>,
33}
34
35unsafe impl Send for PinnedStorage {}
36unsafe impl Sync for PinnedStorage {}
37
38impl PinnedStorage {
39 pub fn new(len: usize) -> Result<Self> {
45 if len == 0 {
46 return Err(StorageError::AllocationFailed(
47 "zero-sized allocations are not supported".into(),
48 ));
49 }
50
51 let ctx = cuda_context(0)?;
52 let ptr = unsafe {
53 ctx.bind_to_thread().map_err(StorageError::Cuda)?;
54
55 let ptr = cudarc::driver::result::malloc_host(len, sys::CU_MEMHOSTALLOC_WRITECOMBINED)
56 .map_err(StorageError::Cuda)?;
57
58 let ptr = ptr as *mut u8;
59 assert!(!ptr.is_null(), "Failed to allocate pinned memory");
60 assert!(ptr.is_aligned(), "Pinned memory is not aligned");
61 assert!(len < isize::MAX as usize);
62
63 ptr as usize
64 };
65
66 Ok(Self { ptr, len, ctx })
67 }
68
69 pub unsafe fn as_ptr(&self) -> *const u8 {
74 self.ptr as *const u8
75 }
76
77 pub unsafe fn as_mut_ptr(&mut self) -> *mut u8 {
83 self.ptr as *mut u8
84 }
85}
86
87impl Drop for PinnedStorage {
88 fn drop(&mut self) {
89 if let Err(e) = self.ctx.bind_to_thread() {
90 tracing::debug!("failed to bind CUDA context for free: {e}");
91 }
92 unsafe {
93 if let Err(e) = cudarc::driver::result::free_host(self.ptr as _) {
94 tracing::debug!("failed to free pinned memory: {e}");
95 }
96 };
97 }
98}
99
100impl MemoryDescription for PinnedStorage {
101 fn addr(&self) -> usize {
102 unsafe { self.as_ptr() as usize }
103 }
104
105 fn size(&self) -> usize {
106 self.len
107 }
108
109 fn storage_kind(&self) -> StorageKind {
110 StorageKind::Pinned
111 }
112
113 fn as_any(&self) -> &dyn Any {
114 self
115 }
116
117 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
118 None
119 }
120}
121
122impl super::nixl::NixlCompatible for PinnedStorage {
124 fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
125 let ptr = unsafe { self.as_ptr() };
126 (ptr, self.len, nixl_sys::MemType::Dram, 0)
127 }
128}
129
130impl actions::Memset for PinnedStorage {
131 fn memset(&mut self, value: u8, offset: usize, size: usize) -> Result<()> {
132 let end = offset
133 .checked_add(size)
134 .ok_or_else(|| StorageError::OperationFailed("memset: offset overflow".into()))?;
135 if end > self.len {
136 return Err(StorageError::OperationFailed(
137 "memset: offset + size > storage size".into(),
138 ));
139 }
140 unsafe {
141 let ptr = (self.ptr as *mut u8).add(offset);
142 std::ptr::write_bytes(ptr, value, size);
143 }
144 Ok(())
145 }
146}