1use super::*;
17use crate::descriptors::{QueryResponseList, RegDescList};
18use crate::bindings::{
19 nixl_capi_agent_config_s as nixl_capi_agent_config_t,
20 nixl_capi_thread_sync_t, nixl_capi_create_configured_agent};
21
22impl From<ThreadSync> for nixl_capi_thread_sync_t {
23 fn from(value: ThreadSync) -> Self {
24 match value {
25 ThreadSync::None => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_NONE,
26 ThreadSync::Strict => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_STRICT,
27 ThreadSync::Rw => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_RW,
28 ThreadSync::Default => crate::bindings::nixl_capi_thread_sync_t_NIXL_CAPI_THREAD_SYNC_DEFAULT,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct Agent {
36 inner: Arc<RwLock<AgentInner>>,
37}
38
39#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum XferStatus {
41 Success,
42 InProgress,
43}
44
45impl XferStatus {
46 pub fn is_success(&self) -> bool {
47 return *self == XferStatus::Success;
48 }
49}
50
51impl Agent {
52 pub fn new(name: &str) -> Result<Self, NixlError> {
54 tracing::trace!(agent.name = %name, "Creating new NIXL agent");
55 let c_name = CString::new(name)?;
56 let mut agent = ptr::null_mut();
57 let status = unsafe { nixl_capi_create_agent(c_name.as_ptr(), &mut agent) };
58
59 match status {
60 NIXL_CAPI_SUCCESS => {
61 let handle = unsafe { NonNull::new_unchecked(agent) };
63 tracing::trace!(agent.name = %name, "Successfully created NIXL agent");
64 Ok(Self {
65 inner: Arc::new(RwLock::new(AgentInner::new(handle, name.to_string()))),
66 })
67 }
68 NIXL_CAPI_ERROR_INVALID_PARAM => {
69 tracing::error!(agent.name = %name, error = "invalid_param", "Failed to create NIXL agent");
70 Err(NixlError::InvalidParam)
71 }
72 _ => {
73 tracing::error!(agent.name = %name, error = "backend_error", "Failed to create NIXL agent");
74 Err(NixlError::BackendError)
75 }
76 }
77 }
78
79 pub fn new_configured(name: &str, cfg: &AgentConfig) -> Result<Self, NixlError> {
81 tracing::trace!(agent.name = %name, "Creating configured NIXL agent");
82 let c_name = CString::new(name)?;
83
84 let mut c_cfg = nixl_capi_agent_config_t {
86 enable_prog_thread: cfg.enable_prog_thread,
87 enable_listen_thread: cfg.enable_listen_thread,
88 listen_port: cfg.listen_port,
89 thread_sync: cfg.thread_sync.into(),
90 num_workers: cfg.num_workers,
91 pthr_delay_us: cfg.pthr_delay_us,
92 lthr_delay_us: cfg.lthr_delay_us,
93 capture_telemetry: cfg.capture_telemetry,
94 };
95
96 let mut agent = ptr::null_mut();
97 let status = unsafe {
98 nixl_capi_create_configured_agent(c_name.as_ptr(), &mut c_cfg, &mut agent)
99 };
100
101 match status {
102 NIXL_CAPI_SUCCESS => {
103 let handle = unsafe { NonNull::new_unchecked(agent) };
105 tracing::trace!(agent.name = %name, "Successfully created configured NIXL agent");
106 Ok(Self {
107 inner: Arc::new(RwLock::new(AgentInner::new(handle, name.to_string()))),
108 })
109 }
110 NIXL_CAPI_ERROR_INVALID_PARAM => {
111 tracing::error!(agent.name = %name, error = "invalid_param", "Failed to create configured NIXL agent");
112 Err(NixlError::InvalidParam)
113 }
114 _ => {
115 tracing::error!(agent.name = %name, error = "backend_error", "Failed to create configured NIXL agent");
116 Err(NixlError::BackendError)
117 }
118 }
119 }
120
121 pub fn name(&self) -> String {
123 self.inner.read().unwrap().name.clone()
124 }
125
126 pub fn get_available_plugins(&self) -> Result<utils::StringList, NixlError> {
128 tracing::trace!("Getting available NIXL plugins");
129 let mut plugins = ptr::null_mut();
130
131 let status = unsafe {
133 nixl_capi_get_available_plugins(
134 self.inner.write().unwrap().handle.as_ptr(),
135 &mut plugins,
136 )
137 };
138
139 match status {
140 0 => {
141 let inner = unsafe { NonNull::new_unchecked(plugins) };
143 tracing::trace!("Successfully retrieved NIXL plugins");
144 Ok(utils::StringList::new(inner))
145 }
146 -1 => {
147 tracing::error!(error = "invalid_param", "Failed to get NIXL plugins");
148 Err(NixlError::InvalidParam)
149 }
150 _ => {
151 tracing::error!(error = "backend_error", "Failed to get NIXL plugins");
152 Err(NixlError::BackendError)
153 }
154 }
155 }
156
157 pub fn get_plugin_params(
170 &self,
171 plugin_name: &str,
172 ) -> Result<(MemList, utils::Params), NixlError> {
173 let plugin_name = CString::new(plugin_name)?;
174 let mut mems = ptr::null_mut();
175 let mut params = ptr::null_mut();
176
177 let status = unsafe {
179 nixl_capi_get_plugin_params(
180 self.inner.read().unwrap().handle.as_ptr(),
181 plugin_name.as_ptr(),
182 &mut mems,
183 &mut params,
184 )
185 };
186
187 match status {
188 0 => {
189 let mems_inner = unsafe { NonNull::new_unchecked(mems) };
191 let params_inner = unsafe { NonNull::new_unchecked(params) };
192 Ok((
193 MemList { inner: mems_inner },
194 utils::Params::new(params_inner),
195 ))
196 }
197 -1 => Err(NixlError::InvalidParam),
198 _ => Err(NixlError::BackendError),
199 }
200 }
201
202 pub fn create_backend(
204 &self,
205 plugin: &str,
206 params: &utils::Params,
207 ) -> Result<Backend, NixlError> {
208 tracing::trace!(plugin.name = %plugin, "Creating new NIXL backend");
209 let c_plugin = CString::new(plugin).map_err(|_| NixlError::InvalidParam)?;
210 let name = c_plugin.to_string_lossy().to_string();
211 let mut backend = ptr::null_mut();
212 let status = unsafe {
213 nixl_capi_create_backend(
214 self.inner.write().unwrap().handle.as_ptr(),
215 c_plugin.as_ptr(),
216 params.handle(),
217 &mut backend,
218 )
219 };
220
221 match status {
222 NIXL_CAPI_SUCCESS => {
223 let backend_handle = NonNull::new(backend).ok_or(NixlError::BackendError)?;
224 self.inner
225 .write()
226 .unwrap()
227 .backends
228 .insert(name.clone(), backend_handle);
229 tracing::trace!(plugin.name = %plugin, "Successfully created NIXL backend");
230 Ok(Backend {
231 inner: backend_handle,
232 })
233 }
234 NIXL_CAPI_ERROR_INVALID_PARAM => {
235 tracing::error!(plugin.name = %plugin, error = "invalid_param", "Failed to create NIXL backend");
236 Err(NixlError::InvalidParam)
237 }
238 _ => {
239 tracing::error!(plugin.name = %plugin, error = "backend_error", "Failed to create NIXL backend");
240 Err(NixlError::BackendError)
241 }
242 }
243 }
244
245 pub fn get_backend(&self, name: &str) -> Option<Backend> {
247 self.inner
248 .read()
249 .unwrap()
250 .get_backend(name)
251 .map(|backend| Backend { inner: backend })
252 }
253
254 pub fn get_backend_params(
256 &self,
257 backend: &Backend,
258 ) -> Result<(MemList, utils::Params), NixlError> {
259 let mut mem_list = ptr::null_mut();
260 let mut params = ptr::null_mut();
261
262 let status = unsafe {
263 nixl_capi_get_backend_params(
264 self.inner.read().unwrap().handle.as_ptr(),
265 backend.inner.as_ptr(),
266 &mut mem_list,
267 &mut params,
268 )
269 };
270
271 if status != NIXL_CAPI_SUCCESS {
272 return Err(NixlError::BackendError);
273 }
274
275 unsafe {
277 Ok((
278 MemList {
279 inner: NonNull::new_unchecked(mem_list),
280 },
281 utils::Params::new(NonNull::new_unchecked(params)),
282 ))
283 }
284 }
285
286 pub fn register_memory(
292 &self,
293 descriptor: &impl NixlDescriptor,
294 opt_args: Option<&OptArgs>,
295 ) -> Result<RegistrationHandle, NixlError> {
296 let mut reg_dlist = RegDescList::new(descriptor.mem_type())?;
297 unsafe {
298 reg_dlist.add_storage_desc(descriptor)?;
299
300 nixl_capi_register_mem(
301 self.inner.write().unwrap().handle.as_ptr(),
302 reg_dlist.handle(),
303 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
304 );
305 }
306 Ok(RegistrationHandle {
307 agent: Some(self.inner.clone()),
308 ptr: unsafe { descriptor.as_ptr() } as usize,
309 size: descriptor.size(),
310 dev_id: descriptor.device_id(),
311 mem_type: descriptor.mem_type(),
312 })
313 }
314
315 pub fn query_mem(
325 &self,
326 descs: &RegDescList,
327 opt_args: Option<&OptArgs>,
328 ) -> Result<QueryResponseList, NixlError> {
329 let resp = QueryResponseList::new()?;
330
331 let status = {
332 let inner_guard = self.inner.write().unwrap();
333 unsafe {
334 nixl_capi_query_mem(
335 inner_guard.handle.as_ptr(),
336 descs.handle(),
337 resp.handle(),
338 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
339 )
340 }
341 };
342
343 match status {
344 NIXL_CAPI_SUCCESS => Ok(resp),
345 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
346 _ => Err(NixlError::BackendError),
347 }
348 }
349
350 pub fn get_local_md(&self) -> Result<Vec<u8>, NixlError> {
352 tracing::trace!("Getting local metadata");
353 let mut data = std::ptr::null_mut();
354 let mut len = 0;
355
356 let status = unsafe {
357 nixl_capi_get_local_md(
358 self.inner.write().unwrap().handle.as_ptr(),
359 &mut data as *mut *mut _,
360 &mut len,
361 )
362 };
363
364 let data = data as *const u8;
365
366 if data.is_null() {
367 tracing::trace!(
368 error = "invalid_data_pointer",
369 "Failed to get local metadata"
370 );
371 return Err(NixlError::InvalidDataPointer);
372 }
373
374 match status {
375 NIXL_CAPI_SUCCESS => {
376 let bytes = unsafe {
377 let slice = std::slice::from_raw_parts(data, len);
378 let vec = slice.to_vec();
379 libc::free(data as *mut libc::c_void);
380 vec
381 };
382 tracing::trace!(metadata.size = len, "Successfully retrieved local metadata");
383 Ok(bytes)
384 }
385 NIXL_CAPI_ERROR_INVALID_PARAM => {
386 tracing::error!(error = "invalid_param", "Failed to get local metadata");
387 Err(NixlError::InvalidParam)
388 }
389 _ => {
390 tracing::error!(error = "backend_error", "Failed to get local metadata");
391 Err(NixlError::BackendError)
392 }
393 }
394 }
395
396 pub fn get_local_partial_md(&self, descs: &RegDescList, opt_args: Option<&OptArgs>) -> Result<Vec<u8>, NixlError> {
406 tracing::trace!("Getting local partial metadata");
407 let mut data = std::ptr::null_mut();
408 let mut len: usize = 0;
409 let inner_guard = self.inner.write().unwrap();
410
411 let status = unsafe {
412 nixl_capi_get_local_partial_md(
413 inner_guard.handle.as_ptr(),
414 descs.handle(),
415 &mut data as *mut *mut _,
416 &mut len,
417 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
418 )
419 };
420 match status {
421 NIXL_CAPI_SUCCESS => {
422 let bytes = unsafe {
423 let slice = std::slice::from_raw_parts(data as *const u8, len);
424 let vec = slice.to_vec();
425 libc::free(data as *mut libc::c_void);
426 vec
427 };
428 tracing::trace!(metadata.size = len, "Successfully retrieved local partial metadata");
429 Ok(bytes)
430 }
431 NIXL_CAPI_ERROR_INVALID_PARAM => {
432 tracing::error!(error = "invalid_param", "Failed to get local partial metadata");
433 Err(NixlError::InvalidParam)
434 }
435 _ => {
436 tracing::error!(error = "backend_error", "Failed to get local partial metadata");
437 Err(NixlError::BackendError)
438 }
439 }
440 }
441
442 pub fn load_remote_md(&self, metadata: &[u8]) -> Result<String, NixlError> {
444 tracing::trace!(metadata.size = metadata.len(), "Loading remote metadata");
445 let mut agent_name = std::ptr::null_mut();
446
447 let status = unsafe {
448 nixl_capi_load_remote_md(
449 self.inner.write().unwrap().handle.as_ptr(),
450 metadata.as_ptr() as *const std::ffi::c_void,
451 metadata.len(),
452 &mut agent_name,
453 )
454 };
455
456 match status {
457 NIXL_CAPI_SUCCESS => {
458 let name = unsafe {
459 let c_str = std::ffi::CStr::from_ptr(agent_name);
460 let s = c_str.to_str().unwrap().to_string();
461 libc::free(agent_name as *mut libc::c_void);
462 s
463 };
464 self.inner.write().unwrap().remotes.insert(name.clone());
465 tracing::trace!(remote.agent = %name, "Successfully loaded remote metadata");
466 Ok(name)
467 }
468 NIXL_CAPI_ERROR_INVALID_PARAM => {
469 tracing::error!(error = "invalid_param", "Failed to load remote metadata");
470 Err(NixlError::InvalidParam)
471 }
472 _ => {
473 tracing::error!(error = "backend_error", "Failed to load remote metadata");
474 Err(NixlError::BackendError)
475 }
476 }
477 }
478
479 pub fn make_connection(&self, remote_agent: &str, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
480 let remote_agent = CString::new(remote_agent)?;
481 let inner_guard = self.inner.write().unwrap();
482
483 let status = unsafe {
484 nixl_capi_agent_make_connection(
485 inner_guard.handle.as_ptr(),
486 remote_agent.as_ptr(),
487 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
488 )
489 };
490
491 match status {
492 NIXL_CAPI_SUCCESS => Ok(()),
493 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
494 _ => Err(NixlError::BackendError),
495 }
496 }
497
498 pub fn prepare_xfer_dlist(
499 &self,
500 agent_name: &str,
501 descs: &XferDescList,
502 opt_args: Option<&OptArgs>,
503 ) -> Result<XferDlistHandle, NixlError> {
504 let c_agent_name = CString::new(agent_name)?;
505 let mut dlist_hndl = std::ptr::null_mut();
506 let inner_guard = self.inner.read().unwrap();
507
508 let status = unsafe {
509 nixl_capi_prep_xfer_dlist(
510 inner_guard.handle.as_ptr(),
511 c_agent_name.as_ptr(),
512 descs.handle(),
513 &mut dlist_hndl,
514 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
515 )
516 };
517
518 match status {
519 NIXL_CAPI_SUCCESS => Ok(XferDlistHandle::new(dlist_hndl, inner_guard.handle)),
520 _ => Err(NixlError::BackendError),
521 }
522 }
523
524 pub fn make_xfer_req(&self, operation: XferOp,
525 local_descs: &XferDlistHandle, local_indices: &[i32],
526 remote_descs: &XferDlistHandle, remote_indices: &[i32],
527 opt_args: Option<&OptArgs>) -> Result<XferRequest, NixlError> {
528 let mut req = std::ptr::null_mut();
529 let inner_guard = self.inner.read().unwrap();
530
531 let status = unsafe {
532 nixl_capi_make_xfer_req(
533 inner_guard.handle.as_ptr(),
534 operation as bindings::nixl_capi_xfer_op_t,
535 local_descs.handle(),
536 local_indices.as_ptr(),
537 local_indices.len() as usize,
538 remote_descs.handle(),
539 remote_indices.as_ptr(),
540 remote_indices.len() as usize,
541 &mut req,
542 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr())
543 )
544 };
545
546 match status {
547 NIXL_CAPI_SUCCESS => Ok(XferRequest::new(NonNull::new(req)
548 .ok_or(NixlError::FailedToCreateXferRequest)?,
549 self.inner.clone(),
550 )),
551 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
552 _ => Err(NixlError::BackendError),
553 }
554 }
555
556 pub fn check_remote_metadata(&self, remote_agent: &str, descs: Option<&XferDescList>) -> bool {
570 tracing::trace!(remote_agent = %remote_agent, "Checking remote metadata");
571
572 let c_remote_name = match CString::new(remote_agent) {
573 Ok(name) => name,
574 Err(_) => {
575 tracing::trace!(
576 error = "invalid_param",
577 remote_agent = %remote_agent,
578 "Invalid remote agent name"
579 );
580 return false;
581 }
582 };
583
584 let status = unsafe {
585 bindings::nixl_capi_check_remote_md(
586 self.inner.read().unwrap().handle.as_ptr(),
587 c_remote_name.as_ptr(),
588 descs.map_or(std::ptr::null_mut(), |d| d.as_ptr()),
589 )
590 };
591
592 match status {
593 NIXL_CAPI_SUCCESS => {
594 tracing::trace!(remote_agent = %remote_agent, "Remote metadata is available");
595 true
596 }
597 _ => {
598 tracing::trace!(remote_agent = %remote_agent, "Remote metadata is not available");
599 false
600 }
601 }
602 }
603
604 pub fn invalidate_remote_md(&self, remote_agent: &str) -> Result<(), NixlError> {
606 self.inner
607 .write()
608 .unwrap()
609 .invalidate_remote_md(remote_agent)
610 }
611
612 pub fn invalidate_all_remotes(&self) -> Result<(), NixlError> {
614 self.inner.write().unwrap().invalidate_all_remotes()
615 }
616
617 pub fn send_local_md(&self, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
624 tracing::trace!("Sending local metadata to etcd");
625 let inner_guard = self.inner.write().unwrap();
626 let status = unsafe {
627 bindings::nixl_capi_send_local_md(
628 inner_guard.handle.as_ptr(),
629 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
630 )
631 };
632
633 match status {
634 NIXL_CAPI_SUCCESS => {
635 tracing::trace!("Successfully sent local metadata to etcd");
636 Ok(())
637 }
638 NIXL_CAPI_ERROR_INVALID_PARAM => {
639 tracing::error!(
640 error = "invalid_param",
641 "Failed to send local metadata to etcd"
642 );
643 Err(NixlError::InvalidParam)
644 }
645 _ => {
646 tracing::error!(
647 error = "backend_error",
648 "Failed to send local metadata to etcd"
649 );
650 Err(NixlError::BackendError)
651 }
652 }
653 }
654
655 pub fn send_local_partial_md(&self, descs: &RegDescList, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
661 tracing::trace!("Sending local partial metadata to etcd");
662 let inner_guard = self.inner.write().unwrap();
663 let status = unsafe {
664 nixl_capi_send_local_partial_md(
665 inner_guard.handle.as_ptr(),
666 descs.handle(),
667 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
668 )
669 };
670 match status {
671 NIXL_CAPI_SUCCESS => {
672 tracing::trace!("Successfully sent local partial metadata to etcd");
673 Ok(())
674 }
675 NIXL_CAPI_ERROR_INVALID_PARAM => {
676 tracing::error!(error = "invalid_param", "Failed to send local partial metadata to etcd");
677 Err(NixlError::InvalidParam)
678 }
679 _ => Err(NixlError::BackendError)
680 }
681 }
682
683
684 pub fn fetch_remote_md(
693 &self,
694 remote_name: &str,
695 opt_args: Option<&OptArgs>,
696 ) -> Result<(), NixlError> {
697 tracing::trace!(remote_agent = %remote_name, "Fetching remote metadata from etcd");
698
699 let c_remote_name = CString::new(remote_name)?;
700 let inner_guard = self.inner.write().unwrap();
701
702 let status = unsafe {
703 bindings::nixl_capi_fetch_remote_md(
704 inner_guard.handle.as_ptr(),
705 c_remote_name.as_ptr(),
706 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
707 )
708 };
709
710 match status {
711 NIXL_CAPI_SUCCESS => {
712 self.inner
713 .write()
714 .unwrap()
715 .remotes
716 .insert(remote_name.to_string());
717 tracing::trace!(remote_agent = %remote_name, "Successfully fetched remote metadata from etcd");
718 Ok(())
719 }
720 NIXL_CAPI_ERROR_INVALID_PARAM => {
721 tracing::error!(error = "invalid_param", remote_agent = %remote_name, "Failed to fetch remote metadata from etcd");
722 Err(NixlError::InvalidParam)
723 }
724 _ => {
725 tracing::error!(error = "backend_error", remote_agent = %remote_name, "Failed to fetch remote metadata from etcd");
726 Err(NixlError::BackendError)
727 }
728 }
729 }
730
731 pub fn invalidate_local_md(&self, opt_args: Option<&OptArgs>) -> Result<(), NixlError> {
738 tracing::trace!("Invalidating local metadata in etcd");
739 let inner_guard = self.inner.write().unwrap();
740 let status = unsafe {
741 bindings::nixl_capi_invalidate_local_md(
742 inner_guard.handle.as_ptr(),
743 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
744 )
745 };
746
747 match status {
748 NIXL_CAPI_SUCCESS => {
749 tracing::trace!("Successfully invalidated local metadata in etcd");
750 Ok(())
751 }
752 NIXL_CAPI_ERROR_INVALID_PARAM => {
753 tracing::error!(
754 error = "invalid_param",
755 "Failed to invalidate local metadata in etcd"
756 );
757 Err(NixlError::InvalidParam)
758 }
759 _ => {
760 tracing::error!(
761 error = "backend_error",
762 "Failed to invalidate local metadata in etcd"
763 );
764 Err(NixlError::BackendError)
765 }
766 }
767 }
768
769 pub fn send_notification(
779 &self,
780 remote_agent: &str,
781 message: &[u8],
782 backend: Option<&Backend>,
783 ) -> Result<(), NixlError> {
784 tracing::trace!(remote_agent = %remote_agent, "Sending notification");
785
786 let c_remote_name = CString::new(remote_agent)?;
787 let inner_guard = self.inner.write().unwrap();
788
789 let opt_args = if backend.is_some() {
790 let mut args = OptArgs::new()?;
791 if let Some(b) = backend {
792 args.add_backend(b)?;
793 }
794 Some(args)
795 } else {
796 None
797 };
798
799 let status = unsafe {
800 nixl_capi_gen_notif(
801 inner_guard.handle.as_ptr(),
802 c_remote_name.as_ptr(),
803 message.as_ptr() as *const std::ffi::c_void,
804 message.len(),
805 opt_args
806 .as_ref()
807 .map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
808 )
809 };
810
811 match status {
812 NIXL_CAPI_SUCCESS => {
813 tracing::trace!(remote_agent = %remote_agent, "Successfully sent notification");
814 Ok(())
815 }
816 NIXL_CAPI_ERROR_INVALID_PARAM => {
817 tracing::error!(error = "invalid_param", remote_agent = %remote_agent, "Failed to send notification");
818 Err(NixlError::InvalidParam)
819 }
820 _ => {
821 tracing::error!(error = "backend_error", remote_agent = %remote_agent, "Failed to send notification");
822 Err(NixlError::BackendError)
823 }
824 }
825 }
826
827 pub fn create_xfer_req(
842 &self,
843 operation: XferOp,
844 local_descs: &XferDescList,
845 remote_descs: &XferDescList,
846 remote_agent: &str,
847 opt_args: Option<&OptArgs>,
848 ) -> Result<XferRequest, NixlError> {
849 let remote_agent = CString::new(remote_agent)?;
850 let mut req = std::ptr::null_mut();
851
852 let status = unsafe {
854 bindings::nixl_capi_create_xfer_req(
855 self.inner.read().unwrap().handle.as_ptr(),
856 operation as bindings::nixl_capi_xfer_op_t,
857 local_descs.handle(),
858 remote_descs.handle(),
859 remote_agent.as_ptr(),
860 &mut req,
861 opt_args.map_or(std::ptr::null_mut(), |args| args.inner.as_ptr()),
862 )
863 };
864
865 match status {
866 NIXL_CAPI_SUCCESS => {
867 let inner = NonNull::new(req).ok_or(NixlError::FailedToCreateXferRequest)?;
869 Ok(XferRequest::new(inner, self.inner.clone()))
870 }
871 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
872 _ => Err(NixlError::FailedToCreateXferRequest),
873 }
874 }
875
876 pub fn estimate_xfer_cost(
888 &self,
889 req: &XferRequest,
890 opt_args: Option<&OptArgs>,
891 ) -> Result<(i64, i64, CostMethod), NixlError> {
892 let mut duration_us: i64 = 0;
893 let mut err_margin_us: i64 = 0;
894 let mut method: u32 = 0;
895
896 let status = unsafe {
897 nixl_capi_estimate_xfer_cost(
898 self.inner.write().unwrap().handle.as_ptr(),
899 req.handle(),
900 opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
901 &mut duration_us,
902 &mut err_margin_us,
903 &mut method as *mut u32 as *mut bindings::nixl_capi_cost_t,
904 )
905 };
906
907 match status {
908 NIXL_CAPI_SUCCESS => Ok((duration_us, err_margin_us, CostMethod::from(method))),
909 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
910 _ => Err(NixlError::BackendError),
911 }
912 }
913
914 pub fn post_xfer_req(
929 &self,
930 req: &XferRequest,
931 opt_args: Option<&OptArgs>,
932 ) -> Result<bool, NixlError> {
933 tracing::trace!("Posting transfer request");
934 let status = unsafe {
935 nixl_capi_post_xfer_req(
936 self.inner.write().unwrap().handle.as_ptr(),
937 req.handle(),
938 opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
939 )
940 };
941
942 match status {
943 NIXL_CAPI_SUCCESS => {
944 tracing::trace!(
945 status = "completed",
946 "Transfer request completed immediately"
947 );
948 Ok(false)
949 }
950 NIXL_CAPI_IN_PROG => {
951 tracing::trace!(status = "in_progress", "Transfer request in progress");
952 Ok(true)
953 }
954 NIXL_CAPI_ERROR_INVALID_PARAM => {
955 tracing::error!(error = "invalid_param", "Failed to post transfer request");
956 Err(NixlError::InvalidParam)
957 }
958 _ => {
959 tracing::error!(error = "backend_error", "Failed to post transfer request");
960 Err(NixlError::BackendError)
961 }
962 }
963 }
964
965 pub fn get_xfer_status(&self, req: &XferRequest) -> Result<XferStatus, NixlError> {
972 let status = unsafe {
973 nixl_capi_get_xfer_status(self.inner.write().unwrap().handle.as_ptr(), req.handle())
974 };
975
976 match status {
977 NIXL_CAPI_SUCCESS => Ok(XferStatus::Success), NIXL_CAPI_IN_PROG => Ok(XferStatus::InProgress), NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
980 _ => Err(NixlError::BackendError),
981 }
982 }
983
984 pub fn query_xfer_backend(&self, req: &XferRequest) -> Result<Backend, NixlError> {
995 let mut backend = std::ptr::null_mut();
996 let inner_guard = self.inner.write().unwrap();
997 let status = unsafe {
998 nixl_capi_query_xfer_backend(
999 inner_guard.handle.as_ptr(),
1000 req.handle(),
1001 &mut backend
1002 )
1003 };
1004 match status {
1005 NIXL_CAPI_SUCCESS => {
1006 Ok(Backend{ inner: NonNull::new(backend).ok_or(NixlError::FailedToCreateBackend)? })
1007 }
1008 NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
1009 _ => Err(NixlError::BackendError),
1010 }
1011 }
1012
1013
1014 pub fn get_notifications(
1020 &self,
1021 notifs: &mut NotificationMap,
1022 opt_args: Option<&OptArgs>,
1023 ) -> Result<(), NixlError> {
1024 tracing::trace!("Getting notifications");
1025 let status = unsafe {
1026 nixl_capi_get_notifs(
1027 self.inner.write().unwrap().handle.as_ptr(),
1028 notifs.inner.as_ptr(),
1029 opt_args.map_or(ptr::null_mut(), |args| args.inner.as_ptr()),
1030 )
1031 };
1032
1033 match status {
1034 NIXL_CAPI_SUCCESS => {
1035 tracing::trace!("Successfully retrieved notifications");
1036 Ok(())
1037 }
1038 NIXL_CAPI_ERROR_INVALID_PARAM => {
1039 tracing::error!(error = "invalid_param", "Failed to get notifications");
1040 Err(NixlError::InvalidParam)
1041 }
1042 _ => {
1043 tracing::error!(error = "backend_error", "Failed to get notifications");
1044 Err(NixlError::BackendError)
1045 }
1046 }
1047 }
1048}
1049
1050#[derive(Debug)]
1052pub(crate) struct AgentInner {
1053 pub(crate) name: String,
1054 pub(crate) handle: NonNull<bindings::nixl_capi_agent_s>,
1055 pub(crate) backends: HashMap<String, NonNull<bindings::nixl_capi_backend_s>>,
1056 pub(crate) remotes: HashSet<String>,
1057}
1058
1059#[derive(Clone, Copy, Debug)]
1060pub enum ThreadSync {
1061 None,
1062 Strict,
1063 Rw,
1064 Default,
1065}
1066
1067#[derive(Clone, Debug)]
1068pub struct AgentConfig {
1069 pub enable_prog_thread: bool,
1070 pub enable_listen_thread: bool,
1071 pub listen_port: i32,
1072 pub thread_sync: ThreadSync,
1073 pub num_workers: u32,
1074 pub pthr_delay_us: u64,
1075 pub lthr_delay_us: u64,
1076 pub capture_telemetry: bool,
1077}
1078
1079impl Default for AgentConfig {
1080 fn default() -> Self {
1081 Self {
1082 enable_prog_thread: true,
1083 enable_listen_thread: false,
1084 listen_port: 0,
1085 thread_sync: ThreadSync::None,
1086 num_workers: 1,
1087 pthr_delay_us: 0,
1088 lthr_delay_us: 100_000,
1089 capture_telemetry: false,
1090 }
1091 }
1092}
1093
1094unsafe impl Send for AgentInner {}
1095unsafe impl Sync for AgentInner {}
1096
1097impl AgentInner {
1098 fn new(handle: NonNull<bindings::nixl_capi_agent_s>, name: String) -> Self {
1099 Self {
1100 name,
1101 handle,
1102 backends: HashMap::new(),
1103 remotes: HashSet::new(),
1104 }
1105 }
1106
1107 fn get_backend(&self, name: &str) -> Option<NonNull<bindings::nixl_capi_backend_s>> {
1108 self.backends.get(name).cloned()
1109 }
1110
1111 fn invalidate_remote_md(&mut self, remote_agent: &str) -> Result<(), NixlError> {
1112 unsafe {
1113 if self.remotes.remove(remote_agent) {
1114 nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote_agent.as_ptr().cast());
1115 } else {
1116 return Err(NixlError::InvalidParam);
1117 }
1118 }
1119 Ok(())
1120 }
1121
1122 fn invalidate_all_remotes(&mut self) -> Result<(), NixlError> {
1123 unsafe {
1124 for remote in self.remotes.drain() {
1125 nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote.as_ptr().cast());
1126 }
1127 }
1128 Ok(())
1129 }
1130}
1131
1132impl Drop for AgentInner {
1133 fn drop(&mut self) {
1134 tracing::trace!("Dropping NIXL agent");
1135 unsafe {
1136 for remote in self.remotes.iter() {
1138 tracing::trace!(remote.agent = %remote, "Invalidating remote agent");
1139 nixl_capi_invalidate_remote_md(self.handle.as_ptr(), remote.as_ptr().cast());
1140 }
1141
1142 for backend in self.backends.values() {
1144 tracing::trace!("Destroying backend");
1145 nixl_capi_destroy_backend(backend.as_ptr());
1146 }
1147
1148 nixl_capi_destroy_agent(self.handle.as_ptr());
1149 }
1150 tracing::trace!("NIXL agent dropped");
1151 }
1152}