async_acceptor/
async_acceptable.rs1use std::future::poll_fn;
6use std::io;
7use std::net::{Ipv4Addr, SocketAddr};
8use std::task::{Context, Poll, ready};
9use tokio::io::{AsyncRead, AsyncWrite};
10
11pub trait AsyncAcceptable {
13 type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
15
16 fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>>;
18
19 fn poll_accept_with_sockaddr(
21 &self,
22 cx: &mut Context<'_>,
23 ) -> Poll<io::Result<(Self::Stream, SocketAddr)>> {
24 let stream = ready!(self.poll_accept(cx))?;
25 let peer = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0));
26 Poll::Ready(Ok((stream, peer)))
27 }
28}
29
30#[cfg(feature = "tokio-net")]
31impl AsyncAcceptable for tokio::net::TcpListener {
32 type Stream = tokio::net::TcpStream;
33
34 fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
35 let stream = ready!(self.poll_accept(cx))?.0;
36 Poll::Ready(Ok(stream))
37 }
38
39 fn poll_accept_with_sockaddr(
40 &self,
41 cx: &mut Context<'_>,
42 ) -> Poll<io::Result<(Self::Stream, SocketAddr)>> {
43 self.poll_accept(cx)
44 }
45}
46
47#[cfg(unix)]
48#[cfg(feature = "tokio-net")]
49impl AsyncAcceptable for tokio::net::UnixListener {
50 type Stream = tokio::net::UnixStream;
51
52 fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
53 let stream = ready!(self.poll_accept(cx))?.0;
54 Poll::Ready(Ok(stream))
55 }
56}
57
58pub trait AsyncAcceptableExt: AsyncAcceptable + Send + Sync {
60 fn accept(&self) -> impl Future<Output = io::Result<Self::Stream>> + Send {
62 poll_fn(|cx| self.poll_accept(cx))
63 }
64
65 fn accept_with_sockaddr(
67 &self,
68 ) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send {
69 poll_fn(|cx| self.poll_accept_with_sockaddr(cx))
70 }
71}
72
73impl<T: AsyncAcceptable + Send + Sync> AsyncAcceptableExt for T {}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78
79 #[cfg(feature = "tokio-net")]
80 #[tokio::test]
81 async fn test_async_acceptable() {
82 use tokio::io::{AsyncReadExt, AsyncWriteExt};
83
84 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
85 let addr = listener.local_addr().unwrap();
86 let connector_task = tokio::spawn(async move {
87 let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
88 stream.write_all(b"test").await.unwrap();
89 stream
90 });
91 let (mut s, a) = AsyncAcceptableExt::accept_with_sockaddr(&listener)
92 .await
93 .unwrap();
94 let stream = connector_task.await.unwrap();
95 let mut buf = [0u8; 4];
96 s.read_exact(&mut buf).await.unwrap();
97 assert_eq!(&buf, b"test");
98 assert_eq!(a, stream.local_addr().unwrap());
99 }
100}