async_ucx/ucp/
listener.rs1use crate::Error;
2
3use super::*;
4use futures::channel::mpsc;
5use futures::stream::StreamExt;
6use std::mem::MaybeUninit;
7use std::net::SocketAddr;
8
9#[derive(Debug)]
11pub struct Listener {
12 handle: ucp_listener_h,
13 #[allow(unused)]
14 sender: Rc<mpsc::UnboundedSender<ConnectionRequest>>,
15 recver: mpsc::UnboundedReceiver<ConnectionRequest>,
16}
17
18#[derive(Debug)]
22#[must_use = "connection must be accepted or rejected"]
23pub struct ConnectionRequest {
24 pub(super) handle: ucp_conn_request_h,
25}
26
27unsafe impl Send for ConnectionRequest {}
29
30impl ConnectionRequest {
31 pub fn remote_addr(&self) -> Result<SocketAddr, Error> {
33 #[allow(clippy::uninit_assumed_init)]
34 let mut attr = ucp_conn_request_attr {
35 field_mask: ucp_conn_request_attr_field::UCP_CONN_REQUEST_ATTR_FIELD_CLIENT_ADDR.0
36 as u64,
37 ..unsafe { MaybeUninit::uninit().assume_init() }
38 };
39 let status = unsafe { ucp_conn_request_query(self.handle, &mut attr) };
40 Error::from_status(status)?;
41
42 let sockaddr =
43 unsafe { socket2::SockAddr::new(std::mem::transmute(attr.client_address), 8) };
44 Ok(sockaddr.as_socket().unwrap())
45 }
46}
47
48impl Listener {
49 pub(super) fn new(worker: &Rc<Worker>, addr: SocketAddr) -> Result<Self, Error> {
50 unsafe extern "C" fn connect_handler(conn_request: ucp_conn_request_h, arg: *mut c_void) {
51 trace!("connect request={:?}", conn_request);
52 let sender = &*(arg as *const mpsc::UnboundedSender<ConnectionRequest>);
53 let connection = ConnectionRequest {
54 handle: conn_request,
55 };
56 sender.unbounded_send(connection).unwrap();
57 }
58 let (sender, recver) = mpsc::unbounded();
59 let sender = Rc::new(sender);
60 let sockaddr = socket2::SockAddr::from(addr);
61 let params = ucp_listener_params_t {
62 field_mask: (ucp_listener_params_field::UCP_LISTENER_PARAM_FIELD_SOCK_ADDR
63 | ucp_listener_params_field::UCP_LISTENER_PARAM_FIELD_CONN_HANDLER)
64 .0 as u64,
65 sockaddr: ucs_sock_addr {
66 addr: sockaddr.as_ptr() as _,
67 addrlen: sockaddr.len(),
68 },
69 accept_handler: ucp_listener_accept_handler_t {
70 cb: None,
71 arg: null_mut(),
72 },
73 conn_handler: ucp_listener_conn_handler_t {
74 cb: Some(connect_handler),
75 arg: &*sender as *const mpsc::UnboundedSender<ConnectionRequest> as _,
76 },
77 };
78 let mut handle = MaybeUninit::uninit();
79 let status = unsafe { ucp_listener_create(worker.handle, ¶ms, handle.as_mut_ptr()) };
80 Error::from_status(status)?;
81 trace!("create listener={:?}", handle);
82 Ok(Listener {
83 handle: unsafe { handle.assume_init() },
84 sender,
85 recver,
86 })
87 }
88
89 pub fn socket_addr(&self) -> Result<SocketAddr, Error> {
91 #[allow(clippy::uninit_assumed_init)]
92 let mut attr = ucp_listener_attr_t {
93 field_mask: ucp_listener_attr_field::UCP_LISTENER_ATTR_FIELD_SOCKADDR.0 as u64,
94 sockaddr: unsafe { MaybeUninit::uninit().assume_init() },
95 };
96 let status = unsafe { ucp_listener_query(self.handle, &mut attr) };
97 Error::from_status(status)?;
98 let sockaddr = unsafe { socket2::SockAddr::new(std::mem::transmute(attr.sockaddr), 8) };
99
100 Ok(sockaddr.as_socket().unwrap())
101 }
102
103 pub async fn next(&mut self) -> ConnectionRequest {
105 self.recver.next().await.unwrap()
106 }
107
108 pub fn reject(&self, conn: ConnectionRequest) -> Result<(), Error> {
110 let status = unsafe { ucp_listener_reject(self.handle, conn.handle) };
111 Error::from_status(status)
112 }
113}
114
115impl Drop for Listener {
116 fn drop(&mut self) {
117 trace!("destroy listener={:?}", self.handle);
118 unsafe { ucp_listener_destroy(self.handle) }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125
126 #[test_log::test]
127 fn accept() {
128 let (sender, recver) = tokio::sync::oneshot::channel();
129 let f1 = spawn_thread!(async move {
130 let context = Context::new().unwrap();
131 let worker = context.create_worker().unwrap();
132 tokio::task::spawn_local(worker.clone().polling());
133 let mut listener = worker
134 .create_listener("0.0.0.0:0".parse().unwrap())
135 .unwrap();
136 let listen_port = listener.socket_addr().unwrap().port();
137 sender.send(listen_port).unwrap();
138 let conn = listener.next().await;
139 let _endpoint = worker.accept(conn).await.unwrap();
140 });
141 spawn_thread!(async move {
142 let context = Context::new().unwrap();
143 let worker = context.create_worker().unwrap();
144 tokio::task::spawn_local(worker.clone().polling());
145 let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
146 let listen_port = recver.await.unwrap();
147 addr.set_port(listen_port);
148 let _endpoint = worker.connect_socket(addr).await.unwrap();
149 });
150 f1.join().unwrap();
151 }
152}