async_ucx/ucp/
worker.rs

1use super::*;
2use derivative::*;
3#[cfg(feature = "am")]
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::os::unix::io::AsRawFd;
7#[cfg(feature = "am")]
8use std::sync::RwLock;
9#[cfg(feature = "event")]
10use tokio::io::unix::AsyncFd;
11
12/// An object representing the communication context.
13#[derive(Derivative)]
14#[derivative(Debug)]
15pub struct Worker {
16    pub(super) handle: ucp_worker_h,
17    context: Arc<Context>,
18    #[cfg(feature = "am")]
19    #[derivative(Debug = "ignore")]
20    pub(crate) am_streams: RwLock<HashMap<u16, Rc<AmStreamInner>>>,
21}
22
23impl Drop for Worker {
24    fn drop(&mut self) {
25        unsafe { ucp_worker_destroy(self.handle) }
26    }
27}
28
29impl Worker {
30    pub(super) fn new(context: &Arc<Context>) -> Result<Rc<Self>, Error> {
31        let mut params = MaybeUninit::<ucp_worker_params_t>::uninit();
32        unsafe {
33            (*params.as_mut_ptr()).field_mask =
34                ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_THREAD_MODE.0 as _;
35            (*params.as_mut_ptr()).thread_mode = ucs_thread_mode_t::UCS_THREAD_MODE_SINGLE;
36        };
37        let mut handle = MaybeUninit::uninit();
38        let status =
39            unsafe { ucp_worker_create(context.handle, params.as_ptr(), handle.as_mut_ptr()) };
40        Error::from_status(status)?;
41
42        Ok(Rc::new(Worker {
43            handle: unsafe { handle.assume_init() },
44            context: context.clone(),
45            #[cfg(feature = "am")]
46            am_streams: RwLock::new(HashMap::new()),
47        }))
48    }
49
50    /// Make progress on the worker.
51    pub async fn polling(self: Rc<Self>) {
52        while Rc::strong_count(&self) > 1 {
53            while self.progress() != 0 {}
54            futures_lite::future::yield_now().await;
55        }
56    }
57
58    /// Wait event then make progress.
59    ///
60    /// This function register `event_fd` on tokio's event loop and wait `event_fd` become readable,
61    ////  then call progress function.
62    #[cfg(feature = "event")]
63    pub async fn event_poll(self: Rc<Self>) -> Result<(), Error> {
64        let fd = self.event_fd()?;
65        let wait_fd = AsyncFd::new(fd).unwrap();
66        while Rc::strong_count(&self) > 1 {
67            while self.progress() != 0 {}
68            if self.arm().unwrap() {
69                let mut ready = wait_fd.readable().await.unwrap();
70                ready.clear_ready();
71            }
72        }
73
74        Ok(())
75    }
76
77    /// Prints information about the worker.
78    ///
79    /// Including protocols being used, thresholds, UCT transport methods,
80    /// and other useful information associated with the worker.
81    pub fn print_to_stderr(&self) {
82        unsafe { ucp_worker_print_info(self.handle, stderr) };
83    }
84
85    /// Thread safe level of the context.
86    pub fn thread_mode(&self) -> ucs_thread_mode_t {
87        let mut attr = MaybeUninit::<ucp_worker_attr>::uninit();
88        unsafe { &mut *attr.as_mut_ptr() }.field_mask =
89            ucp_worker_attr_field::UCP_WORKER_ATTR_FIELD_THREAD_MODE.0 as u64;
90        let status = unsafe { ucp_worker_query(self.handle, attr.as_mut_ptr()) };
91        assert_eq!(status, ucs_status_t::UCS_OK);
92        let attr = unsafe { attr.assume_init() };
93        attr.thread_mode
94    }
95
96    /// Get the address of the worker object.
97    ///
98    /// This address can be passed to remote instances of the UCP library
99    /// in order to connect to this worker.
100    pub fn address(&self) -> Result<WorkerAddress<'_>, Error> {
101        let mut handle = MaybeUninit::uninit();
102        let mut length = MaybeUninit::uninit();
103        let status = unsafe {
104            ucp_worker_get_address(self.handle, handle.as_mut_ptr(), length.as_mut_ptr())
105        };
106        Error::from_status(status)?;
107
108        Ok(WorkerAddress {
109            handle: unsafe { handle.assume_init() },
110            length: unsafe { length.assume_init() } as usize,
111            worker: self,
112        })
113    }
114
115    /// Create a new [`Listener`].
116    pub fn create_listener(self: &Rc<Self>, addr: SocketAddr) -> Result<Listener, Error> {
117        Listener::new(self, addr)
118    }
119
120    /// Connect to a remote worker by address.
121    pub fn connect_addr(self: &Rc<Self>, addr: &WorkerAddress) -> Result<Endpoint, Error> {
122        Endpoint::connect_addr(self, addr.handle)
123    }
124
125    /// Connect to a remote listener.
126    pub async fn connect_socket(self: &Rc<Self>, addr: SocketAddr) -> Result<Endpoint, Error> {
127        Endpoint::connect_socket(self, addr).await
128    }
129
130    /// Accept a connection request.
131    pub async fn accept(self: &Rc<Self>, connection: ConnectionRequest) -> Result<Endpoint, Error> {
132        Endpoint::accept(self, connection).await
133    }
134
135    /// Waits (blocking) until an event has happened.
136    pub fn wait(&self) -> Result<(), Error> {
137        let status = unsafe { ucp_worker_wait(self.handle) };
138        Error::from_status(status)
139    }
140
141    /// This needs to be called before waiting on each notification on this worker.
142    ///
143    /// Returns 'true' if one can wait for events (sleep mode).
144    pub fn arm(&self) -> Result<bool, Error> {
145        let status = unsafe { ucp_worker_arm(self.handle) };
146        match status {
147            ucs_status_t::UCS_OK => Ok(true),
148            ucs_status_t::UCS_ERR_BUSY => Ok(false),
149            status => Err(Error::from_error(status)),
150        }
151    }
152
153    /// Explicitly progresses all communication operations on a worker.
154    pub fn progress(&self) -> u32 {
155        unsafe { ucp_worker_progress(self.handle) }
156    }
157
158    /// Returns a valid file descriptor for polling functions.
159    pub fn event_fd(&self) -> Result<i32, Error> {
160        let mut fd = MaybeUninit::uninit();
161        let status = unsafe { ucp_worker_get_efd(self.handle, fd.as_mut_ptr()) };
162        Error::from_status(status)?;
163
164        unsafe { Ok(fd.assume_init()) }
165    }
166
167    /// This routine flushes all outstanding AMO and RMA communications on the worker.
168    pub fn flush(&self) {
169        let status = unsafe { ucp_worker_flush(self.handle) };
170        assert_eq!(status, ucs_status_t::UCS_OK);
171    }
172}
173
174impl AsRawFd for Worker {
175    fn as_raw_fd(&self) -> i32 {
176        self.event_fd().unwrap()
177    }
178}
179
180/// The address of the worker object.
181#[derive(Debug)]
182pub struct WorkerAddress<'a> {
183    handle: *mut ucp_address_t,
184    length: usize,
185    worker: &'a Worker,
186}
187
188impl<'a> AsRef<[u8]> for WorkerAddress<'a> {
189    fn as_ref(&self) -> &[u8] {
190        unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.length) }
191    }
192}
193
194impl<'a> Drop for WorkerAddress<'a> {
195    fn drop(&mut self) {
196        unsafe { ucp_worker_release_address(self.worker.handle, self.handle) }
197    }
198}