Skip to main content

async_acceptor/
async_acceptable.rs

1//! Abstraction `AsyncAcceptable` over tokio listeners with an async `accept()` method
2//
3// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
4
5use std::future::poll_fn;
6use std::io;
7use std::net::{Ipv4Addr, SocketAddr};
8use std::task::{Context, Poll, ready};
9use tokio::io::{AsyncRead, AsyncWrite};
10
11/// A Listener that can accept connections asynchronously.
12pub trait AsyncAcceptable {
13    /// The type of stream that will be returned by `accept()`
14    type Stream: AsyncRead + AsyncWrite + Unpin + Send + 'static;
15
16    /// Poll accept a connection asynchronously.
17    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>>;
18
19    /// Poll accept a connection asynchronously, returning the stream and the peer address.
20    fn poll_accept_with_sockaddr(
21        &self,
22        cx: &mut Context<'_>,
23    ) -> Poll<io::Result<(Self::Stream, SocketAddr)>> {
24        let stream = ready!(self.poll_accept(cx))?;
25        let peer = SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0));
26        Poll::Ready(Ok((stream, peer)))
27    }
28}
29
30#[cfg(feature = "tokio-net")]
31impl AsyncAcceptable for tokio::net::TcpListener {
32    type Stream = tokio::net::TcpStream;
33
34    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
35        let stream = ready!(self.poll_accept(cx))?.0;
36        Poll::Ready(Ok(stream))
37    }
38
39    fn poll_accept_with_sockaddr(
40        &self,
41        cx: &mut Context<'_>,
42    ) -> Poll<io::Result<(Self::Stream, SocketAddr)>> {
43        self.poll_accept(cx)
44    }
45}
46
47#[cfg(unix)]
48#[cfg(feature = "tokio-net")]
49impl AsyncAcceptable for tokio::net::UnixListener {
50    type Stream = tokio::net::UnixStream;
51
52    fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<io::Result<Self::Stream>> {
53        let stream = ready!(self.poll_accept(cx))?.0;
54        Poll::Ready(Ok(stream))
55    }
56}
57
58/// Extension trait for `AsyncAcceptable` that provides async methods.
59pub trait AsyncAcceptableExt: AsyncAcceptable + Send + Sync {
60    /// Accept a connection asynchronously.
61    fn accept(&self) -> impl Future<Output = io::Result<Self::Stream>> + Send {
62        poll_fn(|cx| self.poll_accept(cx))
63    }
64
65    /// Accept a connection asynchronously, returning the stream and the peer address.
66    fn accept_with_sockaddr(
67        &self,
68    ) -> impl Future<Output = io::Result<(Self::Stream, SocketAddr)>> + Send {
69        poll_fn(|cx| self.poll_accept_with_sockaddr(cx))
70    }
71}
72
73impl<T: AsyncAcceptable + Send + Sync> AsyncAcceptableExt for T {}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78
79    #[cfg(feature = "tokio-net")]
80    #[tokio::test]
81    async fn test_async_acceptable() {
82        use tokio::io::{AsyncReadExt, AsyncWriteExt};
83
84        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
85        let addr = listener.local_addr().unwrap();
86        let connector_task = tokio::spawn(async move {
87            let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
88            stream.write_all(b"test").await.unwrap();
89            stream
90        });
91        let (mut s, a) = AsyncAcceptableExt::accept_with_sockaddr(&listener)
92            .await
93            .unwrap();
94        let stream = connector_task.await.unwrap();
95        let mut buf = [0u8; 4];
96        s.read_exact(&mut buf).await.unwrap();
97        assert_eq!(&buf, b"test");
98        assert_eq!(a, stream.local_addr().unwrap());
99    }
100}