1mod agent;
7mod config;
8
9use super::{MemoryDescription, StorageKind};
10use std::any::Any;
11use std::fmt;
12
13pub use agent::NixlAgent;
14pub use config::NixlBackendConfig;
15
16pub use nixl_sys::{MemType, OptArgs, RegistrationHandle};
17pub use serde::{Deserialize, Serialize};
18
19pub trait NixlCompatible {
21 fn nixl_params(&self) -> (*const u8, usize, MemType, u64);
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct NixlDescriptor {
30 pub addr: u64,
31 pub size: usize,
32 pub mem_type: MemType,
33 pub device_id: u64,
34}
35
36impl nixl_sys::MemoryRegion for NixlDescriptor {
37 unsafe fn as_ptr(&self) -> *const u8 {
38 self.addr as *const u8
39 }
40
41 fn size(&self) -> usize {
42 self.size
43 }
44}
45
46impl nixl_sys::NixlDescriptor for NixlDescriptor {
47 fn mem_type(&self) -> MemType {
48 self.mem_type
49 }
50
51 fn device_id(&self) -> u64 {
52 self.device_id
53 }
54}
55
56pub trait RegisteredView {
58 fn agent_name(&self) -> &str;
60
61 fn descriptor(&self) -> NixlDescriptor;
63}
64
65pub struct NixlRegistered<S: NixlCompatible> {
71 storage: S,
72 handle: Option<RegistrationHandle>,
73 agent_name: String,
74}
75
76impl<S: NixlCompatible> Drop for NixlRegistered<S> {
77 fn drop(&mut self) {
78 drop(self.handle.take());
80 }
82}
83
84impl<S: NixlCompatible + fmt::Debug> fmt::Debug for NixlRegistered<S> {
85 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86 f.debug_struct("NixlRegistered")
87 .field("storage", &self.storage)
88 .field("agent_name", &self.agent_name)
89 .field("handle", &self.handle.is_some())
90 .finish()
91 }
92}
93
94impl<S: MemoryDescription + NixlCompatible + 'static> MemoryDescription for NixlRegistered<S> {
95 fn addr(&self) -> usize {
96 self.storage.addr()
97 }
98
99 fn size(&self) -> usize {
100 self.storage.size()
101 }
102
103 fn storage_kind(&self) -> StorageKind {
104 self.storage.storage_kind()
105 }
106
107 fn as_any(&self) -> &dyn Any {
108 self
109 }
110
111 fn nixl_descriptor(&self) -> Option<NixlDescriptor> {
112 Some(self.descriptor())
113 }
114}
115
116impl<S: MemoryDescription + NixlCompatible> RegisteredView for NixlRegistered<S> {
117 fn agent_name(&self) -> &str {
118 &self.agent_name
119 }
120
121 fn descriptor(&self) -> NixlDescriptor {
122 let (ptr, size, mem_type, device_id) = self.storage.nixl_params();
123 NixlDescriptor {
124 addr: ptr as u64,
125 size,
126 mem_type,
127 device_id,
128 }
129 }
130}
131
132impl<S: MemoryDescription + NixlCompatible> NixlRegistered<S> {
133 pub fn storage(&self) -> &S {
135 &self.storage
136 }
137
138 pub fn storage_mut(&mut self) -> &mut S {
140 &mut self.storage
141 }
142
143 pub fn is_registered(&self) -> bool {
145 self.handle.is_some()
146 }
147
148 pub fn into_storage(mut self) -> S {
152 drop(self.handle.take());
153 let mut this = std::mem::ManuallyDrop::new(self);
154 unsafe {
155 let storage = std::ptr::read(&this.storage);
156 std::ptr::drop_in_place(&mut this.agent_name);
157 storage
158 }
159 }
160}
161
162pub fn register_with_nixl<S>(
177 storage: S,
178 agent: &NixlAgent,
179 opt: Option<&OptArgs>,
180) -> std::result::Result<NixlRegistered<S>, S>
181where
182 S: MemoryDescription + NixlCompatible,
183{
184 let (ptr, size, mem_type, device_id) = storage.nixl_params();
186
187 let descriptor = NixlDescriptor {
189 addr: ptr as u64,
190 size,
191 mem_type,
192 device_id,
193 };
194
195 match agent.register_memory(&descriptor, opt) {
196 Ok(handle) => Ok(NixlRegistered {
197 storage,
198 handle: Some(handle),
199 agent_name: agent.name().to_string(),
200 }),
201 Err(_) => Err(storage),
202 }
203}