async_ucx/ucp/
listener.rs

1use crate::Error;
2
3use super::*;
4use futures::channel::mpsc;
5use futures::stream::StreamExt;
6use std::mem::MaybeUninit;
7use std::net::SocketAddr;
8
9/// Listening on a specific address and accepting connections from clients.
10#[derive(Debug)]
11pub struct Listener {
12    handle: ucp_listener_h,
13    #[allow(unused)]
14    sender: Rc<mpsc::UnboundedSender<ConnectionRequest>>,
15    recver: mpsc::UnboundedReceiver<ConnectionRequest>,
16}
17
18/// An incoming connection request.
19///
20/// The request must be explicitly accepted by [Worker::accept] or rejected by [Listener::reject].
21#[derive(Debug)]
22#[must_use = "connection must be accepted or rejected"]
23pub struct ConnectionRequest {
24    pub(super) handle: ucp_conn_request_h,
25}
26
27// connection can be send to other thread and accepted on its worker
28unsafe impl Send for ConnectionRequest {}
29
30impl ConnectionRequest {
31    /// The address of the remote client that sent the connection request to the server.
32    pub fn remote_addr(&self) -> Result<SocketAddr, Error> {
33        #[allow(clippy::uninit_assumed_init)]
34        let mut attr = ucp_conn_request_attr {
35            field_mask: ucp_conn_request_attr_field::UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR.0
36                as u64,
37            ..unsafe { MaybeUninit::uninit().assume_init() }
38        };
39        let status = unsafe { ucp_conn_request_query(self.handle, &mut attr) };
40        Error::from_status(status)?;
41
42        let sockaddr =
43            unsafe { socket2::SockAddr::new(std::mem::transmute(attr.client_address), 8) };
44        Ok(sockaddr.as_socket().unwrap())
45    }
46}
47
48impl Listener {
49    pub(super) fn new(worker: &Rc<Worker>, addr: SocketAddr) -> Result<Self, Error> {
50        unsafe extern "C" fn connect_handler(conn_request: ucp_conn_request_h, arg: *mut c_void) {
51            trace!("connect request={:?}", conn_request);
52            let sender = &*(arg as *const mpsc::UnboundedSender<ConnectionRequest>);
53            let connection = ConnectionRequest {
54                handle: conn_request,
55            };
56            sender.unbounded_send(connection).unwrap();
57        }
58        let (sender, recver) = mpsc::unbounded();
59        let sender = Rc::new(sender);
60        let sockaddr = socket2::SockAddr::from(addr);
61        let params = ucp_listener_params_t {
62            field_mask: (ucp_listener_params_field::UCP_LISTENER_PARAM_FIELD_SOCK_ADDR
63                | ucp_listener_params_field::UCP_LISTENER_PARAM_FIELD_CONN_HANDLER)
64                .0 as u64,
65            sockaddr: ucs_sock_addr {
66                addr: sockaddr.as_ptr() as _,
67                addrlen: sockaddr.len(),
68            },
69            accept_handler: ucp_listener_accept_handler_t {
70                cb: None,
71                arg: null_mut(),
72            },
73            conn_handler: ucp_listener_conn_handler_t {
74                cb: Some(connect_handler),
75                arg: &*sender as *const mpsc::UnboundedSender<ConnectionRequest> as _,
76            },
77        };
78        let mut handle = MaybeUninit::uninit();
79        let status = unsafe { ucp_listener_create(worker.handle, &params, handle.as_mut_ptr()) };
80        Error::from_status(status)?;
81        trace!("create listener={:?}", handle);
82        Ok(Listener {
83            handle: unsafe { handle.assume_init() },
84            sender,
85            recver,
86        })
87    }
88
89    /// Returns the local socket address of this listener.
90    pub fn socket_addr(&self) -> Result<SocketAddr, Error> {
91        #[allow(clippy::uninit_assumed_init)]
92        let mut attr = ucp_listener_attr_t {
93            field_mask: ucp_listener_attr_field::UCP_LISTENER_ATTR_FIELD_SOCKADDR.0 as u64,
94            sockaddr: unsafe { MaybeUninit::uninit().assume_init() },
95        };
96        let status = unsafe { ucp_listener_query(self.handle, &mut attr) };
97        Error::from_status(status)?;
98        let sockaddr = unsafe { socket2::SockAddr::new(std::mem::transmute(attr.sockaddr), 8) };
99
100        Ok(sockaddr.as_socket().unwrap())
101    }
102
103    /// Waiting for the next connection request.
104    pub async fn next(&mut self) -> ConnectionRequest {
105        self.recver.next().await.unwrap()
106    }
107
108    /// Reject a connection.
109    pub fn reject(&self, conn: ConnectionRequest) -> Result<(), Error> {
110        let status = unsafe { ucp_listener_reject(self.handle, conn.handle) };
111        Error::from_status(status)
112    }
113}
114
115impl Drop for Listener {
116    fn drop(&mut self) {
117        trace!("destroy listener={:?}", self.handle);
118        unsafe { ucp_listener_destroy(self.handle) }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test_log::test]
127    fn accept() {
128        let (sender, recver) = tokio::sync::oneshot::channel();
129        let f1 = spawn_thread!(async move {
130            let context = Context::new().unwrap();
131            let worker = context.create_worker().unwrap();
132            tokio::task::spawn_local(worker.clone().polling());
133            let mut listener = worker
134                .create_listener("0.0.0.0:0".parse().unwrap())
135                .unwrap();
136            let listen_port = listener.socket_addr().unwrap().port();
137            sender.send(listen_port).unwrap();
138            let conn = listener.next().await;
139            let _endpoint = worker.accept(conn).await.unwrap();
140        });
141        spawn_thread!(async move {
142            let context = Context::new().unwrap();
143            let worker = context.create_worker().unwrap();
144            tokio::task::spawn_local(worker.clone().polling());
145            let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
146            let listen_port = recver.await.unwrap();
147            addr.set_port(listen_port);
148            let _endpoint = worker.connect_socket(addr).await.unwrap();
149        });
150        f1.join().unwrap();
151    }
152}