1#![deny(missing_docs)]
13
14pub mod actions;
15pub mod arena;
16pub mod nixl;
17#[cfg(target_os = "linux")]
18pub mod numa;
19
20pub mod offset;
22
23pub mod pool;
25
26pub mod prelude;
28
29mod device;
30#[cfg(target_os = "linux")]
31mod disk;
32mod external;
33mod pinned;
34mod system;
35mod tensor;
36
37#[cfg(test)]
38mod tests;
39
40pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
41pub use device::DeviceStorage;
42#[cfg(target_os = "linux")]
43pub use disk::DiskStorage;
44pub use external::ExternalDeviceMemory;
45#[cfg(target_os = "linux")]
46pub use numa::{NumaNode, is_numa_enabled};
47pub use offset::OffsetBuffer;
48pub use pinned::PinnedStorage;
49pub use pool::{CudaMemPool, CudaMemPoolBuilder};
50pub use system::SystemStorage;
51pub use tensor::{TensorDescriptor, TensorDescriptorExt};
52
53use serde::{Deserialize, Serialize};
54use std::any::Any;
55use std::fmt;
56use std::sync::Arc;
57use thiserror::Error;
58
59pub type Result<T> = std::result::Result<T, StorageError>;
61
62pub trait MemoryDescriptor: Send + Sync + fmt::Debug {
67 fn addr(&self) -> usize;
69
70 fn size(&self) -> usize;
72
73 fn storage_kind(&self) -> StorageKind;
75
76 fn as_any(&self) -> &dyn Any;
78
79 fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
81}
82
83#[derive(Debug, Error)]
85#[allow(missing_docs)]
86pub enum StorageError {
87 #[error("allocation failed: {0}")]
88 AllocationFailed(String),
89
90 #[error("registration failed: {0}")]
91 RegistrationFailed(String),
92
93 #[error("operation failed: {0}")]
94 OperationFailed(String),
95
96 #[error("unsupported operation: {0}")]
97 Unsupported(String),
98
99 #[error("I/O error: {0}")]
100 Io(#[from] std::io::Error),
101
102 #[error("CUDA error: {0}")]
104 Cuda(#[from] cudarc::driver::DriverError),
105
106 #[error("NIXL error: {0}")]
107 Nixl(#[from] nixl_sys::NixlError),
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
112pub enum StorageKind {
113 System,
115
116 Pinned,
119
120 Device(u32),
123
124 Disk(u64),
126}
127
128impl StorageKind {
129 pub fn cuda_device_index(&self) -> Option<u32> {
131 match self {
132 StorageKind::Device(idx) => Some(*idx),
133 _ => None,
134 }
135 }
136
137 pub fn is_cuda(&self) -> bool {
139 matches!(self, StorageKind::Device(_))
140 }
141
142 pub fn is_system(&self) -> bool {
144 matches!(self, StorageKind::System)
145 }
146
147 pub fn is_pinned(&self) -> bool {
149 matches!(self, StorageKind::Pinned)
150 }
151
152 pub fn is_disk(&self) -> bool {
154 matches!(self, StorageKind::Disk(_))
155 }
156}
157
158#[derive(Clone)]
160pub struct Buffer(Arc<dyn MemoryDescriptor>);
161
162impl Buffer {
163 pub fn new<S: MemoryDescriptor + 'static>(memory: S) -> Self {
168 Buffer(Arc::new(memory))
169 }
170}
171
172impl MemoryDescriptor for Buffer {
173 fn addr(&self) -> usize {
174 self.0.addr()
175 }
176 fn size(&self) -> usize {
177 self.0.size()
178 }
179 fn storage_kind(&self) -> StorageKind {
180 self.0.storage_kind()
181 }
182 fn as_any(&self) -> &dyn Any {
183 self.0.as_any()
184 }
185 fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
186 self.0.nixl_descriptor()
187 }
188}
189
190impl std::ops::Deref for Buffer {
191 type Target = dyn MemoryDescriptor;
192
193 fn deref(&self) -> &Self::Target {
194 self.0.as_ref()
195 }
196}
197
198impl std::fmt::Debug for Buffer {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 f.debug_struct("Buffer")
201 .field("addr", &self.addr())
202 .field("size", &self.size())
203 .field("kind", &self.storage_kind())
204 .finish()
205 }
206}
207
208pub fn create_buffer<S: MemoryDescriptor + 'static>(memory: S) -> Buffer {
210 Buffer(Arc::new(memory))
211}
212
213impl Buffer {
214 pub fn from_arc(arc: Arc<dyn MemoryDescriptor>) -> Self {
216 Buffer(arc)
217 }
218}
219
220impl From<Arc<dyn MemoryDescriptor>> for Buffer {
222 fn from(arc: Arc<dyn MemoryDescriptor>) -> Self {
223 Buffer::from_arc(arc)
224 }
225}
226
227impl From<Arc<dyn nixl::NixlMemory + Send + Sync>> for Buffer {
228 fn from(arc: Arc<dyn nixl::NixlMemory + Send + Sync>) -> Self {
229 Buffer::new(arc)
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
236pub struct MemoryRegion {
237 pub addr: usize,
239
240 pub size: usize,
242}
243
244impl MemoryRegion {
245 pub fn new(addr: usize, size: usize) -> Self {
247 Self { addr, size }
248 }
249
250 #[inline]
252 pub fn addr(&self) -> usize {
253 self.addr
254 }
255
256 #[inline]
258 pub fn size(&self) -> usize {
259 self.size
260 }
261
262 #[cfg(feature = "unsafe-slices")]
270 pub unsafe fn as_slice(&self) -> Result<&[u8]> {
271 if self.size == 0 {
272 return Ok(&[]);
273 }
274 unsafe {
276 Ok(std::slice::from_raw_parts(
277 self.addr as *const u8,
278 self.size,
279 ))
280 }
281 }
282
283 #[cfg(feature = "unsafe-slices")]
291 pub unsafe fn as_slice_mut(&mut self) -> Result<&mut [u8]> {
292 if self.size == 0 {
293 return Ok(&mut []);
294 }
295 unsafe {
297 Ok(std::slice::from_raw_parts_mut(
298 self.addr as *mut u8,
299 self.size,
300 ))
301 }
302 }
303}