nixl_sys/descriptors/
xfer.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use super::sync_manager::{BackendSyncable, SyncManager};
18use std::ops::{Index, IndexMut};
19use serde::{Serialize, Deserialize};
20
21/// Public transfer descriptor used for indexing and comparisons
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct XferDescriptor {
24    pub addr: usize,
25    pub len: usize,
26    pub dev_id: u64,
27}
28
29/// Internal data structure for transfer descriptors
30#[derive(Debug, Serialize, Deserialize)]
31struct XferDescData {
32    mem_type: MemType,
33    descriptors: Vec<XferDescriptor>,
34}
35
36impl BackendSyncable for XferDescData {
37    type Backend = NonNull<bindings::nixl_capi_xfer_dlist_s>;
38    type Error = NixlError;
39
40    fn sync_to_backend(&self, backend: &Self::Backend) -> Result<(), Self::Error> {
41        // Clear backend
42        let status = unsafe { nixl_capi_xfer_dlist_clear(backend.as_ptr()) };
43        match status {
44            NIXL_CAPI_SUCCESS => {}
45            NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
46            _ => return Err(NixlError::BackendError),
47        }
48
49        // Re-add all descriptors
50        for desc in &self.descriptors {
51            let status = unsafe {
52                nixl_capi_xfer_dlist_add_desc(backend.as_ptr(), desc.addr as uintptr_t, desc.len, desc.dev_id)
53            };
54            match status {
55                NIXL_CAPI_SUCCESS => {}
56                NIXL_CAPI_ERROR_INVALID_PARAM => return Err(NixlError::InvalidParam),
57                _ => return Err(NixlError::BackendError),
58            }
59        }
60
61        Ok(())
62    }
63}
64
65/// A safe wrapper around a NIXL transfer descriptor list
66pub struct XferDescList<'a> {
67    sync_mgr: SyncManager<XferDescData>,
68    _phantom: PhantomData<&'a dyn NixlDescriptor>,
69    mem_type: MemType,
70}
71
72impl<'a> XferDescList<'a> {
73    /// Creates a new transfer descriptor list for the given memory type
74    pub fn new(mem_type: MemType) -> Result<Self, NixlError> {
75        let mut dlist = ptr::null_mut();
76        let status = unsafe {
77            nixl_capi_create_xfer_dlist(mem_type as nixl_capi_mem_type_t, &mut dlist)
78        };
79
80        match status {
81            NIXL_CAPI_SUCCESS => {
82                // SAFETY: If status is NIXL_CAPI_SUCCESS, dlist is non-null
83                let backend = unsafe { NonNull::new_unchecked(dlist) };
84                let data = XferDescData {
85                    mem_type,
86                    descriptors: Vec::new(),
87                };
88                let sync_mgr = SyncManager::new(data, backend);
89
90                Ok(Self {
91                    sync_mgr,
92                    _phantom: PhantomData,
93                    mem_type,
94                })
95            }
96            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
97            _ => Err(NixlError::FailedToCreateXferDlistHandle),
98        }
99    }
100
101    pub fn as_ptr(&self) -> *mut bindings::nixl_capi_xfer_dlist_s {
102        self.sync_mgr.backend().map(|b| b.as_ptr()).unwrap_or(ptr::null_mut())
103    }
104
105    /// Returns the memory type of the transfer descriptor list
106    pub fn get_type(&self) -> Result<MemType, NixlError> { Ok(self.mem_type) }
107
108    /// Adds a descriptor to the list
109    pub fn add_desc(&mut self, addr: usize, len: usize, dev_id: u64) {
110        self.sync_mgr.data_mut().descriptors.push(XferDescriptor { addr, len, dev_id });
111    }
112
113    /// Returns true if the list is empty
114    pub fn is_empty(&self) -> Result<bool, NixlError> {
115        Ok(self.len()? == 0)
116    }
117
118    /// Returns the number of descriptors in the list
119    pub fn desc_count(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
120
121    /// Returns the number of descriptors in the list
122    pub fn len(&self) -> Result<usize, NixlError> { Ok(self.sync_mgr.data().descriptors.len()) }
123
124    /// Trims the list to the given size
125    pub fn trim(&mut self) {
126        self.sync_mgr.data_mut().descriptors.shrink_to_fit();
127    }
128
129    /// Removes the descriptor at the given index
130    pub fn rem_desc(&mut self, index: i32) -> Result<(), NixlError> {
131        if index < 0 { return Err(NixlError::InvalidParam); }
132        let idx = index as usize;
133
134        let data = self.sync_mgr.data_mut();
135        if idx >= data.descriptors.len() {
136            return Err(NixlError::InvalidParam);
137        }
138        data.descriptors.remove(idx);
139        Ok(())
140    }
141
142    /// Clears all descriptors from the list
143    pub fn clear(&mut self) {
144        self.sync_mgr.data_mut().descriptors.clear();
145    }
146
147    /// Prints the list contents
148    pub fn print(&self) -> Result<(), NixlError> {
149        let backend = self.sync_mgr.backend()?;
150        let status = unsafe { nixl_capi_xfer_dlist_print(backend.as_ptr()) };
151        match status {
152            NIXL_CAPI_SUCCESS => Ok(()),
153            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
154            _ => Err(NixlError::BackendError),
155        }
156    }
157
158    /// Resizes the list to the given size
159    pub fn resize(&mut self, new_size: usize) {
160        self.sync_mgr.data_mut().descriptors.resize(new_size, XferDescriptor {
161            addr: 0,
162            len: 0,
163            dev_id: 0,
164        });
165    }
166
167    /// Safe immutable access to descriptor by index
168    pub fn get(&self, index: usize) -> Result<&XferDescriptor, NixlError> {
169        self.sync_mgr.data().descriptors
170            .get(index)
171            .ok_or(NixlError::InvalidParam)
172    }
173
174    /// Safe mutable access to descriptor by index
175    pub fn get_mut(&mut self, index: usize) -> Result<&mut XferDescriptor, NixlError> {
176        self.sync_mgr.data_mut().descriptors
177            .get_mut(index)
178            .ok_or(NixlError::InvalidParam)
179    }
180
181    /// Add a descriptor from a type implementing NixlDescriptor
182    ///
183    /// # Safety
184    /// The caller must ensure that:
185    /// - The descriptor remains valid for the lifetime of the list
186    /// - The memory region pointed to by the descriptor remains valid
187    pub fn add_storage_desc<D: NixlDescriptor + 'a>(
188        &mut self,
189        desc: &'a D,
190    ) -> Result<(), NixlError> {
191        // Validate memory type matches
192        let desc_mem_type = desc.mem_type();
193        let list_mem_type = if self.len().unwrap_or(0) > 0 { self.get_type().unwrap() } else { desc_mem_type };
194
195        if desc_mem_type != list_mem_type && list_mem_type != MemType::Unknown {
196            return Err(NixlError::InvalidParam);
197        }
198
199        // Get descriptor details
200        let addr = unsafe { desc.as_ptr() } as usize;
201        let len = desc.size();
202        let dev_id = desc.device_id();
203
204        // Add to list
205        self.add_desc(addr, len, dev_id);
206        Ok(())
207    }
208
209    pub(crate) fn handle(&self) -> *mut bindings::nixl_capi_xfer_dlist_s {
210        self.sync_mgr.backend().map(|b| b.as_ptr()).unwrap_or(ptr::null_mut())
211    }
212
213    /// Serializes the descriptor list to a byte vector using bincode
214    pub fn serialize(&self) -> Result<Vec<u8>, NixlError> {
215        // Serialize the XferDescData directly (contains mem_type + descriptors)
216        bincode::serialize(self.sync_mgr.data()).map_err(|_| NixlError::BackendError)
217    }
218
219    /// Deserializes a descriptor list from a byte slice using bincode
220    pub fn deserialize(bytes: &[u8]) -> Result<Self, NixlError> {
221        let data: XferDescData = bincode::deserialize(bytes)
222            .map_err(|_| NixlError::FailedToCreateXferDlistHandle)?;
223
224        let mut list = XferDescList::new(data.mem_type)?;
225        for desc in data.descriptors {
226            list.add_desc(desc.addr, desc.len, desc.dev_id);
227        }
228
229        // Force synchronization to validate backend can handle the data
230        list.sync_mgr.backend()?;
231
232        Ok(list)
233    }
234}
235
236impl std::fmt::Debug for XferDescList<'_> {
237    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238        let mem_type = self.get_type().unwrap_or(MemType::Unknown);
239        let len = self.len().unwrap_or(0);
240        let desc_count = self.desc_count().unwrap_or(0);
241
242        f.debug_struct("XferDescList")
243            .field("mem_type", &mem_type)
244            .field("len", &len)
245            .field("desc_count", &desc_count)
246            .finish()
247    }
248}
249
250impl PartialEq for XferDescList<'_> {
251    fn eq(&self, other: &Self) -> bool {
252        // Compare memory types first
253        if self.mem_type != other.mem_type {
254            return false;
255        }
256
257        // Compare internal descriptor tracking
258        self.sync_mgr.data().descriptors == other.sync_mgr.data().descriptors
259    }
260}
261
262// Implement Index trait for immutable indexing (list[i])
263impl Index<usize> for XferDescList<'_> {
264    type Output = XferDescriptor;
265
266    fn index(&self, index: usize) -> &Self::Output {
267        &self.sync_mgr.data().descriptors[index]
268    }
269}
270
271// Implement IndexMut trait for mutable indexing (list[i] = value)
272impl IndexMut<usize> for XferDescList<'_> {
273    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
274        // data_mut() automatically marks dirty
275        &mut self.sync_mgr.data_mut().descriptors[index]
276    }
277}
278
279impl Drop for XferDescList<'_> {
280    fn drop(&mut self) {
281        if let Ok(backend) = self.sync_mgr.backend() {
282            unsafe {
283                nixl_capi_destroy_xfer_dlist(backend.as_ptr());
284            }
285        }
286    }
287}