1pub mod actions;
13pub mod arena;
14pub mod nixl;
15pub mod offset;
16pub mod pool;
17pub mod prelude;
18
19mod device;
20mod disk;
21mod pinned;
22mod system;
23mod torch;
24
25#[cfg(test)]
26mod tests;
27
28pub use arena::{ArenaAllocator, ArenaBuffer, ArenaError};
29pub use device::DeviceStorage;
30pub use disk::DiskStorage;
31pub use pinned::PinnedStorage;
32pub use system::SystemStorage;
33pub use torch::{TorchDevice, TorchTensor};
34
35use serde::{Deserialize, Serialize};
36use std::any::Any;
37use std::fmt;
38use std::sync::Arc;
39use thiserror::Error;
40
41pub type Result<T> = std::result::Result<T, StorageError>;
43
44#[derive(Debug, Error)]
46pub enum StorageError {
47 #[error("allocation failed: {0}")]
48 AllocationFailed(String),
49
50 #[error("registration failed: {0}")]
51 RegistrationFailed(String),
52
53 #[error("operation failed: {0}")]
54 OperationFailed(String),
55
56 #[error("unsupported operation: {0}")]
57 Unsupported(String),
58
59 #[error("I/O error: {0}")]
60 Io(#[from] std::io::Error),
61
62 #[error("CUDA error: {0}")]
64 Cuda(#[from] cudarc::driver::DriverError),
65
66 #[error("NIXL error: {0}")]
67 Nixl(#[from] nixl_sys::NixlError),
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72pub enum StorageKind {
73 System,
75
76 Pinned,
79
80 Device(u32),
83
84 Disk(u64),
86}
87
88pub trait MemoryDescription: Send + Sync + fmt::Debug {
93 fn addr(&self) -> usize;
95
96 fn size(&self) -> usize;
98
99 fn storage_kind(&self) -> StorageKind;
101
102 fn as_any(&self) -> &dyn Any;
104
105 fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor>;
107}
108
109#[derive(Clone)]
111pub struct Buffer(Arc<dyn MemoryDescription>);
112
113impl MemoryDescription for Buffer {
114 fn addr(&self) -> usize {
115 self.0.addr()
116 }
117 fn size(&self) -> usize {
118 self.0.size()
119 }
120 fn storage_kind(&self) -> StorageKind {
121 self.0.storage_kind()
122 }
123 fn as_any(&self) -> &dyn Any {
124 self.0.as_any()
125 }
126 fn nixl_descriptor(&self) -> Option<nixl::NixlDescriptor> {
127 self.0.nixl_descriptor()
128 }
129}
130
131impl std::ops::Deref for Buffer {
132 type Target = dyn MemoryDescription;
133
134 fn deref(&self) -> &Self::Target {
135 self.0.as_ref()
136 }
137}
138
139impl std::fmt::Debug for Buffer {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 f.debug_struct("Buffer")
142 .field("addr", &self.addr())
143 .field("size", &self.size())
144 .field("kind", &self.storage_kind())
145 .finish()
146 }
147}
148
149pub fn create_buffer<S: MemoryDescription + 'static>(memory: S) -> Buffer {
151 Buffer(Arc::new(memory))
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
156pub struct MemoryRegion {
157 pub addr: usize,
159
160 pub size: usize,
162}
163
164impl MemoryRegion {
165 pub fn new(addr: usize, size: usize) -> Self {
166 Self { addr, size }
167 }
168
169 #[inline]
170 pub fn addr(&self) -> usize {
171 self.addr
172 }
173
174 #[inline]
175 pub fn size(&self) -> usize {
176 self.size
177 }
178}