Skip to main content

dynamo_memory/
device.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! CUDA device memory storage.
5
6use super::{MemoryDescriptor, Result, StorageError, StorageKind, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use std::any::Any;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, OnceLock};
11
12/// Get or create a CUDA context for the given device.
13fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
14    static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
15    let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
16
17    if let Some(existing) = map.get(&device_id) {
18        return Ok(existing.clone());
19    }
20
21    let ctx = CudaContext::new(device_id as usize)?;
22    map.insert(device_id, ctx.clone());
23    Ok(ctx)
24}
25
26/// CUDA device memory allocated via cudaMalloc.
27#[derive(Debug)]
28pub struct DeviceStorage {
29    /// CUDA context used for allocation and deallocation.
30    ctx: Arc<CudaContext>,
31    /// Device pointer to the allocated memory.
32    ptr: u64,
33    /// CUDA device ID where memory is allocated.
34    device_id: u32,
35    /// Size of the allocation in bytes.
36    len: usize,
37}
38
39unsafe impl Send for DeviceStorage {}
40unsafe impl Sync for DeviceStorage {}
41
42impl DeviceStorage {
43    /// Allocate new device memory of the given size.
44    ///
45    /// # Arguments
46    /// * `len` - Size in bytes to allocate
47    /// * `device_id` - CUDA device on which to allocate
48    pub fn new(len: usize, device_id: u32) -> Result<Self> {
49        if len == 0 {
50            return Err(StorageError::AllocationFailed(
51                "zero-sized allocations are not supported".into(),
52            ));
53        }
54
55        let ctx = cuda_context(device_id)?;
56        ctx.bind_to_thread().map_err(StorageError::Cuda)?;
57        let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
58
59        Ok(Self {
60            ctx,
61            ptr,
62            device_id,
63            len,
64        })
65    }
66
67    /// Get the device pointer value.
68    pub fn device_ptr(&self) -> u64 {
69        self.ptr
70    }
71
72    /// Get the CUDA device ID this memory is allocated on.
73    pub fn device_id(&self) -> u32 {
74        self.device_id
75    }
76}
77
78impl Drop for DeviceStorage {
79    fn drop(&mut self) {
80        if let Err(e) = self.ctx.bind_to_thread() {
81            tracing::debug!("failed to bind CUDA context for free: {e}");
82        }
83        unsafe {
84            if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
85                tracing::debug!("failed to free device memory: {e}");
86            }
87        };
88    }
89}
90
91impl MemoryDescriptor for DeviceStorage {
92    fn addr(&self) -> usize {
93        self.device_ptr() as usize
94    }
95
96    fn size(&self) -> usize {
97        self.len
98    }
99
100    fn storage_kind(&self) -> StorageKind {
101        StorageKind::Device(self.device_id)
102    }
103
104    fn as_any(&self) -> &dyn Any {
105        self
106    }
107
108    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
109        None
110    }
111}
112
113// Support for NIXL registration
114impl super::nixl::NixlCompatible for DeviceStorage {
115    fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
116        (
117            self.ptr as *const u8,
118            self.len,
119            nixl_sys::MemType::Vram,
120            self.device_id as u64,
121        )
122    }
123}