Skip to main content

dynamo_memory/
lib.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Clean, minimal storage API for v2 block manager.
5//!
6//! This module provides a simplified storage abstraction with:
7//! - Single trait for type erasure (`MemoryDescriptor`)
8//! - Concrete storage types (no trait implementations required)
9//! - Composition-based NIXL registration via `NixlRegistered<T>` wrapper
10//! - RAII with proper drop ordering (registration handle drops before memory)
11
12#![deny(missing_docs)]
13
14pub mod actions;
15pub mod arena;
16pub mod nixl;
17#[cfg(target_os = "linux")]
18pub mod numa;
19
20/// Offset-based buffer views into underlying storage.
21pub mod offset;
22
23/// CUDA memory pool utilities.
24pub mod pool;
25
26/// Common imports for working with memory types.
27pub mod prelude;
28
29mod device;
30#[cfg(target_os = "linux")]
31mod disk;
32mod external;
33mod pinned;
34mod system;
35mod tensor;
36
37#[cfg(test)]
38mod tests;
39
40pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
41pub use device::DeviceStorage;
42#[cfg(target_os = "linux")]
43pub use disk::DiskStorage;
44pub use external::ExternalDeviceMemory;
45#[cfg(target_os = "linux")]
46pub use numa::{NumaNode, is_numa_enabled};
47pub use offset::OffsetBuffer;
48pub use pinned::PinnedStorage;
49pub use pool::{CudaMemPool, CudaMemPoolBuilder};
50pub use system::SystemStorage;
51pub use tensor::{TensorDescriptor, TensorDescriptorExt};
52
53use serde::{Deserialize, Serialize};
54use std::any::Any;
55use std::fmt;
56use std::sync::Arc;
57use thiserror::Error;
58
59/// Result type for storage operations.
60pub type Result<T> = std::result::Result<T, StorageError>;
61
62/// Core trait for memory regions that can be type-erased.
63///
64/// This is the only trait in the storage API. Concrete storage types
65/// implement this trait to enable type erasure via `Arc<dyn MemoryDescriptor>`.
66pub trait MemoryDescriptor: Send + Sync + fmt::Debug {
67    /// Base address of the memory region.
68    fn addr(&self) -> usize;
69
70    /// Size of the memory region in bytes.
71    fn size(&self) -> usize;
72
73    /// Type of storage backing this region.
74    fn storage_kind(&self) -> StorageKind;
75
76    /// Enable downcasting to concrete type.
77    fn as_any(&self) -> &dyn Any;
78
79    /// Get the NIXL descriptor for this memory region.
80    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
81}
82
83/// Errors that can occur during storage operations.
84#[derive(Debug, Error)]
85#[allow(missing_docs)]
86pub enum StorageError {
87    #[error("allocation failed: {0}")]
88    AllocationFailed(String),
89
90    #[error("registration failed: {0}")]
91    RegistrationFailed(String),
92
93    #[error("operation failed: {0}")]
94    OperationFailed(String),
95
96    #[error("unsupported operation: {0}")]
97    Unsupported(String),
98
99    #[error("I/O error: {0}")]
100    Io(#[from] std::io::Error),
101
102    // #[cfg(feature = "cuda")]
103    #[error("CUDA error: {0}")]
104    Cuda(#[from] cudarc::driver::DriverError),
105
106    #[error("NIXL error: {0}")]
107    Nixl(#[from] nixl_sys::NixlError),
108}
109
110/// Storage type classification.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
112pub enum StorageKind {
113    /// System memory (malloc)
114    System,
115
116    /// CUDA pinned host memory
117    // #[cfg(feature = "cuda")]
118    Pinned,
119
120    /// CUDA device memory with device ID
121    // #[cfg(feature = "cuda")]
122    Device(u32),
123
124    /// Disk-backed memory (mmap)
125    Disk(u64),
126}
127
128impl StorageKind {
129    /// Returns the CUDA device index if this is device memory.
130    pub fn cuda_device_index(&self) -> Option<u32> {
131        match self {
132            StorageKind::Device(idx) => Some(*idx),
133            _ => None,
134        }
135    }
136
137    /// Returns true if this is CUDA device memory.
138    pub fn is_cuda(&self) -> bool {
139        matches!(self, StorageKind::Device(_))
140    }
141
142    /// Returns true if this is system memory (malloc).
143    pub fn is_system(&self) -> bool {
144        matches!(self, StorageKind::System)
145    }
146
147    /// Returns true if this is CUDA pinned host memory.
148    pub fn is_pinned(&self) -> bool {
149        matches!(self, StorageKind::Pinned)
150    }
151
152    /// Returns true if this is disk-backed memory.
153    pub fn is_disk(&self) -> bool {
154        matches!(self, StorageKind::Disk(_))
155    }
156}
157
158/// Type-erased memory region for use in layouts.
159#[derive(Clone)]
160pub struct Buffer(Arc<dyn MemoryDescriptor>);
161
162impl Buffer {
163    /// Wraps a concrete storage type into a type-erased [`Buffer`].
164    ///
165    /// This is the primary way to create a `Buffer` from any type that
166    /// implements [`MemoryDescriptor`].
167    pub fn new<S: MemoryDescriptor + 'static>(memory: S) -> Self {
168        Buffer(Arc::new(memory))
169    }
170}
171
172impl MemoryDescriptor for Buffer {
173    fn addr(&self) -> usize {
174        self.0.addr()
175    }
176    fn size(&self) -> usize {
177        self.0.size()
178    }
179    fn storage_kind(&self) -> StorageKind {
180        self.0.storage_kind()
181    }
182    fn as_any(&self) -> &dyn Any {
183        self.0.as_any()
184    }
185    fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
186        self.0.nixl_descriptor()
187    }
188}
189
190impl std::ops::Deref for Buffer {
191    type Target = dyn MemoryDescriptor;
192
193    fn deref(&self) -> &Self::Target {
194        self.0.as_ref()
195    }
196}
197
198impl std::fmt::Debug for Buffer {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        f.debug_struct("Buffer")
201            .field("addr", &self.addr())
202            .field("size", &self.size())
203            .field("kind", &self.storage_kind())
204            .finish()
205    }
206}
207
208/// Helper function to convert concrete storage to type-erased form.
209pub fn create_buffer<S: MemoryDescriptor + 'static>(memory: S) -> Buffer {
210    Buffer(Arc::new(memory))
211}
212
213impl Buffer {
214    /// Create a Buffer from an existing Arc<dyn MemoryDescriptor>.
215    pub fn from_arc(arc: Arc<dyn MemoryDescriptor>) -> Self {
216        Buffer(arc)
217    }
218}
219
220// From implementations for ergonomic Buffer creation
221impl From<Arc<dyn MemoryDescriptor>> for Buffer {
222    fn from(arc: Arc<dyn MemoryDescriptor>) -> Self {
223        Buffer::from_arc(arc)
224    }
225}
226
227impl From<Arc<dyn nixl::NixlMemory + Send + Sync>> for Buffer {
228    fn from(arc: Arc<dyn nixl::NixlMemory + Send + Sync>) -> Self {
229        // Arc<dyn NixlMemory> implements MemoryDescriptor, so we can wrap it
230        Buffer::new(arc)
231    }
232}
233
234/// An unowned contiguous chunk of memory, not storage specific.
235#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236pub struct MemoryRegion {
237    /// Start address of the memory region.
238    pub addr: usize,
239
240    /// Size of the memory region in bytes.
241    pub size: usize,
242}
243
244impl MemoryRegion {
245    /// Creates a new memory region with the given base address and size.
246    pub fn new(addr: usize, size: usize) -> Self {
247        Self { addr, size }
248    }
249
250    /// Returns the base address of this memory region.
251    #[inline]
252    pub fn addr(&self) -> usize {
253        self.addr
254    }
255
256    /// Returns the size of this memory region in bytes.
257    #[inline]
258    pub fn size(&self) -> usize {
259        self.size
260    }
261
262    /// Get a slice view of this memory region.
263    ///
264    /// # Safety
265    /// This is unsafe because:
266    /// - The caller must ensure the memory region is valid and properly initialized
267    /// - The caller must ensure no mutable references exist to this memory
268    /// - The caller must ensure the memory remains valid for the lifetime of the slice
269    #[cfg(feature = "unsafe-slices")]
270    pub unsafe fn as_slice(&self) -> Result<&[u8]> {
271        if self.size == 0 {
272            return Ok(&[]);
273        }
274        // SAFETY: Caller guarantees memory is valid
275        unsafe {
276            Ok(std::slice::from_raw_parts(
277                self.addr as *const u8,
278                self.size,
279            ))
280        }
281    }
282
283    /// Get a mutable slice view of this memory region.
284    ///
285    /// # Safety
286    /// This is unsafe because:
287    /// - The caller must ensure the memory region is valid and properly initialized
288    /// - The caller must ensure no other references (mutable or immutable) exist to this memory
289    /// - The caller must ensure the memory remains valid for the lifetime of the slice
290    #[cfg(feature = "unsafe-slices")]
291    pub unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8]> {
292        if self.size == 0 {
293            return Ok(&mut []);
294        }
295        // SAFETY: Caller guarantees memory is valid and exclusively accessible
296        unsafe {
297            Ok(std::slice::from_raw_parts_mut(
298                self.addr as *mut u8,
299                self.size,
300            ))
301        }
302    }
303}