nixl_sys/descriptors/
reg.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use super::sync_manager::{BackendSyncable, SyncManager};
18use std::ops::{Index, IndexMut};
19use serde::{Serialize, Deserialize};
20
21/// Public registration descriptor used for indexing and comparisons
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct RegDescriptor {
24    pub addr: usize,
25    pub len: usize,
26    pub dev_id: u64,
27    pub metadata: Vec<u8>,
28}
29
30/// Internal data structure for registration descriptors
31#[derive(Debug, Serialize, Deserialize)]
32struct RegDescData {
33    mem_type: MemType,
34    descriptors: Vec<RegDescriptor>,
35}
36
37impl BackendSyncable for RegDescData {
38    type Backend = NonNull<bindings::nixl_capi_reg_dlist_s>;
39    type Error = NixlError;
40
41    fn sync_to_backend(&self, backend: &Self::Backend) -> Result<(), Self::Error> {
42        // Clear backend
43        let status = unsafe { nixl_capi_reg_dlist_clear(backend.as_ptr()) };
44        match status {
45            NIXL_CAPI_SUCCESS => {}
46            NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
47            _ => return Err(NixlError::BackendError),
48        }
49
50        // Re-add all descriptors
51        for desc in &self.descriptors {
52            let status = unsafe {
53                nixl_capi_reg_dlist_add_desc(
54                    backend.as_ptr(),
55                    desc.addr as uintptr_t,
56                    desc.len,
57                    desc.dev_id,
58                    desc.metadata.as_ptr() as *const std::ffi::c_void,
59                    desc.metadata.len(),
60                )
61            };
62            match status {
63                NIXL_CAPI_SUCCESS => {}
64                NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
65                _ => return Err(NixlError::BackendError),
66            }
67        }
68
69        Ok(())
70    }
71}
72
73/// A safe wrapper around a NIXL registration descriptor list
74pub struct RegDescList<'a> {
75    sync_mgr: SyncManager<RegDescData>,
76    _phantom: PhantomData<&'a dyn NixlDescriptor>,
77    mem_type: MemType,
78}
79
80impl<'a> RegDescList<'a> {
81    /// Creates a new registration descriptor list for the given memory type
82    pub fn new(mem_type: MemType) -> Result<Self, NixlError> {
83        let mut dlist = ptr::null_mut();
84        let status = unsafe {
85            nixl_capi_create_reg_dlist(mem_type as nixl_capi_mem_type_t, &mut dlist)
86        };
87
88        match status {
89            NIXL_CAPI_SUCCESS => {
90                if dlist.is_null() {
91                    tracing::error!("Failed to create registration descriptor list");
92                    return Err(NixlError::RegDescListCreationFailed);
93                }
94                let backend = NonNull::new(dlist).ok_or(NixlError::RegDescListCreationFailed)?;
95
96                let data = RegDescData {
97                    mem_type,
98                    descriptors: Vec::new(),
99                };
100                let sync_mgr = SyncManager::new(data, backend);
101
102                Ok(Self {
103                    sync_mgr,
104                    _phantom: PhantomData,
105                    mem_type,
106                })
107            }
108            _ => Err(NixlError::RegDescListCreationFailed),
109        }
110    }
111
112    pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }
113
114    /// Adds a descriptor to the list
115    pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) {
116        self.add_desc_with_meta(addr, len, dev_id, &[])
117    }
118
119    /// Add a descriptor with metadata
120    pub fn add_desc_with_meta(
121        &mut self,
122        addr: usize,
123        len: usize,
124        dev_id: u64,
125        metadata: &[u8],
126    ) {
127        self.sync_mgr.data_mut().descriptors.push(RegDescriptor {
128            addr,
129            len,
130            dev_id,
131            metadata: metadata.to_vec(),
132        });
133    }
134
135    /// Returns true if the list is empty
136    pub fn is_empty(&self) -> Result<bool, NixlError> {
137        Ok(self.len()? == 0)
138    }
139
140    /// Returns the number of descriptors in the list
141    pub fn desc_count(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
142
143    /// Returns the number of descriptors in the list
144    pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
145
146    /// Trims the list to the given size
147    pub fn trim(&mut self) {
148        self.sync_mgr.data_mut().descriptors.shrink_to_fit();
149    }
150
151    /// Removes the descriptor at the given index
152    pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
153        if index < 0 { return Err(NixlError::InvalidParam); }
154        let idx = index as usize;
155
156        let data = self.sync_mgr.data_mut();
157        if idx >= data.descriptors.len() {
158            return Err(NixlError::InvalidParam);
159        }
160        data.descriptors.remove(idx);
161        Ok(())
162    }
163
164    /// Prints the list contents
165    pub fn print(&self) -> Result<(), NixlError> {
166        let backend = self.sync_mgr.backend()?;
167        let status = unsafe { nixl_capi_reg_dlist_print(backend.as_ptr()) };
168        match status {
169            NIXL_CAPI_SUCCESS => Ok(()),
170            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
171            _ => Err(NixlError::BackendError),
172        }
173    }
174
175    /// Clears all descriptors from the list
176    pub fn clear(&mut self) {
177        self.sync_mgr.data_mut().descriptors.clear();
178    }
179
180    /// Resizes the list to the given size
181    pub fn resize(&mut self, new_size: usize) {
182        self.sync_mgr.data_mut().descriptors.resize(new_size, RegDescriptor {
183            addr: 0,
184            len: 0,
185            dev_id: 0,
186            metadata: Vec::new(),
187        });
188    }
189
190    /// Safe immutable access to descriptor by index
191    pub fn get(&self, index: usize) -> Result<&RegDescriptor, NixlError> {
192        self.sync_mgr.data().descriptors
193            .get(index)
194            .ok_or(NixlError::InvalidParam)
195    }
196
197    /// Safe mutable access to descriptor by index
198    pub fn get_mut(&mut self, index: usize) -> Result<&mut RegDescriptor, NixlError> {
199        self.sync_mgr.data_mut().descriptors
200            .get_mut(index)
201            .ok_or(NixlError::InvalidParam)
202    }
203
204    /// Add a descriptor from a type implementing NixlDescriptor
205    ///
206    /// # Safety
207    /// The caller must ensure that:
208    /// - The descriptor remains valid for the lifetime of the list
209    /// - The memory region pointed to by the descriptor remains valid
210    pub fn add_storage_desc(&mut self, desc: &'a dyn NixlDescriptor) -> Result<(), NixlError> {
211        // Validate memory type matches
212        let desc_mem_type = desc.mem_type();
213        let list_mem_type = if self.len()? > 0 {
214            self.get_type()?
215        } else {
216            desc_mem_type
217        };
218
219        if desc_mem_type != list_mem_type && list_mem_type != MemType::Unknown {
220            return Err(NixlError::InvalidParam);
221        }
222
223        // Get descriptor details
224        let addr = unsafe { desc.as_ptr() } as usize;
225        let len = desc.size();
226        let dev_id = desc.device_id();
227
228        // Add to list
229        self.add_desc(addr, len, dev_id);
230        Ok(())
231    }
232
233    pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_reg_dlist_s {
234        self.sync_mgr.backend().map(|b| b.as_ptr()).unwrap_or(ptr::null_mut())
235    }
236
237    /// Serializes the descriptor list to a byte vector using bincode
238    pub fn serialize(&self) -> Result<Vec<u8>, NixlError> {
239        // Serialize the RegDescData directly (contains mem_type + descriptors)
240        bincode::serialize(self.sync_mgr.data()).map_err(|_| NixlError::BackendError)
241    }
242
243    /// Deserializes a descriptor list from a byte slice using bincode
244    pub fn deserialize(bytes: &[u8]) -> Result<Self, NixlError> {
245        let data: RegDescData = bincode::deserialize(bytes)
246            .map_err(|_| NixlError::RegDescListCreationFailed)?;
247
248        let mut list = RegDescList::new(data.mem_type)?;
249        for desc in data.descriptors {
250            list.add_desc_with_meta(desc.addr, desc.len, desc.dev_id, &desc.metadata);
251        }
252
253        // Force synchronization to validate backend can handle the data
254        list.sync_mgr.backend()?;
255
256        Ok(list)
257    }
258}
259
260impl std::fmt::Debug for RegDescList<'_> {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        let mem_type = self.get_type().unwrap_or(MemType::Unknown);
263        let len = self.len().unwrap_or(0);
264        let desc_count = self.desc_count().unwrap_or(0);
265
266        f.debug_struct("RegDescList")
267            .field("mem_type", &mem_type)
268            .field("len", &len)
269            .field("desc_count", &desc_count)
270            .finish()
271    }
272}
273
274impl PartialEq for RegDescList<'_> {
275    fn eq(&self, other: &Self) -> bool {
276        // Compare memory types first
277        if self.mem_type != other.mem_type {
278            return false;
279        }
280
281        // Compare internal descriptor tracking
282        self.sync_mgr.data().descriptors == other.sync_mgr.data().descriptors
283    }
284}
285
286// Implement Index trait for immutable indexing (list[i])
287impl Index<usize> for RegDescList<'_> {
288    type Output = RegDescriptor;
289
290    fn index(&self, index: usize) -> &Self::Output {
291        &self.sync_mgr.data().descriptors[index]
292    }
293}
294
295// Implement IndexMut trait for mutable indexing (list[i] = value)
296impl IndexMut<usize> for RegDescList<'_> {
297    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
298        // data_mut() automatically marks dirty
299        &mut self.sync_mgr.data_mut().descriptors[index]
300    }
301}
302
303impl Drop for RegDescList<'_> {
304    fn drop(&mut self) {
305        tracing::trace!("Dropping registration descriptor list");
306        if let Ok(backend) = self.sync_mgr.backend() {
307            unsafe {
308                nixl_capi_destroy_reg_dlist(backend.as_ptr());
309            }
310        }
311        tracing::trace!("Registration descriptor list dropped");
312    }
313}