async_ucx/ucp/endpoint/
mod.rs

1use super::*;
2use std::cell::Cell;
3use std::future::Future;
4use std::net::SocketAddr;
5use std::pin::Pin;
6use std::rc::Weak;
7use std::sync::atomic::AtomicBool;
8use std::task::Poll;
9
10#[cfg(feature = "am")]
11mod am;
12mod rma;
13mod stream;
14mod tag;
15
16#[cfg(feature = "am")]
17pub use self::am::*;
18pub use self::rma::*;
19pub use self::stream::*;
20pub use self::tag::*;
21
22// State associate with ucp_ep_h
23// todo: Add a `get_user_data` to UCX
24#[derive(Debug)]
25struct EndpointInner {
26    closed: AtomicBool,
27    status: Cell<ucs_status_t>,
28    worker: Rc<Worker>,
29}
30
31impl EndpointInner {
32    fn new(worker: Rc<Worker>) -> Self {
33        EndpointInner {
34            closed: AtomicBool::new(false),
35            status: Cell::new(ucs_status_t::UCS_OK),
36            worker,
37        }
38    }
39
40    fn closed(self: &Rc<Self>) {
41        if self
42            .closed
43            .compare_exchange(
44                false,
45                true,
46                std::sync::atomic::Ordering::SeqCst,
47                std::sync::atomic::Ordering::SeqCst,
48            )
49            .is_ok()
50        {
51            // release a weak reference
52            let _weak = unsafe { Weak::from_raw(Rc::as_ptr(self)) };
53            self.set_status(ucs_status_t::UCS_ERR_CONNECTION_RESET);
54        }
55    }
56
57    fn is_closed(&self) -> bool {
58        self.closed.load(std::sync::atomic::Ordering::SeqCst)
59    }
60
61    // call from `err_handler` or `close`
62    #[inline]
63    fn set_status(&self, status: ucs_status_t) {
64        if status != ucs_status_t::UCS_OK {
65            self.status.set(status)
66        }
67    }
68
69    #[inline]
70    fn check(&self) -> Result<(), Error> {
71        let status = self.status.get();
72        Error::from_status(status)
73    }
74}
75
76/// Communication endpoint.
77#[derive(Debug, Clone)]
78pub struct Endpoint {
79    handle: ucp_ep_h,
80    inner: Rc<EndpointInner>,
81}
82
83impl Endpoint {
84    fn create(worker: &Rc<Worker>, mut params: ucp_ep_params) -> Result<Self, Error> {
85        let inner = Rc::new(EndpointInner::new(worker.clone()));
86        let weak = Rc::downgrade(&inner);
87
88        // ucp endpoint keep a weak reference to inner
89        // this reference will drop when endpoint is closed
90        let ptr = Weak::into_raw(weak);
91        unsafe extern "C" fn callback(arg: *mut c_void, ep: ucp_ep_h, status: ucs_status_t) {
92            let weak: Weak<EndpointInner> = Weak::from_raw(arg as _);
93            if let Some(inner) = weak.upgrade() {
94                inner.set_status(status);
95                // don't drop weak reference
96                std::mem::forget(weak);
97            } else {
98                // no strong rc, force close endpoint here
99                let status = ucp_ep_close_nb(ep, ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as _);
100                let _ = Error::from_ptr(status)
101                    .map_err(|err| error!("Force close endpoint failed, {}", err));
102            }
103        }
104
105        params.field_mask |= (ucp_ep_params_field::UCP_EP_PARAM_FIELD_USER_DATA
106            | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLER)
107            .0 as u64;
108        params.user_data = ptr as _;
109        params.err_handler = ucp_err_handler {
110            cb: Some(callback),
111            arg: std::ptr::null_mut(), // override by user_data
112        };
113
114        let mut handle = MaybeUninit::uninit();
115        let status = unsafe { ucp_ep_create(worker.handle, &params, handle.as_mut_ptr()) };
116        if let Err(err) = Error::from_status(status) {
117            // error happened, drop reference
118            let _weak = unsafe { Weak::from_raw(ptr as _) };
119            return Err(err);
120        }
121
122        let handle = unsafe { handle.assume_init() };
123        trace!("create endpoint={:?}", handle);
124        Ok(Self { handle, inner })
125    }
126
127    pub(super) async fn connect_socket(
128        worker: &Rc<Worker>,
129        addr: SocketAddr,
130    ) -> Result<Self, Error> {
131        let sockaddr = socket2::SockAddr::from(addr);
132        #[allow(invalid_value)]
133        #[allow(clippy::uninit_assumed_init)]
134        let params = ucp_ep_params {
135            field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_FLAGS
136                | ucp_ep_params_field::UCP_EP_PARAM_FIELD_SOCK_ADDR
137                | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE)
138                .0 as u64,
139            flags: ucp_ep_params_flags_field::UCP_EP_PARAMS_FLAGS_CLIENT_SERVER.0,
140            sockaddr: ucs_sock_addr {
141                addr: sockaddr.as_ptr() as _,
142                addrlen: sockaddr.len(),
143            },
144            err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
145            ..unsafe { MaybeUninit::uninit().assume_init() }
146        };
147        let endpoint = Endpoint::create(worker, params)?;
148
149        // Workaround for UCX bug: https://github.com/openucx/ucx/issues/6872
150        let buf = [0, 1, 2, 3];
151        endpoint.stream_send(&buf).await?;
152
153        Ok(endpoint)
154    }
155
156    pub(super) fn connect_addr(
157        worker: &Rc<Worker>,
158        addr: *const ucp_address_t,
159    ) -> Result<Self, Error> {
160        #[allow(invalid_value)]
161        #[allow(clippy::uninit_assumed_init)]
162        let params = ucp_ep_params {
163            field_mask: (ucp_ep_params_field::UCP_EP_PARAM_FIELD_REMOTE_ADDRESS
164                | ucp_ep_params_field::UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE)
165                .0 as u64,
166            address: addr,
167            err_mode: ucp_err_handling_mode_t::UCP_ERR_HANDLING_MODE_PEER,
168            ..unsafe { MaybeUninit::uninit().assume_init() }
169        };
170        Endpoint::create(worker, params)
171    }
172
173    pub(super) async fn accept(
174        worker: &Rc<Worker>,
175        connection: ConnectionRequest,
176    ) -> Result<Self, Error> {
177        #[allow(invalid_value)]
178        #[allow(clippy::uninit_assumed_init)]
179        let params = ucp_ep_params {
180            field_mask: ucp_ep_params_field::UCP_EP_PARAM_FIELD_CONN_REQUEST.0 as u64,
181            conn_request: connection.handle,
182            ..unsafe { MaybeUninit::uninit().assume_init() }
183        };
184        let endpoint = Endpoint::create(worker, params)?;
185
186        // Workaround for UCX bug: https://github.com/openucx/ucx/issues/6872
187        let mut buf = [MaybeUninit::<u8>::uninit(); 4];
188        endpoint.stream_recv(buf.as_mut()).await?;
189
190        Ok(endpoint)
191    }
192
193    /// Whether the endpoint is closed.
194    pub fn is_closed(&self) -> bool {
195        self.inner.is_closed()
196    }
197
198    /// Get the endpoint status.
199    pub fn get_status(&self) -> Result<(), Error> {
200        self.inner.check()
201    }
202
203    #[inline]
204    fn get_handle(&self) -> Result<ucp_ep_h, Error> {
205        self.inner.check()?;
206        Ok(self.handle)
207    }
208
209    /// Print endpoint information to stderr.
210    pub fn print_to_stderr(&self) {
211        if !self.inner.is_closed() {
212            unsafe { ucp_ep_print_info(self.handle, stderr) };
213        }
214    }
215
216    /// This routine flushes all outstanding AMO and RMA communications on the endpoint.
217    pub async fn flush(&self) -> Result<(), Error> {
218        let handle = self.get_handle()?;
219        trace!("flush: endpoint={:?}", handle);
220        unsafe extern "C" fn callback(request: *mut c_void, _status: ucs_status_t) {
221            trace!("flush: complete");
222            let request = &mut *(request as *mut Request);
223            request.waker.wake();
224        }
225        let status = unsafe { ucp_ep_flush_nb(handle, 0, Some(callback)) };
226        if status.is_null() {
227            trace!("flush: complete");
228            Ok(())
229        } else if UCS_PTR_IS_PTR(status) {
230            RequestHandle {
231                ptr: status,
232                poll_fn: poll_normal,
233            }
234            .await
235        } else {
236            Error::from_ptr(status)
237        }
238    }
239
240    /// This routine close connection.
241    pub async fn close(&self, force: bool) -> Result<(), Error> {
242        if force && self.is_closed() {
243            return Ok(());
244        } else if !force {
245            self.get_status()?;
246        }
247
248        trace!("close: endpoint={:?}", self.handle);
249        let mode = if force {
250            ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32
251        } else {
252            ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FLUSH as u32
253        };
254        let status = unsafe { ucp_ep_close_nb(self.handle, mode) };
255        if status.is_null() {
256            trace!("close: complete");
257            self.inner.closed();
258            Ok(())
259        } else if UCS_PTR_IS_PTR(status) {
260            let result = loop {
261                if let Poll::Ready(result) = unsafe { poll_normal(status) } {
262                    unsafe { ucp_request_free(status as _) };
263                    break result;
264                } else {
265                    futures_lite::future::yield_now().await;
266                }
267            };
268            if result.is_ok() {
269                self.inner.closed();
270            }
271
272            result
273        } else {
274            // todo: maybe this shouldn't treat as error ...
275            let status = UCS_PTR_RAW_STATUS(status);
276            warn!("close endpoint get error: {:?}", status);
277            Error::from_status(status)
278        }
279    }
280
281    /// Get the worker of the endpoint.
282    pub fn worker(&self) -> &Rc<Worker> {
283        &self.inner.worker
284    }
285
286    #[allow(unused)]
287    #[cfg(test)]
288    fn get_rc(&self) -> (usize, usize) {
289        (Rc::strong_count(&self.inner), Rc::weak_count(&self.inner))
290    }
291}
292
293impl Drop for Endpoint {
294    fn drop(&mut self) {
295        if !self.inner.is_closed() {
296            trace!("destroy endpoint={:?}", self.handle);
297            let status = unsafe {
298                ucp_ep_close_nb(
299                    self.handle,
300                    ucp_ep_close_mode::UCP_EP_CLOSE_MODE_FORCE as u32,
301                )
302            };
303            let _ = Error::from_ptr(status).map_err(|err| error!("Failed to force close, {}", err));
304            self.inner.closed();
305        }
306    }
307}
308
309/// A handle to the request returned from async IO functions.
310struct RequestHandle<T> {
311    ptr: ucs_status_ptr_t,
312    poll_fn: unsafe fn(ucs_status_ptr_t) -> Poll<T>,
313}
314
315impl<T> Future for RequestHandle<T> {
316    type Output = T;
317    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
318        if let ret @ Poll::Ready(_) = unsafe { (self.poll_fn)(self.ptr) } {
319            return ret;
320        }
321        let request = unsafe { &mut *(self.ptr as *mut Request) };
322        request.waker.register(cx.waker());
323        unsafe { (self.poll_fn)(self.ptr) }
324    }
325}
326
327impl<T> Drop for RequestHandle<T> {
328    fn drop(&mut self) {
329        trace!("request free: {:?}", self.ptr);
330        unsafe { ucp_request_free(self.ptr as _) };
331    }
332}
333
334unsafe fn poll_normal(ptr: ucs_status_ptr_t) -> Poll<Result<(), Error>> {
335    let status = ucp_request_check_status(ptr as _);
336    if status == ucs_status_t::UCS_INPROGRESS {
337        Poll::Pending
338    } else {
339        Poll::Ready(Error::from_status(status))
340    }
341}