1use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use std::any::Any;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, OnceLock};
11
12fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
14 static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
15 let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
16
17 if let Some(existing) = map.get(&device_id) {
18 return Ok(existing.clone());
19 }
20
21 let ctx = CudaContext::new(device_id as usize)?;
22 map.insert(device_id, ctx.clone());
23 Ok(ctx)
24}
25
26#[derive(Debug)]
28pub struct DeviceStorage {
29 ctx: Arc<CudaContext>,
30 ptr: u64,
31 device_id: u32,
32 len: usize,
33}
34
35unsafe impl Send for DeviceStorage {}
36unsafe impl Sync for DeviceStorage {}
37
38impl DeviceStorage {
39 pub fn new(len: usize, device_id: u32) -> Result<Self> {
45 if len == 0 {
46 return Err(StorageError::AllocationFailed(
47 "zero-sized allocations are not supported".into(),
48 ));
49 }
50
51 let ctx = cuda_context(device_id)?;
52 ctx.bind_to_thread().map_err(StorageError::Cuda)?;
53 let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
54
55 Ok(Self {
56 ctx,
57 ptr,
58 device_id,
59 len,
60 })
61 }
62
63 pub fn device_ptr(&self) -> u64 {
65 self.ptr
66 }
67
68 pub fn device_id(&self) -> u32 {
70 self.device_id
71 }
72}
73
74impl Drop for DeviceStorage {
75 fn drop(&mut self) {
76 if let Err(e) = self.ctx.bind_to_thread() {
77 tracing::debug!("failed to bind CUDA context for free: {e}");
78 }
79 unsafe {
80 if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
81 tracing::debug!("failed to free device memory: {e}");
82 }
83 };
84 }
85}
86
87impl MemoryDescription for DeviceStorage {
88 fn addr(&self) -> usize {
89 self.device_ptr() as usize
90 }
91
92 fn size(&self) -> usize {
93 self.len
94 }
95
96 fn storage_kind(&self) -> StorageKind {
97 StorageKind::Device(self.device_id)
98 }
99
100 fn as_any(&self) -> &dyn Any {
101 self
102 }
103
104 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
105 None
106 }
107}
108
109impl super::nixl::NixlCompatible for DeviceStorage {
111 fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
112 (
113 self.ptr as *const u8,
114 self.len,
115 nixl_sys::MemType::Vram,
116 self.device_id as u64,
117 )
118 }
119}