1use super::*;
2use derivative::*;
3#[cfg(feature = "am")]
4use std::collections::HashMap;
5use std::net::SocketAddr;
6use std::os::unix::io::AsRawFd;
7#[cfg(feature = "am")]
8use std::sync::RwLock;
9#[cfg(feature = "event")]
10use tokio::io::unix::AsyncFd;
11
12#[derive(Derivative)]
14#[derivative(Debug)]
15pub struct Worker {
16 pub(super) handle: ucp_worker_h,
17 context: Arc<Context>,
18 #[cfg(feature = "am")]
19 #[derivative(Debug = "ignore")]
20 pub(crate) am_streams: RwLock<HashMap<u16, Rc<AmStreamInner>>>,
21}
22
23impl Drop for Worker {
24 fn drop(&mut self) {
25 unsafe { ucp_worker_destroy(self.handle) }
26 }
27}
28
29impl Worker {
30 pub(super) fn new(context: &Arc<Context>) -> Result<Rc<Self>, Error> {
31 let mut params = MaybeUninit::<ucp_worker_params_t>::uninit();
32 unsafe {
33 (*params.as_mut_ptr()).field_mask =
34 ucp_worker_params_field::UCP_WORKER_PARAM_FIELD_THREAD_MODE.0 as _;
35 (*params.as_mut_ptr()).thread_mode = ucs_thread_mode_t::UCS_THREAD_MODE_SINGLE;
36 };
37 let mut handle = MaybeUninit::uninit();
38 let status =
39 unsafe { ucp_worker_create(context.handle, params.as_ptr(), handle.as_mut_ptr()) };
40 Error::from_status(status)?;
41
42 Ok(Rc::new(Worker {
43 handle: unsafe { handle.assume_init() },
44 context: context.clone(),
45 #[cfg(feature = "am")]
46 am_streams: RwLock::new(HashMap::new()),
47 }))
48 }
49
50 pub async fn polling(self: Rc<Self>) {
52 while Rc::strong_count(&self) > 1 {
53 while self.progress() != 0 {}
54 futures_lite::future::yield_now().await;
55 }
56 }
57
58 #[cfg(feature = "event")]
63 pub async fn event_poll(self: Rc<Self>) -> Result<(), Error> {
64 let fd = self.event_fd()?;
65 let wait_fd = AsyncFd::new(fd).unwrap();
66 while Rc::strong_count(&self) > 1 {
67 while self.progress() != 0 {}
68 if self.arm().unwrap() {
69 let mut ready = wait_fd.readable().await.unwrap();
70 ready.clear_ready();
71 }
72 }
73
74 Ok(())
75 }
76
77 pub fn print_to_stderr(&self) {
82 unsafe { ucp_worker_print_info(self.handle, stderr) };
83 }
84
85 pub fn thread_mode(&self) -> ucs_thread_mode_t {
87 let mut attr = MaybeUninit::<ucp_worker_attr>::uninit();
88 unsafe { &mut *attr.as_mut_ptr() }.field_mask =
89 ucp_worker_attr_field::UCP_WORKER_ATTR_FIELD_THREAD_MODE.0 as u64;
90 let status = unsafe { ucp_worker_query(self.handle, attr.as_mut_ptr()) };
91 assert_eq!(status, ucs_status_t::UCS_OK);
92 let attr = unsafe { attr.assume_init() };
93 attr.thread_mode
94 }
95
96 pub fn address(&self) -> Result<WorkerAddress<'_>, Error> {
101 let mut handle = MaybeUninit::uninit();
102 let mut length = MaybeUninit::uninit();
103 let status = unsafe {
104 ucp_worker_get_address(self.handle, handle.as_mut_ptr(), length.as_mut_ptr())
105 };
106 Error::from_status(status)?;
107
108 Ok(WorkerAddress {
109 handle: unsafe { handle.assume_init() },
110 length: unsafe { length.assume_init() } as usize,
111 worker: self,
112 })
113 }
114
115 pub fn create_listener(self: &Rc<Self>, addr: SocketAddr) -> Result<Listener, Error> {
117 Listener::new(self, addr)
118 }
119
120 pub fn connect_addr(self: &Rc<Self>, addr: &WorkerAddress) -> Result<Endpoint, Error> {
122 Endpoint::connect_addr(self, addr.handle)
123 }
124
125 pub async fn connect_socket(self: &Rc<Self>, addr: SocketAddr) -> Result<Endpoint, Error> {
127 Endpoint::connect_socket(self, addr).await
128 }
129
130 pub async fn accept(self: &Rc<Self>, connection: ConnectionRequest) -> Result<Endpoint, Error> {
132 Endpoint::accept(self, connection).await
133 }
134
135 pub fn wait(&self) -> Result<(), Error> {
137 let status = unsafe { ucp_worker_wait(self.handle) };
138 Error::from_status(status)
139 }
140
141 pub fn arm(&self) -> Result<bool, Error> {
145 let status = unsafe { ucp_worker_arm(self.handle) };
146 match status {
147 ucs_status_t::UCS_OK => Ok(true),
148 ucs_status_t::UCS_ERR_BUSY => Ok(false),
149 status => Err(Error::from_error(status)),
150 }
151 }
152
153 pub fn progress(&self) -> u32 {
155 unsafe { ucp_worker_progress(self.handle) }
156 }
157
158 pub fn event_fd(&self) -> Result<i32, Error> {
160 let mut fd = MaybeUninit::uninit();
161 let status = unsafe { ucp_worker_get_efd(self.handle, fd.as_mut_ptr()) };
162 Error::from_status(status)?;
163
164 unsafe { Ok(fd.assume_init()) }
165 }
166
167 pub fn flush(&self) {
169 let status = unsafe { ucp_worker_flush(self.handle) };
170 assert_eq!(status, ucs_status_t::UCS_OK);
171 }
172}
173
174impl AsRawFd for Worker {
175 fn as_raw_fd(&self) -> i32 {
176 self.event_fd().unwrap()
177 }
178}
179
180#[derive(Debug)]
182pub struct WorkerAddress<'a> {
183 handle: *mut ucp_address_t,
184 length: usize,
185 worker: &'a Worker,
186}
187
188impl<'a> AsRef<[u8]> for WorkerAddress<'a> {
189 fn as_ref(&self) -> &[u8] {
190 unsafe { std::slice::from_raw_parts(self.handle as *const u8, self.length) }
191 }
192}
193
194impl<'a> Drop for WorkerAddress<'a> {
195 fn drop(&mut self) {
196 unsafe { ucp_worker_release_address(self.worker.handle, self.handle) }
197 }
198}