Skip to main content

dynamo_memory/
device.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! CUDA device memory storage.
5
6use super::{MemoryDescription, Result, StorageError, StorageKind, nixl::NixlDescriptor};
7use cudarc::driver::CudaContext;
8use std::any::Any;
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex, OnceLock};
11
12/// Get or create a CUDA context for the given device.
13fn cuda_context(device_id: u32) -> Result<Arc<CudaContext>> {
14    static CONTEXTS: OnceLock<Mutex<HashMap<u32, Arc<CudaContext>>>> = OnceLock::new();
15    let mut map = CONTEXTS.get_or_init(Default::default).lock().unwrap();
16
17    if let Some(existing) = map.get(&device_id) {
18        return Ok(existing.clone());
19    }
20
21    let ctx = CudaContext::new(device_id as usize)?;
22    map.insert(device_id, ctx.clone());
23    Ok(ctx)
24}
25
26/// CUDA device memory allocated via cudaMalloc.
27#[derive(Debug)]
28pub struct DeviceStorage {
29    ctx: Arc<CudaContext>,
30    ptr: u64,
31    device_id: u32,
32    len: usize,
33}
34
35unsafe impl Send for DeviceStorage {}
36unsafe impl Sync for DeviceStorage {}
37
38impl DeviceStorage {
39    /// Allocate new device memory of the given size.
40    ///
41    /// # Arguments
42    /// * `len` - Size in bytes to allocate
43    /// * `device_id` - CUDA device on which to allocate
44    pub fn new(len: usize, device_id: u32) -> Result<Self> {
45        if len == 0 {
46            return Err(StorageError::AllocationFailed(
47                "zero-sized allocations are not supported".into(),
48            ));
49        }
50
51        let ctx = cuda_context(device_id)?;
52        ctx.bind_to_thread().map_err(StorageError::Cuda)?;
53        let ptr = unsafe { cudarc::driver::result::malloc_sync(len).map_err(StorageError::Cuda)? };
54
55        Ok(Self {
56            ctx,
57            ptr,
58            device_id,
59            len,
60        })
61    }
62
63    /// Get the device pointer value.
64    pub fn device_ptr(&self) -> u64 {
65        self.ptr
66    }
67
68    /// Get the CUDA device ID this memory is allocated on.
69    pub fn device_id(&self) -> u32 {
70        self.device_id
71    }
72}
73
74impl Drop for DeviceStorage {
75    fn drop(&mut self) {
76        if let Err(e) = self.ctx.bind_to_thread() {
77            tracing::debug!("failed to bind CUDA context for free: {e}");
78        }
79        unsafe {
80            if let Err(e) = cudarc::driver::result::free_sync(self.ptr) {
81                tracing::debug!("failed to free device memory: {e}");
82            }
83        };
84    }
85}
86
87impl MemoryDescription for DeviceStorage {
88    fn addr(&self) -> usize {
89        self.device_ptr() as usize
90    }
91
92    fn size(&self) -> usize {
93        self.len
94    }
95
96    fn storage_kind(&self) -> StorageKind {
97        StorageKind::Device(self.device_id)
98    }
99
100    fn as_any(&self) -> &dyn Any {
101        self
102    }
103
104    fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
105        None
106    }
107}
108
109// Support for NIXL registration
110impl super::nixl::NixlCompatible for DeviceStorage {
111    fn nixl_params(&self) -> (*const u8, usize, nixl_sys::MemType, u64) {
112        (
113            self.ptr as *const u8,
114            self.len,
115            nixl_sys::MemType::Vram,
116            self.device_id as u64,
117        )
118    }
119}