nixl_sys/
agent.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17use crate::descriptors::{QueryResponseList, RegDescList};
18use crate::bindings::{
19    nixl_capi_agent_config_s as nixl_capi_agent_config_t,
20    nixl_capi_thread_sync_t, nixl_capi_create_configured_agent};
21
22impl From<ThreadSync> for nixl_capi_thread_sync_t {
23    fn from(value: ThreadSync) -> Self {
24        match value {
25            ThreadSync::None => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_NONE,
26            ThreadSync::Strict => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_STRICT,
27            ThreadSync::Rw => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_RW,
28            ThreadSync::Default => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_DEFAULT,
29        }
30    }
31}
32
33/// A NIXL agent that can create backends and manage memory
34#[derive(Debug, Clone)]
35pub struct Agent {
36    inner: Arc<RwLock<AgentInner>>,
37}
38
39#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum XferStatus {
41    Success,
42    InProgress,
43}
44
45impl XferStatus {
46    pub fn is_success(&self) -> bool {
47        return *self == XferStatus::Success;
48    }
49}
50
51impl Agent {
52    /// Creates a new agent with the given name
53    pub fn new(name: &str) -> Result<Self, NixlError> {
54        tracing::trace!(agent.name = %name, "Creating new NIXL agent");
55        let c_name = CString::new(name)?;
56        let mut agent = ptr::null_mut();
57        let status = unsafe { nixl_capi_create_agent(c_name.as_ptr(), &mut agent) };
58
59        match status {
60            NIXL_CAPI_SUCCESS => {
61                // SAFETY: If status is NIXL_CAPI_SUCCESS, agent is non-null
62                let handle = unsafe { NonNull::new_unchecked(agent) };
63                tracing::trace!(agent.name = %name, "Successfully created NIXL agent");
64                Ok(Self {
65                    inner: Arc::new(RwLock::new(AgentInner::new(handle, name.to_string()))),
66                })
67            }
68            NIXL_CAPI_ERROR_INVALID_PARAM => {
69                tracing::error!(agent.name = %name, error = "invalid_param", "Failed to create NIXL agent");
70                Err(NixlError::InvalidParam)
71            }
72            _ => {
73                tracing::error!(agent.name = %name, error = "backend_error", "Failed to create NIXL agent");
74                Err(NixlError::BackendError)
75            }
76        }
77    }
78
79    /// Creates a new agent with the given configuration
80    pub fn new_configured(name: &str, cfg: &AgentConfig) -> Result<Self, NixlError> {
81        tracing::trace!(agent.name = %name, "Creating configured NIXL agent");
82        let c_name = CString::new(name)?;
83
84        // Prepare C ABI config
85        let mut c_cfg = nixl_capi_agent_config_t {
86            enable_prog_thread: cfg.enable_prog_thread,
87            enable_listen_thread: cfg.enable_listen_thread,
88            listen_port: cfg.listen_port,
89            thread_sync: cfg.thread_sync.into(),
90            num_workers: cfg.num_workers,
91            pthr_delay_us: cfg.pthr_delay_us,
92            lthr_delay_us: cfg.lthr_delay_us,
93            capture_telemetry: cfg.capture_telemetry,
94        };
95
96        let mut agent = ptr::null_mut();
97        let status = unsafe {
98            nixl_capi_create_configured_agent(c_name.as_ptr(), &mut c_cfg, &mut agent)
99        };
100
101        match status {
102            NIXL_CAPI_SUCCESS => {
103                // SAFETY: If status is NIXL_CAPI_SUCCESS, agent is non-null
104                let handle = unsafe { NonNull::new_unchecked(agent) };
105                tracing::trace!(agent.name = %name, "Successfully created configured NIXL agent");
106                Ok(Self {
107                    inner: Arc::new(RwLock::new(AgentInner::new(handle, name.to_string()))),
108                })
109            }
110            NIXL_CAPI_ERROR_INVALID_PARAM => {
111                tracing::error!(agent.name = %name, error = "invalid_param", "Failed to create configured NIXL agent");
112                Err(NixlError::InvalidParam)
113            }
114            _ => {
115                tracing::error!(agent.name = %name, error = "backend_error", "Failed to create configured NIXL agent");
116                Err(NixlError::BackendError)
117            }
118        }
119    }
120
121    /// Gets the name of the agent
122    pub fn name(&self) -> String {
123        self.inner.read().unwrap().name.clone()
124    }
125
126    /// Gets the list of available plugins
127    pub fn get_available_plugins(&self) -> Result<utils::StringList, NixlError> {
128        tracing::trace!("Getting available NIXL plugins");
129        let mut plugins = ptr::null_mut();
130
131        // SAFETY: self.inner is guaranteed to be valid by NonNull
132        let status = unsafe {
133            nixl_capi_get_available_plugins(
134                self.inner.write().unwrap().handle.as_ptr(),
135                &mut plugins,
136            )
137        };
138
139        match status {
140            0 => {
141                // SAFETY: If status is 0, plugins was successfully created and is non-null
142                let inner = unsafe { NonNull::new_unchecked(plugins) };
143                tracing::trace!("Successfully retrieved NIXL plugins");
144                Ok(utils::StringList::new(inner))
145            }
146            -1 => {
147                tracing::error!(error = "invalid_param", "Failed to get NIXL plugins");
148                Err(NixlError::InvalidParam)
149            }
150            _ => {
151                tracing::error!(error = "backend_error", "Failed to get NIXL plugins");
152                Err(NixlError::BackendError)
153            }
154        }
155    }
156
157    /// Gets the parameters for a plugin
158    ///
159    /// # Arguments
160    /// * `plugin_name` - The name of the plugin
161    ///
162    /// # Returns
163    /// The plugin's memory list and parameters
164    ///
165    /// # Errors
166    /// Returns a NixlError if:
167    /// * The plugin name contains interior nul bytes
168    /// * The operation fails
169    pub fn get_plugin_params(
170        &self,
171        plugin_name: &str,
172    ) -> Result<(MemList, utils::Params), NixlError> {
173        let plugin_name = CString::new(plugin_name)?;
174        let mut mems = ptr::null_mut();
175        let mut params = ptr::null_mut();
176
177        // SAFETY: self.inner is guaranteed to be valid by NonNull
178        let status = unsafe {
179            nixl_capi_get_plugin_params(
180                self.inner.read().unwrap().handle.as_ptr(),
181                plugin_name.as_ptr(),
182                &mut mems,
183                &mut params,
184            )
185        };
186
187        match status {
188            0 => {
189                // SAFETY: If status is 0, both pointers were successfully created and are non-null
190                let mems_inner = unsafe { NonNull::new_unchecked(mems) };
191                let params_inner = unsafe { NonNull::new_unchecked(params) };
192                Ok((
193                    MemList { inner: mems_inner },
194                    utils::Params::new(params_inner),
195                ))
196            }
197            -1 => Err(NixlError::InvalidParam),
198            _ => Err(NixlError::BackendError),
199        }
200    }
201
202    /// Creates a new backend for the given plugin using the provided parameters
203    pub fn create_backend(
204        &self,
205        plugin: &str,
206        params: &utils::Params,
207    ) -> Result<Backend, NixlError> {
208        tracing::trace!(plugin.name = %plugin, "Creating new NIXL backend");
209        let c_plugin = CString::new(plugin).map_err(|_| NixlError::InvalidParam)?;
210        let name = c_plugin.to_string_lossy().to_string();
211        let mut backend = ptr::null_mut();
212        let status = unsafe {
213            nixl_capi_create_backend(
214                self.inner.write().unwrap().handle.as_ptr(),
215                c_plugin.as_ptr(),
216                params.handle(),
217                &mut backend,
218            )
219        };
220
221        match status {
222            NIXL_CAPI_SUCCESS => {
223                let backend_handle = NonNull::new(backend).ok_or(NixlError::BackendError)?;
224                self.inner
225                    .write()
226                    .unwrap()
227                    .backends
228                    .insert(name.clone(), backend_handle);
229                tracing::trace!(plugin.name = %plugin, "Successfully created NIXL backend");
230                Ok(Backend {
231                    inner: backend_handle,
232                })
233            }
234            NIXL_CAPI_ERROR_INVALID_PARAM => {
235                tracing::error!(plugin.name = %plugin, error = "invalid_param", "Failed to create NIXL backend");
236                Err(NixlError::InvalidParam)
237            }
238            _ => {
239                tracing::error!(plugin.name = %plugin, error = "backend_error", "Failed to create NIXL backend");
240                Err(NixlError::BackendError)
241            }
242        }
243    }
244
245    /// Gets a backend by name
246    pub fn get_backend(&self, name: &str) -> Option<Backend> {
247        self.inner
248            .read()
249            .unwrap()
250            .get_backend(name)
251            .map(|backend| Backend { inner: backend })
252    }
253
254    /// Gets the parameters and memory types for a backend after initialization
255    pub fn get_backend_params(
256        &self,
257        backend: &Backend,
258    ) -> Result<(MemList, utils::Params), NixlError> {
259        let mut mem_list = ptr::null_mut();
260        let mut params = ptr::null_mut();
261
262        let status = unsafe {
263            nixl_capi_get_backend_params(
264                self.inner.read().unwrap().handle.as_ptr(),
265                backend.inner.as_ptr(),
266                &mut mem_list,
267                &mut params,
268            )
269        };
270
271        if status != NIXL_CAPI_SUCCESS {
272            return Err(NixlError::BackendError);
273        }
274
275        // SAFETY: If status is NIXL_CAPI_SUCCESS, both pointers are non-null
276        unsafe {
277            Ok((
278                MemList {
279                    inner: NonNull::new_unchecked(mem_list),
280                },
281                utils::Params::new(NonNull::new_unchecked(params)),
282            ))
283        }
284    }
285
286    /// Registers a memory descriptor with the agent
287    ///
288    /// # Arguments
289    /// * `descriptor` - The memory descriptor to register
290    /// * `opt_args` - Optional arguments for the registration
291    pub fn register_memory(
292        &self,
293        descriptor: &impl NixlDescriptor,
294        opt_args: Option<&OptArgs>,
295    ) -> Result<RegistrationHandle, NixlError> {
296        let mut reg_dlist = RegDescList::new(descriptor.mem_type())?;
297        unsafe {
298            reg_dlist.add_storage_desc(descriptor)?;
299
300            nixl_capi_register_mem(
301                self.inner.write().unwrap().handle.as_ptr(),
302                reg_dlist.handle(),
303                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
304            );
305        }
306        Ok(RegistrationHandle {
307            agent: Some(self.inner.clone()),
308            ptr: unsafe { descriptor.as_ptr() } as usize,
309            size: descriptor.size(),
310            dev_id: descriptor.device_id(),
311            mem_type: descriptor.mem_type(),
312        })
313    }
314
315    /// Query information about memory/storage
316    ///
317    /// # Arguments
318    /// * `descs` - Registration descriptor list to query
319    /// * `opt_args` - Optional arguments specifying backends
320    ///
321    /// # Returns
322    /// A list of query responses, where each response may contain parameters
323    /// describing the memory/storage characteristics.
324    pub fn query_mem(
325        &self,
326        descs: &RegDescList,
327        opt_args: Option<&OptArgs>,
328    ) -> Result<QueryResponseList, NixlError> {
329        let resp = QueryResponseList::new()?;
330
331        let status = {
332            let inner_guard = self.inner.write().unwrap();
333            unsafe {
334                nixl_capi_query_mem(
335                    inner_guard.handle.as_ptr(),
336                    descs.handle(),
337                    resp.handle(),
338                    opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
339                )
340            }
341        };
342
343        match status {
344            NIXL_CAPI_SUCCESS => Ok(resp),
345            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
346            _ => Err(NixlError::BackendError),
347        }
348    }
349
350    /// Gets the local metadata for this agent as a byte array
351    pub fn get_local_md(&self) -> Result<Vec<u8>, NixlError> {
352        tracing::trace!("Getting local metadata");
353        let mut data = std::ptr::null_mut();
354        let mut len = 0;
355
356        let status = unsafe {
357            nixl_capi_get_local_md(
358                self.inner.write().unwrap().handle.as_ptr(),
359                &mut data as *mut *mut _,
360                &mut len,
361            )
362        };
363
364        let data = data as *const u8;
365
366        if data.is_null() {
367            tracing::trace!(
368                error = "invalid_data_pointer",
369                "Failed to get local metadata"
370            );
371            return Err(NixlError::InvalidDataPointer);
372        }
373
374        match status {
375            NIXL_CAPI_SUCCESS => {
376                let bytes = unsafe {
377                    let slice = std::slice::from_raw_parts(data, len);
378                    let vec = slice.to_vec();
379                    libc::free(data as *mut libc::c_void);
380                    vec
381                };
382                tracing::trace!(metadata.size = len, "Successfully retrieved local metadata");
383                Ok(bytes)
384            }
385            NIXL_CAPI_ERROR_INVALID_PARAM => {
386                tracing::error!(error = "invalid_param", "Failed to get local metadata");
387                Err(NixlError::InvalidParam)
388            }
389            _ => {
390                tracing::error!(error = "backend_error", "Failed to get local metadata");
391                Err(NixlError::BackendError)
392            }
393        }
394    }
395
396    /// Gets the local partial metadata as a byte array
397    ///
398    /// # Arguments
399    /// * `descs` - Registration descriptor list to get metadata for
400    /// * `opt_args` - Optional arguments for getting metadata
401    ///
402    /// # Returns
403    /// A byte array containing the local partial metadata
404    ///
405    pub fn get_local_partial_md(&self, descs: &RegDescList, opt_args: Option<&OptArgs>) -> Result<Vec<u8>, NixlError> {
406        tracing::trace!("Getting local partial metadata");
407        let mut data = std::ptr::null_mut();
408        let mut len: usize = 0;
409        let inner_guard = self.inner.write().unwrap();
410
411        let status = unsafe {
412            nixl_capi_get_local_partial_md(
413                inner_guard.handle.as_ptr(),
414                descs.handle(),
415                &mut data as *mut *mut _,
416                &mut len,
417                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
418            )
419        };
420        match status {
421            NIXL_CAPI_SUCCESS => {
422                let bytes = unsafe {
423                    let slice = std::slice::from_raw_parts(data as *const u8, len);
424                    let vec = slice.to_vec();
425                    libc::free(data as *mut libc::c_void);
426                    vec
427                };
428                tracing::trace!(metadata.size = len, "Successfully retrieved local partial metadata");
429                Ok(bytes)
430            }
431            NIXL_CAPI_ERROR_INVALID_PARAM => {
432                tracing::error!(error = "invalid_param", "Failed to get local partial metadata");
433                Err(NixlError::InvalidParam)
434            }
435            _ => {
436                tracing::error!(error = "backend_error", "Failed to get local partial metadata");
437                Err(NixlError::BackendError)
438            }
439        }
440    }
441
442    /// Loads remote metadata from a byte slice
443    pub fn load_remote_md(&self, metadata: &[u8]) -> Result<String, NixlError> {
444        tracing::trace!(metadata.size = metadata.len(), "Loading remote metadata");
445        let mut agent_name = std::ptr::null_mut();
446
447        let status = unsafe {
448            nixl_capi_load_remote_md(
449                self.inner.write().unwrap().handle.as_ptr(),
450                metadata.as_ptr() as *const std::ffi::c_void,
451                metadata.len(),
452                &mut agent_name,
453            )
454        };
455
456        match status {
457            NIXL_CAPI_SUCCESS => {
458                let name = unsafe {
459                    let c_str = std::ffi::CStr::from_ptr(agent_name);
460                    let s = c_str.to_str().unwrap().to_string();
461                    libc::free(agent_name as *mut libc::c_void);
462                    s
463                };
464                self.inner.write().unwrap().remotes.insert(name.clone());
465                tracing::trace!(remote.agent = %name, "Successfully loaded remote metadata");
466                Ok(name)
467            }
468            NIXL_CAPI_ERROR_INVALID_PARAM => {
469                tracing::error!(error = "invalid_param", "Failed to load remote metadata");
470                Err(NixlError::InvalidParam)
471            }
472            _ => {
473                tracing::error!(error = "backend_error", "Failed to load remote metadata");
474                Err(NixlError::BackendError)
475            }
476        }
477    }
478
479    pub fn make_connection(&self, remote_agent: &str, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
480        let remote_agent = CString::new(remote_agent)?;
481        let inner_guard = self.inner.write().unwrap();
482
483        let status = unsafe {
484            nixl_capi_agent_make_connection(
485                inner_guard.handle.as_ptr(),
486                remote_agent.as_ptr(),
487                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
488            )
489        };
490
491        match status {
492            NIXL_CAPI_SUCCESS => Ok(()),
493            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
494            _ => Err(NixlError::BackendError),
495        }
496    }
497
498    pub fn prepare_xfer_dlist(
499        &self,
500        agent_name: &str,
501        descs: &XferDescList,
502        opt_args: Option<&OptArgs>,
503    ) -> Result<XferDlistHandle, NixlError> {
504        let c_agent_name = CString::new(agent_name)?;
505        let mut dlist_hndl = std::ptr::null_mut();
506        let inner_guard = self.inner.read().unwrap();
507
508        let status = unsafe {
509            nixl_capi_prep_xfer_dlist(
510                inner_guard.handle.as_ptr(),
511                c_agent_name.as_ptr(),
512                descs.handle(),
513                &mut dlist_hndl,
514                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
515            )
516        };
517
518        match status {
519            NIXL_CAPI_SUCCESS => Ok(XferDlistHandle::new(dlist_hndl, inner_guard.handle)),
520            _ => Err(NixlError::BackendError),
521        }
522    }
523
524    pub fn make_xfer_req(&self, operation: XferOp,
525                         local_descs: &XferDlistHandle, local_indices: &[i32],
526                         remote_descs: &XferDlistHandle, remote_indices: &[i32],
527                         opt_args: Option<&OptArgs>) -> Result<XferRequest, NixlError> {
528        let mut req = std::ptr::null_mut();
529        let inner_guard = self.inner.read().unwrap();
530
531        let status = unsafe {
532            nixl_capi_make_xfer_req(
533                inner_guard.handle.as_ptr(),
534                operation as bindings::nixl_capi_xfer_op_t,
535                local_descs.handle(),
536                local_indices.as_ptr(),
537                local_indices.len() as usize,
538                remote_descs.handle(),
539                remote_indices.as_ptr(),
540                remote_indices.len() as usize,
541                &mut req,
542                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr())
543            )
544        };
545
546        match status {
547            NIXL_CAPI_SUCCESS => Ok(XferRequest::new(NonNull::new(req)
548                .ok_or(NixlError::FailedToCreateXferRequest)?,
549                self.inner.clone(),
550            )),
551            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
552            _ => Err(NixlError::BackendError),
553        }
554    }
555
556    /// Check if remote metadata for a specific agent is available
557    ///
558    /// This function checks if the metadata for the specified remote agent has been
559    /// loaded and if specific descriptors can be found in the metadata.
560    ///
561    /// # Arguments
562    /// * `remote_agent` - Name of the remote agent to check
563    /// * `descs` - Optional descriptor list to check against the remote metadata.
564    ///            If None, only checks if any metadata exists for the agent.
565    ///
566    /// # Returns
567    /// `true` if the remote agent's metadata is available (and descriptors are found if provided),
568    /// `false` otherwise
569    pub fn check_remote_metadata(&self, remote_agent: &str, descs: Option<&XferDescList>) -> bool {
570        tracing::trace!(remote_agent = %remote_agent, "Checking remote metadata");
571
572        let c_remote_name = match CString::new(remote_agent) {
573            Ok(name) => name,
574            Err(_) => {
575                tracing::trace!(
576                    error = "invalid_param",
577                    remote_agent = %remote_agent,
578                    "Invalid remote agent name"
579                );
580                return false;
581            }
582        };
583
584        let status = unsafe {
585            bindings::nixl_capi_check_remote_md(
586                self.inner.read().unwrap().handle.as_ptr(),
587                c_remote_name.as_ptr(),
588                descs.map_or(std::ptr::null_mut(), |d| d.as_ptr()),
589            )
590        };
591
592        match status {
593            NIXL_CAPI_SUCCESS => {
594                tracing::trace!(remote_agent = %remote_agent, "Remote metadata is available");
595                true
596            }
597            _ => {
598                tracing::trace!(remote_agent = %remote_agent, "Remote metadata is not available");
599                false
600            }
601        }
602    }
603
604    /// Invalidates a remote metadata for this agent
605    pub fn invalidate_remote_md(&self, remote_agent: &str) -> Result<(), NixlError> {
606        self.inner
607            .write()
608            .unwrap()
609            .invalidate_remote_md(remote_agent)
610    }
611
612    /// Invalidates all remote metadata for this agent
613    pub fn invalidate_all_remotes(&self) -> Result<(), NixlError> {
614        self.inner.write().unwrap().invalidate_all_remotes()
615    }
616
617    /// Send this agent's metadata to etcdAdd commentMore actions
618    ///
619    /// This enables other agents to discover this agent's metadata via etcd.
620    ///
621    /// # Arguments
622    /// * `opt_args` - Optional arguments for sending metadata
623    pub fn send_local_md(&self, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
624        tracing::trace!("Sending local metadata to etcd");
625        let inner_guard = self.inner.write().unwrap();
626        let status = unsafe {
627            bindings::nixl_capi_send_local_md(
628                inner_guard.handle.as_ptr(),
629                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
630            )
631        };
632
633        match status {
634            NIXL_CAPI_SUCCESS => {
635                tracing::trace!("Successfully sent local metadata to etcd");
636                Ok(())
637            }
638            NIXL_CAPI_ERROR_INVALID_PARAM => {
639                tracing::error!(
640                    error = "invalid_param",
641                    "Failed to send local metadata to etcd"
642                );
643                Err(NixlError::InvalidParam)
644            }
645            _ => {
646                tracing::error!(
647                    error = "backend_error",
648                    "Failed to send local metadata to etcd"
649                );
650                Err(NixlError::BackendError)
651            }
652        }
653    }
654
655    /// Send this agent's partial metadata
656    ///
657    /// # Arguments
658    /// * `descs` - Registration descriptor list to send
659    /// * `opt_args` - Optional arguments for sending metadata
660    pub fn send_local_partial_md(&self, descs: &RegDescList, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
661        tracing::trace!("Sending local partial metadata to etcd");
662        let inner_guard = self.inner.write().unwrap();
663        let status = unsafe {
664            nixl_capi_send_local_partial_md(
665                inner_guard.handle.as_ptr(),
666                descs.handle(),
667                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
668            )
669        };
670        match status {
671            NIXL_CAPI_SUCCESS => {
672                tracing::trace!("Successfully sent local partial metadata to etcd");
673                Ok(())
674            }
675            NIXL_CAPI_ERROR_INVALID_PARAM => {
676                tracing::error!(error = "invalid_param", "Failed to send local partial metadata to etcd");
677                Err(NixlError::InvalidParam)
678            }
679            _ => Err(NixlError::BackendError)
680        }
681    }
682
683
684    /// Fetch a remote agent's metadata from etcd
685    ///
686    /// Once fetched, the metadata will be loaded and cached locally, enabling
687    /// communication with the remote agent.
688    ///
689    /// # Arguments
690    /// * `remote_name` - Name of the remote agent to fetch metadata for
691    /// * `opt_args` - Optional arguments for fetching metadata
692    pub fn fetch_remote_md(
693        &self,
694        remote_name: &str,
695        opt_args: Option<&OptArgs>,
696    ) -> Result<(), NixlError> {
697        tracing::trace!(remote_agent = %remote_name, "Fetching remote metadata from etcd");
698
699        let c_remote_name = CString::new(remote_name)?;
700        let inner_guard = self.inner.write().unwrap();
701
702        let status = unsafe {
703            bindings::nixl_capi_fetch_remote_md(
704                inner_guard.handle.as_ptr(),
705                c_remote_name.as_ptr(),
706                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
707            )
708        };
709
710        match status {
711            NIXL_CAPI_SUCCESS => {
712                self.inner
713                    .write()
714                    .unwrap()
715                    .remotes
716                    .insert(remote_name.to_string());
717                tracing::trace!(remote_agent = %remote_name, "Successfully fetched remote metadata from etcd");
718                Ok(())
719            }
720            NIXL_CAPI_ERROR_INVALID_PARAM => {
721                tracing::error!(error = "invalid_param", remote_agent = %remote_name, "Failed to fetch remote metadata from etcd");
722                Err(NixlError::InvalidParam)
723            }
724            _ => {
725                tracing::error!(error = "backend_error", remote_agent = %remote_name, "Failed to fetch remote metadata from etcd");
726                Err(NixlError::BackendError)
727            }
728        }
729    }
730
731    /// Invalidate this agent's metadata in etcd
732    ///
733    /// This signals to other agents that this agent's metadata is no longer valid.
734    ///
735    /// # Arguments
736    /// * `opt_args` - Optional arguments for invalidating metadata
737    pub fn invalidate_local_md(&self, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
738        tracing::trace!("Invalidating local metadata in etcd");
739        let inner_guard = self.inner.write().unwrap();
740        let status = unsafe {
741            bindings::nixl_capi_invalidate_local_md(
742                inner_guard.handle.as_ptr(),
743                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
744            )
745        };
746
747        match status {
748            NIXL_CAPI_SUCCESS => {
749                tracing::trace!("Successfully invalidated local metadata in etcd");
750                Ok(())
751            }
752            NIXL_CAPI_ERROR_INVALID_PARAM => {
753                tracing::error!(
754                    error = "invalid_param",
755                    "Failed to invalidate local metadata in etcd"
756                );
757                Err(NixlError::InvalidParam)
758            }
759            _ => {
760                tracing::error!(
761                    error = "backend_error",
762                    "Failed to invalidate local metadata in etcd"
763                );
764                Err(NixlError::BackendError)
765            }
766        }
767    }
768
769    /// Send a notification to a remote agent
770    ///
771    /// # Arguments
772    /// * `remote_agent` - Name of the remote agent to send notification to
773    /// * `message` - The notification message to send
774    /// * `backend` - Optional backend to use for sending the notification
775    ///
776    /// # Returns
777    /// `Ok(())` if the notification was sent successfully
778    pub fn send_notification(
779        &self,
780        remote_agent: &str,
781        message: &[u8],
782        backend: Option<&Backend>,
783    ) -> Result<(), NixlError> {
784        tracing::trace!(remote_agent = %remote_agent, "Sending notification");
785
786        let c_remote_name = CString::new(remote_agent)?;
787        let inner_guard = self.inner.write().unwrap();
788
789        let opt_args = if backend.is_some() {
790            let mut args = OptArgs::new()?;
791            if let Some(b) = backend {
792                args.add_backend(b)?;
793            }
794            Some(args)
795        } else {
796            None
797        };
798
799        let status = unsafe {
800            nixl_capi_gen_notif(
801                inner_guard.handle.as_ptr(),
802                c_remote_name.as_ptr(),
803                message.as_ptr() as *const std::ffi::c_void,
804                message.len(),
805                opt_args
806                    .as_ref()
807                    .map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
808            )
809        };
810
811        match status {
812            NIXL_CAPI_SUCCESS => {
813                tracing::trace!(remote_agent = %remote_agent, "Successfully sent notification");
814                Ok(())
815            }
816            NIXL_CAPI_ERROR_INVALID_PARAM => {
817                tracing::error!(error = "invalid_param", remote_agent = %remote_agent, "Failed to send notification");
818                Err(NixlError::InvalidParam)
819            }
820            _ => {
821                tracing::error!(error = "backend_error", remote_agent = %remote_agent, "Failed to send notification");
822                Err(NixlError::BackendError)
823            }
824        }
825    }
826
827    /// Creates a transfer request between local and remote descriptors
828    ///
829    /// # Arguments
830    /// * `operation` - The transfer operation (read or write)
831    /// * `local_descs` - The local descriptor list
832    /// * `remote_descs` - The remote descriptor list
833    /// * `remote_agent` - The name of the remote agent
834    /// * `opt_args` - Optional arguments for the transfer
835    ///
836    /// # Returns
837    /// A handle to the transfer request
838    ///
839    /// # Errors
840    /// Returns a NixlError if the operation fails
841    pub fn create_xfer_req(
842        &self,
843        operation: XferOp,
844        local_descs: &XferDescList,
845        remote_descs: &XferDescList,
846        remote_agent: &str,
847        opt_args: Option<&OptArgs>,
848    ) -> Result<XferRequest, NixlError> {
849        let remote_agent = CString::new(remote_agent)?;
850        let mut req = std::ptr::null_mut();
851
852        // SAFETY: All pointers are guaranteed to be valid
853        let status = unsafe {
854            bindings::nixl_capi_create_xfer_req(
855                self.inner.read().unwrap().handle.as_ptr(),
856                operation as bindings::nixl_capi_xfer_op_t,
857                local_descs.handle(),
858                remote_descs.handle(),
859                remote_agent.as_ptr(),
860                &mut req,
861                opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
862            )
863        };
864
865        match status {
866            NIXL_CAPI_SUCCESS => {
867                // SAFETY: If status is NIXL_CAPI_SUCCESS, req is guaranteed to be non-null
868                let inner = NonNull::new(req).ok_or(NixlError::FailedToCreateXferRequest)?;
869                Ok(XferRequest::new(inner, self.inner.clone()))
870            }
871            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
872            _ => Err(NixlError::FailedToCreateXferRequest),
873        }
874    }
875
876    /// Estimates the cost of a transfer request
877    ///
878    /// # Arguments
879    /// * `req` - Transfer request handle
880    /// * `opt_args` - Optional arguments for the estimation
881    ///
882    /// # Returns
883    /// A tuple containing (duration in microseconds, error margin in microseconds, cost method)
884    ///
885    /// # Errors
886    /// Returns a NixlError if the operation fails
887    pub fn estimate_xfer_cost(
888        &self,
889        req: &XferRequest,
890        opt_args: Option<&OptArgs>,
891    ) -> Result<(i64, i64, CostMethod), NixlError> {
892        let mut duration_us: i64 = 0;
893        let mut err_margin_us: i64 = 0;
894        let mut method: u32 = 0;
895
896        let status = unsafe {
897            nixl_capi_estimate_xfer_cost(
898                self.inner.write().unwrap().handle.as_ptr(),
899                req.handle(),
900                opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
901                &mut duration_us,
902                &mut err_margin_us,
903                &mut method as *mut u32 as *mut bindings::nixl_capi_cost_t,
904            )
905        };
906
907        match status {
908            NIXL_CAPI_SUCCESS => Ok((duration_us, err_margin_us, CostMethod::from(method))),
909            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
910            _ => Err(NixlError::BackendError),
911        }
912    }
913
914    /// Posts a transfer request to initiate a transfer
915    ///
916    /// After this, the transfer state can be checked asynchronously until completion.
917    /// For small transfers that complete within the call, the function returns `Ok(false)`.
918    /// Otherwise, it returns `Ok(true)` to indicate the transfer is in progress.
919    ///
920    /// # Arguments
921    /// * `req` - Transfer request handle obtained from `create_xfer_req`
922    /// * `opt_args` - Optional arguments for the transfer request
923    ///
924    /// # Returns
925    /// * `Ok(false)` - If the transfer completed immediately
926    /// * `Ok(true)` - If the transfer is in progress
927    /// * `Err` - If there was an error posting the transfer request
928    pub fn post_xfer_req(
929        &self,
930        req: &XferRequest,
931        opt_args: Option<&OptArgs>,
932    ) -> Result<bool, NixlError> {
933        tracing::trace!("Posting transfer request");
934        let status = unsafe {
935            nixl_capi_post_xfer_req(
936                self.inner.write().unwrap().handle.as_ptr(),
937                req.handle(),
938                opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
939            )
940        };
941
942        match status {
943            NIXL_CAPI_SUCCESS => {
944                tracing::trace!(
945                    status = "completed",
946                    "Transfer request completed immediately"
947                );
948                Ok(false)
949            }
950            NIXL_CAPI_IN_PROG => {
951                tracing::trace!(status = "in_progress", "Transfer request in progress");
952                Ok(true)
953            }
954            NIXL_CAPI_ERROR_INVALID_PARAM => {
955                tracing::error!(error = "invalid_param", "Failed to post transfer request");
956                Err(NixlError::InvalidParam)
957            }
958            _ => {
959                tracing::error!(error = "backend_error", "Failed to post transfer request");
960                Err(NixlError::BackendError)
961            }
962        }
963    }
964
965    /// Checks the status of a transfer request
966    ///
967    /// Returns `Ok(true)` if the transfer is still in progress, `Ok(false)` if it completed successfully.
968    ///
969    /// # Arguments
970    /// * `req` - Transfer request handle after `post_xfer_req`
971    pub fn get_xfer_status(&self, req: &XferRequest) -> Result<XferStatus, NixlError> {
972        let status = unsafe {
973            nixl_capi_get_xfer_status(self.inner.write().unwrap().handle.as_ptr(), req.handle())
974        };
975
976        match status {
977            NIXL_CAPI_SUCCESS => Ok(XferStatus::Success), // Transfer completed
978            NIXL_CAPI_IN_PROG => Ok(XferStatus::InProgress),  // Transfer in progress
979            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
980            _ => Err(NixlError::BackendError),
981        }
982    }
983
984    /// Queries the backend for a transfer request
985    ///
986    /// # Arguments
987    /// * `req` - Transfer request handle after `post_xfer_req`
988    ///
989    /// # Returns
990    /// A handle to the backend used for the transfer
991    ///
992    /// # Errors
993    /// Returns a NixlError if the operation fails
994    pub fn query_xfer_backend(&self, req: &XferRequest) -> Result<Backend, NixlError> {
995        let mut backend = std::ptr::null_mut();
996        let inner_guard = self.inner.write().unwrap();
997        let status = unsafe {
998            nixl_capi_query_xfer_backend(
999                inner_guard.handle.as_ptr(),
1000                req.handle(),
1001                &mut backend
1002            )
1003        };
1004        match status {
1005            NIXL_CAPI_SUCCESS => {
1006                Ok(Backend{ inner: NonNull::new(backend).ok_or(NixlError::FailedToCreateBackend)? })
1007            }
1008            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
1009            _ => Err(NixlError::BackendError),
1010        }
1011    }
1012
1013
1014    /// Gets notifications from other agents
1015    ///
1016    /// # Arguments
1017    /// * `notifs` - Notification map to populate with notifications
1018    /// * `opt_args` - Optional arguments to filter notifications by backend
1019    pub fn get_notifications(
1020        &self,
1021        notifs: &mut NotificationMap,
1022        opt_args: Option<&OptArgs>,
1023    ) -> Result<(), NixlError> {
1024        tracing::trace!("Getting notifications");
1025        let status = unsafe {
1026            nixl_capi_get_notifs(
1027                self.inner.write().unwrap().handle.as_ptr(),
1028                notifs.inner.as_ptr(),
1029                opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
1030            )
1031        };
1032
1033        match status {
1034            NIXL_CAPI_SUCCESS => {
1035                tracing::trace!("Successfully retrieved notifications");
1036                Ok(())
1037            }
1038            NIXL_CAPI_ERROR_INVALID_PARAM => {
1039                tracing::error!(error = "invalid_param", "Failed to get notifications");
1040                Err(NixlError::InvalidParam)
1041            }
1042            _ => {
1043                tracing::error!(error = "backend_error", "Failed to get notifications");
1044                Err(NixlError::BackendError)
1045            }
1046        }
1047    }
1048}
1049
1050/// Inner state for an agent that manages the raw pointer
1051#[derive(Debug)]
1052pub(crate) struct AgentInner {
1053    pub(crate) name: String,
1054    pub(crate) handle: NonNull<bindings::nixl_capi_agent_s>,
1055    pub(crate) backends: HashMap<String, NonNull<bindings::nixl_capi_backend_s>>,
1056    pub(crate) remotes: HashSet<String>,
1057}
1058
1059#[derive(Clone, Copy, Debug)]
1060pub enum ThreadSync {
1061    None,
1062    Strict,
1063    Rw,
1064    Default,
1065}
1066
1067#[derive(Clone, Debug)]
1068pub struct AgentConfig {
1069    pub enable_prog_thread: bool,
1070    pub enable_listen_thread: bool,
1071    pub listen_port: i32,
1072    pub thread_sync: ThreadSync,
1073    pub num_workers: u32,
1074    pub pthr_delay_us: u64,
1075    pub lthr_delay_us: u64,
1076    pub capture_telemetry: bool,
1077}
1078
1079impl Default for AgentConfig {
1080    fn default() -> Self {
1081        Self {
1082            enable_prog_thread: true,
1083            enable_listen_thread: false,
1084            listen_port: 0,
1085            thread_sync: ThreadSync::None,
1086            num_workers: 1,
1087            pthr_delay_us: 0,
1088            lthr_delay_us: 100_000,
1089            capture_telemetry: false,
1090        }
1091    }
1092}
1093
1094unsafe impl Send for AgentInner {}
1095unsafe impl Sync for AgentInner {}
1096
1097impl AgentInner {
1098    fn new(handle: NonNull<bindings::nixl_capi_agent_s>, name: String) -> Self {
1099        Self {
1100            name,
1101            handle,
1102            backends: HashMap::new(),
1103            remotes: HashSet::new(),
1104        }
1105    }
1106
1107    fn get_backend(&self, name: &str) -> Option<NonNull<bindings::nixl_capi_backend_s>> {
1108        self.backends.get(name).cloned()
1109    }
1110
1111    fn invalidate_remote_md(&mut self, remote_agent: &str) -> Result<(), NixlError> {
1112        unsafe {
1113            if self.remotes.remove(remote_agent) {
1114                nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote_agent.as_ptr().cast());
1115            } else {
1116                return Err(NixlError::InvalidParam);
1117            }
1118        }
1119        Ok(())
1120    }
1121
1122    fn invalidate_all_remotes(&mut self) -> Result<(), NixlError> {
1123        unsafe {
1124            for remote in self.remotes.drain() {
1125                nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote.as_ptr().cast());
1126            }
1127        }
1128        Ok(())
1129    }
1130}
1131
1132impl Drop for AgentInner {
1133    fn drop(&mut self) {
1134        tracing::trace!("Dropping NIXL agent");
1135        unsafe {
1136            // invalidate all remotes
1137            for remote in self.remotes.iter() {
1138                tracing::trace!(remote.agent = %remote, "Invalidating remote agent");
1139                nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote.as_ptr().cast());
1140            }
1141
1142            // destroy all backends
1143            for backend in self.backends.values() {
1144                tracing::trace!("Destroying backend");
1145                nixl_capi_destroy_backend(backend.as_ptr());
1146            }
1147
1148            nixl_capi_destroy_agent(self.handle.as_ptr());
1149        }
1150        tracing::trace!("NIXL agent dropped");
1151    }
1152}