async_ucx/ucp/endpoint/
mod.rs1use super::*;
2use std::cell::Cell;
3use std::future::Future;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::rc::Weak;
7use std::sync::atomic::AtomicBool;
8use std::task::Poll;
9
10#[cfg(feature = "am")]
11mod am;
12mod rma;
13mod stream;
14mod tag;
15
16#[cfg(feature = "am")]
17pub use self::am::*;
18pub use self::rma::*;
19pub use self::stream::*;
20pub use self::tag::*;
21
22#[derive(Debug)]
25struct EndpointInner {
26 closed: AtomicBool,
27 status: Cell<ucs_status_t>,
28 worker: Rc<Worker>,
29}
30
31impl EndpointInner {
32 fn new(worker: Rc<Worker>) -> Self {
33 EndpointInner {
34 closed: AtomicBool::new(false),
35 status: Cell::new(ucs_status_t::UCS_OK),
36 worker,
37 }
38 }
39
40 fn closed(self: &Rc<Self>) {
41 if self
42 .closed
43 .compare_exchange(
44 false,
45 true,
46 std::sync::atomic::Ordering::SeqCst,
47 std::sync::atomic::Ordering::SeqCst,
48 )
49 .is_ok()
50 {
51 let _weak = unsafe { Weak::from_raw(Rc::as_ptr(self)) };
53 self.set_status(ucs_status_t::UCS_ERR_CONNECTION_RESET);
54 }
55 }
56
57 fn is_closed(&self) -> bool {
58 self.closed.load(std::sync::atomic::Ordering::SeqCst)
59 }
60
61 #[inline]
63 fn set_status(&self, status: ucs_status_t) {
64 if status != ucs_status_t::UCS_OK {
65 self.status.set(status)
66 }
67 }
68
69 #[inline]
70 fn check(&self) -> Result<(), Error> {
71 let status = self.status.get();
72 Error::from_status(status)
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct Endpoint {
79 handle: ucp_ep_h,
80 inner: Rc<EndpointInner>,
81}
82
83impl Endpoint {
84 fn create(worker: &Rc<Worker>, mut params: ucp_ep_params) -> Result<Self, Error> {
85 let inner = Rc::new(EndpointInner::new(worker.clone()));
86 let weak = Rc::downgrade(&inner);
87
88 let ptr = Weak::into_raw(weak);
91 unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) {
92 let weak: Weak<EndpointInner> = Weak::from_raw(arg as _);
93 if let Some(inner) = weak.upgrade() {
94 inner.set_status(status);
95 std::mem::forget(weak);
97 } else {
98 let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _);
100 let _ = Error::from_ptr(status)
101 .map_err(|err| error!("Force close endpoint failed, {}", err));
102 }
103 }
104
105 params.field_mask |= (ucp_ep_params_field::UCP_EP_PARAM_FIELD_USER_DATA
106 | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLER)
107 .0 as u64;
108 params.user_data = ptr as _;
109 params.err_handler = ucp_err_handler {
110 cb: Some(callback),
111 arg: std::ptr::null_mut(), };
113
114 let mut handle = MaybeUninit::uninit();
115 let status = unsafe { ucp_ep_create(worker.handle, ¶ms, handle.as_mut_ptr()) };
116 if let Err(err) = Error::from_status(status) {
117 let _weak = unsafe { Weak::from_raw(ptr as _) };
119 return Err(err);
120 }
121
122 let handle = unsafe { handle.assume_init() };
123 trace!("create endpoint={:?}", handle);
124 Ok(Self { handle, inner })
125 }
126
127 pub(super) async fn connect_socket(
128 worker: &Rc<Worker>,
129 addr: SocketAddr,
130 ) -> Result<Self, Error> {
131 let sockaddr = socket2::SockAddr::from(addr);
132 #[allow(invalid_value)]
133 #[allow(clippy::uninit_assumed_init)]
134 let params = ucp_ep_params {
135 field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_FLAGS
136 | ucp_ep_params_field::UCP_EP_PARAM_FIELD_SOCK_ADDR
137 | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE)
138 .0 as u64,
139 flags: ucp_ep_params_flags_field::UCP_EP_PARAMS_FLAGS_CLIENT_SERVER.0,
140 sockaddr: ucs_sock_addr {
141 addr: sockaddr.as_ptr() as _,
142 addrlen: sockaddr.len(),
143 },
144 err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
145 ..unsafe { MaybeUninit::uninit().assume_init() }
146 };
147 let endpoint = Endpoint::create(worker, params)?;
148
149 let buf = [0, 1, 2, 3];
151 endpoint.stream_send(&buf).await?;
152
153 Ok(endpoint)
154 }
155
156 pub(super) fn connect_addr(
157 worker: &Rc<Worker>,
158 addr: *const ucp_address_t,
159 ) -> Result<Self, Error> {
160 #[allow(invalid_value)]
161 #[allow(clippy::uninit_assumed_init)]
162 let params = ucp_ep_params {
163 field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_REMOTE_ADDRESS
164 | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE)
165 .0 as u64,
166 address: addr,
167 err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
168 ..unsafe { MaybeUninit::uninit().assume_init() }
169 };
170 Endpoint::create(worker, params)
171 }
172
173 pub(super) async fn accept(
174 worker: &Rc<Worker>,
175 connection: ConnectionRequest,
176 ) -> Result<Self, Error> {
177 #[allow(invalid_value)]
178 #[allow(clippy::uninit_assumed_init)]
179 let params = ucp_ep_params {
180 field_mask: ucp_ep_params_field::UCP_EP_PARAM_FIELD_CONN_REQUEST.0 as u64,
181 conn_request: connection.handle,
182 ..unsafe { MaybeUninit::uninit().assume_init() }
183 };
184 let endpoint = Endpoint::create(worker, params)?;
185
186 let mut buf = [MaybeUninit::<u8>::uninit(); 4];
188 endpoint.stream_recv(buf.as_mut()).await?;
189
190 Ok(endpoint)
191 }
192
193 pub fn is_closed(&self) -> bool {
195 self.inner.is_closed()
196 }
197
198 pub fn get_status(&self) -> Result<(), Error> {
200 self.inner.check()
201 }
202
203 #[inline]
204 fn get_handle(&self) -> Result<ucp_ep_h, Error> {
205 self.inner.check()?;
206 Ok(self.handle)
207 }
208
209 pub fn print_to_stderr(&self) {
211 if !self.inner.is_closed() {
212 unsafe { ucp_ep_print_info(self.handle, stderr) };
213 }
214 }
215
216 pub async fn flush(&self) -> Result<(), Error> {
218 let handle = self.get_handle()?;
219 trace!("flush: endpoint={:?}", handle);
220 unsafe extern "C" fn callback(request: *mut c_void, _status: ucs_status_t) {
221 trace!("flush: complete");
222 let request = &mut *(request as *mut Request);
223 request.waker.wake();
224 }
225 let status = unsafe { ucp_ep_flush_nb(handle, 0, Some(callback)) };
226 if status.is_null() {
227 trace!("flush: complete");
228 Ok(())
229 } else if UCS_PTR_IS_PTR(status) {
230 RequestHandle {
231 ptr: status,
232 poll_fn: poll_normal,
233 }
234 .await
235 } else {
236 Error::from_ptr(status)
237 }
238 }
239
240 pub async fn close(&self, force: bool) -> Result<(), Error> {
242 if force && self.is_closed() {
243 return Ok(());
244 } else if !force {
245 self.get_status()?;
246 }
247
248 trace!("close: endpoint={:?}", self.handle);
249 let mode = if force {
250 ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32
251 } else {
252 ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32
253 };
254 let status = unsafe { ucp_ep_close_nb(self.handle, mode) };
255 if status.is_null() {
256 trace!("close: complete");
257 self.inner.closed();
258 Ok(())
259 } else if UCS_PTR_IS_PTR(status) {
260 let result = loop {
261 if let Poll::Ready(result) = unsafe { poll_normal(status) } {
262 unsafe { ucp_request_free(status as _) };
263 break result;
264 } else {
265 futures_lite::future::yield_now().await;
266 }
267 };
268 if result.is_ok() {
269 self.inner.closed();
270 }
271
272 result
273 } else {
274 let status = UCS_PTR_RAW_STATUS(status);
276 warn!("close endpoint get error: {:?}", status);
277 Error::from_status(status)
278 }
279 }
280
281 pub fn worker(&self) -> &Rc<Worker> {
283 &self.inner.worker
284 }
285
286 #[allow(unused)]
287 #[cfg(test)]
288 fn get_rc(&self) -> (usize, usize) {
289 (Rc::strong_count(&self.inner), Rc::weak_count(&self.inner))
290 }
291}
292
293impl Drop for Endpoint {
294 fn drop(&mut self) {
295 if !self.inner.is_closed() {
296 trace!("destroy endpoint={:?}", self.handle);
297 let status = unsafe {
298 ucp_ep_close_nb(
299 self.handle,
300 ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32,
301 )
302 };
303 let _ = Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err));
304 self.inner.closed();
305 }
306 }
307}
308
309struct RequestHandle<T> {
311 ptr: ucs_status_ptr_t,
312 poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll<T>,
313}
314
315impl<T> Future for RequestHandle<T> {
316 type Output = T;
317 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
318 if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } {
319 return ret;
320 }
321 let request = unsafe { &mut *(self.ptr as *mut Request) };
322 request.waker.register(cx.waker());
323 unsafe { (self.poll_fn)(self.ptr) }
324 }
325}
326
327impl<T> Drop for RequestHandle<T> {
328 fn drop(&mut self) {
329 trace!("request free: {:?}", self.ptr);
330 unsafe { ucp_request_free(self.ptr as _) };
331 }
332}
333
334unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll<Result<(), Error>> {
335 let status = ucp_request_check_status(ptr as _);
336 if status == ucs_status_t::UCS_INPROGRESS {
337 Poll::Pending
338 } else {
339 Poll::Ready(Error::from_status(status))
340 }
341}