1use super::*;
17use super::sync_manager::{BackendSyncable, SyncManager};
18use std::ops::{Index, IndexMut};
19use serde::{Serialize, Deserialize};
20
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct RegDescriptor {
24 pub addr: usize,
25 pub len: usize,
26 pub dev_id: u64,
27 pub metadata: Vec<u8>,
28}
29
30#[derive(Debug, Serialize, Deserialize)]
32struct RegDescData {
33 mem_type: MemType,
34 descriptors: Vec<RegDescriptor>,
35}
36
37impl BackendSyncable for RegDescData {
38 type Backend = NonNull<bindings::nixl_capi_reg_dlist_s>;
39 type Error = NixlError;
40
41 fn sync_to_backend(&self, backend: &Self::Backend) -> Result<(), Self::Error> {
42 let status = unsafe { nixl_capi_reg_dlist_clear(backend.as_ptr()) };
44 match status {
45 NIXL_CAPI_SUCCESS => {}
46 NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
47 _ => return Err(NixlError::BackendError),
48 }
49
50 for desc in &self.descriptors {
52 let status = unsafe {
53 nixl_capi_reg_dlist_add_desc(
54 backend.as_ptr(),
55 desc.addr as uintptr_t,
56 desc.len,
57 desc.dev_id,
58 desc.metadata.as_ptr() as *const std::ffi::c_void,
59 desc.metadata.len(),
60 )
61 };
62 match status {
63 NIXL_CAPI_SUCCESS => {}
64 NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
65 _ => return Err(NixlError::BackendError),
66 }
67 }
68
69 Ok(())
70 }
71}
72
73pub struct RegDescList<'a> {
75 sync_mgr: SyncManager<RegDescData>,
76 _phantom: PhantomData<&'a dyn NixlDescriptor>,
77 mem_type: MemType,
78}
79
80impl<'a> RegDescList<'a> {
81 pub fn new(mem_type: MemType) -> Result<Self, NixlError> {
83 let mut dlist = ptr::null_mut();
84 let status = unsafe {
85 nixl_capi_create_reg_dlist(mem_type as nixl_capi_mem_type_t, &mut dlist)
86 };
87
88 match status {
89 NIXL_CAPI_SUCCESS => {
90 if dlist.is_null() {
91 tracing::error!("Failed to create registration descriptor list");
92 return Err(NixlError::RegDescListCreationFailed);
93 }
94 let backend = NonNull::new(dlist).ok_or(NixlError::RegDescListCreationFailed)?;
95
96 let data = RegDescData {
97 mem_type,
98 descriptors: Vec::new(),
99 };
100 let sync_mgr = SyncManager::new(data, backend);
101
102 Ok(Self {
103 sync_mgr,
104 _phantom: PhantomData,
105 mem_type,
106 })
107 }
108 _ => Err(NixlError::RegDescListCreationFailed),
109 }
110 }
111
112 pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }
113
114 pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) {
116 self.add_desc_with_meta(addr, len, dev_id, &[])
117 }
118
119 pub fn add_desc_with_meta(
121 &mut self,
122 addr: usize,
123 len: usize,
124 dev_id: u64,
125 metadata: &[u8],
126 ) {
127 self.sync_mgr.data_mut().descriptors.push(RegDescriptor {
128 addr,
129 len,
130 dev_id,
131 metadata: metadata.to_vec(),
132 });
133 }
134
135 pub fn is_empty(&self) -> Result<bool, NixlError> {
137 Ok(self.len()? == 0)
138 }
139
140 pub fn desc_count(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
142
143 pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
145
146 pub fn trim(&mut self) {
148 self.sync_mgr.data_mut().descriptors.shrink_to_fit();
149 }
150
151 pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
153 if index < 0 { return Err(NixlError::InvalidParam); }
154 let idx = index as usize;
155
156 let data = self.sync_mgr.data_mut();
157 if idx >= data.descriptors.len() {
158 return Err(NixlError::InvalidParam);
159 }
160 data.descriptors.remove(idx);
161 Ok(())
162 }
163
164 pub fn print(&self) -> Result<(), NixlError> {
166 let backend = self.sync_mgr.backend()?;
167 let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) };
168 match status {
169 NIXL_CAPI_SUCCESS => Ok(()),
170 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
171 _ => Err(NixlError::BackendError),
172 }
173 }
174
175 pub fn clear(&mut self) {
177 self.sync_mgr.data_mut().descriptors.clear();
178 }
179
180 pub fn resize(&mut self, new_size: usize) {
182 self.sync_mgr.data_mut().descriptors.resize(new_size, RegDescriptor {
183 addr: 0,
184 len: 0,
185 dev_id: 0,
186 metadata: Vec::new(),
187 });
188 }
189
190 pub fn get(&self, index: usize) -> Result<&RegDescriptor, NixlError> {
192 self.sync_mgr.data().descriptors
193 .get(index)
194 .ok_or(NixlError::InvalidParam)
195 }
196
197 pub fn get_mut(&mut self, index: usize) -> Result<&mut RegDescriptor, NixlError> {
199 self.sync_mgr.data_mut().descriptors
200 .get_mut(index)
201 .ok_or(NixlError::InvalidParam)
202 }
203
204 pub fn add_storage_desc(&mut self, desc: &'a dyn NixlDescriptor) -> Result<(), NixlError> {
211 let desc_mem_type = desc.mem_type();
213 let list_mem_type = if self.len()? > 0 {
214 self.get_type()?
215 } else {
216 desc_mem_type
217 };
218
219 if desc_mem_type != list_mem_type && list_mem_type != MemType::Unknown {
220 return Err(NixlError::InvalidParam);
221 }
222
223 let addr = unsafe { desc.as_ptr() } as usize;
225 let len = desc.size();
226 let dev_id = desc.device_id();
227
228 self.add_desc(addr, len, dev_id);
230 Ok(())
231 }
232
233 pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_reg_dlist_s {
234 self.sync_mgr.backend().map(|b| b.as_ptr()).unwrap_or(ptr::null_mut())
235 }
236
237 pub fn serialize(&self) -> Result<Vec<u8>, NixlError> {
239 bincode::serialize(self.sync_mgr.data()).map_err(|_| NixlError::BackendError)
241 }
242
243 pub fn deserialize(bytes: &[u8]) -> Result<Self, NixlError> {
245 let data: RegDescData = bincode::deserialize(bytes)
246 .map_err(|_| NixlError::RegDescListCreationFailed)?;
247
248 let mut list = RegDescList::new(data.mem_type)?;
249 for desc in data.descriptors {
250 list.add_desc_with_meta(desc.addr, desc.len, desc.dev_id, &desc.metadata);
251 }
252
253 list.sync_mgr.backend()?;
255
256 Ok(list)
257 }
258}
259
260impl std::fmt::Debug for RegDescList<'_> {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 let mem_type = self.get_type().unwrap_or(MemType::Unknown);
263 let len = self.len().unwrap_or(0);
264 let desc_count = self.desc_count().unwrap_or(0);
265
266 f.debug_struct("RegDescList")
267 .field("mem_type", &mem_type)
268 .field("len", &len)
269 .field("desc_count", &desc_count)
270 .finish()
271 }
272}
273
274impl PartialEq for RegDescList<'_> {
275 fn eq(&self, other: &Self) -> bool {
276 if self.mem_type != other.mem_type {
278 return false;
279 }
280
281 self.sync_mgr.data().descriptors == other.sync_mgr.data().descriptors
283 }
284}
285
286impl Index<usize> for RegDescList<'_> {
288 type Output = RegDescriptor;
289
290 fn index(&self, index: usize) -> &Self::Output {
291 &self.sync_mgr.data().descriptors[index]
292 }
293}
294
295impl IndexMut<usize> for RegDescList<'_> {
297 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
298 &mut self.sync_mgr.data_mut().descriptors[index]
300 }
301}
302
303impl Drop for RegDescList<'_> {
304 fn drop(&mut self) {
305 tracing::trace!("Dropping registration descriptor list");
306 if let Ok(backend) = self.sync_mgr.backend() {
307 unsafe {
308 nixl_capi_destroy_reg_dlist(backend.as_ptr());
309 }
310 }
311 tracing::trace!("Registration descriptor list dropped");
312 }
313}