1use super::{MemoryDescriptor, Result, StorageError, StorageKind, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use std::any::Any;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, OnceLock};
11
12fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
14 static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
15 let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
16
17 if let Some(existing) = map.get(&device_id) {
18 return Ok(existing.clone());
19 }
20
21 let ctx = CudaContext::new(device_id as usize)?;
22 map.insert(device_id, ctx.clone());
23 Ok(ctx)
24}
25
26#[derive(Debug)]
28pub struct DeviceStorage {
29 ctx: Arc<CudaContext>,
31 ptr: u64,
33 device_id: u32,
35 len: usize,
37}
38
39unsafe impl Send for DeviceStorage {}
40unsafe impl Sync for DeviceStorage {}
41
42impl DeviceStorage {
43 pub fn new(len: usize, device_id: u32) -> Result<Self> {
49 if len == 0 {
50 return Err(StorageError::AllocationFailed(
51 "zero-sized allocations are not supported".into(),
52 ));
53 }
54
55 let ctx = cuda_context(device_id)?;
56 ctx.bind_to_thread().map_err(StorageError::Cuda)?;
57 let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
58
59 Ok(Self {
60 ctx,
61 ptr,
62 device_id,
63 len,
64 })
65 }
66
67 pub fn device_ptr(&self) -> u64 {
69 self.ptr
70 }
71
72 pub fn device_id(&self) -> u32 {
74 self.device_id
75 }
76}
77
78impl Drop for DeviceStorage {
79 fn drop(&mut self) {
80 if let Err(e) = self.ctx.bind_to_thread() {
81 tracing::debug!("failed to bind CUDA context for free: {e}");
82 }
83 unsafe {
84 if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
85 tracing::debug!("failed to free device memory: {e}");
86 }
87 };
88 }
89}
90
91impl MemoryDescriptor for DeviceStorage {
92 fn addr(&self) -> usize {
93 self.device_ptr() as usize
94 }
95
96 fn size(&self) -> usize {
97 self.len
98 }
99
100 fn storage_kind(&self) -> StorageKind {
101 StorageKind::Device(self.device_id)
102 }
103
104 fn as_any(&self) -> &dyn Any {
105 self
106 }
107
108 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
109 None
110 }
111}
112
113impl super::nixl::NixlCompatible for DeviceStorage {
115 fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
116 (
117 self.ptr as *const u8,
118 self.len,
119 nixl_sys::MemType::Vram,
120 self.device_id as u64,
121 )
122 }
123}