async_stream_connection/
stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use tokio::io::{AsyncRead, AsyncWrite, Error, ReadBuf};
4use tokio::net::TcpStream;
5#[cfg(unix)]
6use tokio::net::UnixStream;
7
8use std::io;
9
10use crate::Addr;
11
12/// A socket connected to an endpoint
13#[derive(Debug)]
14pub enum Stream {
15    /// A TCP stream between a local and a remote socket.
16    Inet(TcpStream),
17    #[cfg(unix)]
18    /// A connected Unix socket
19    Unix(UnixStream),
20}
21
22impl From<TcpStream> for Stream {
23    fn from(s: TcpStream) -> Stream {
24        Stream::Inet(s)
25    }
26}
27
28#[cfg(unix)]
29impl From<UnixStream> for Stream {
30    fn from(s: UnixStream) -> Stream {
31        Stream::Unix(s)
32    }
33}
34
35impl Stream {
36    /// Opens a connection to a remote host.
37    pub async fn connect(s: &Addr) -> io::Result<Stream> {
38        match s {
39            Addr::Inet(s) => TcpStream::connect(s).await.map(Stream::Inet),
40            #[cfg(unix)]
41            Addr::Unix(s) => UnixStream::connect(s).await.map(Stream::Unix),
42        }
43    }
44
45    /// Returns the local address that this stream is bound to.
46    pub fn local_addr(&self) -> io::Result<Addr> {
47        match self {
48            Stream::Inet(s) => s.local_addr().map(Addr::Inet),
49            #[cfg(unix)]
50            Stream::Unix(s) => s.local_addr().map(|e| e.into()),
51        }
52    }
53
54    /// Returns the remote address that this stream is connected to.
55    pub fn peer_addr(&self) -> io::Result<Addr> {
56        match self {
57            Stream::Inet(s) => s.peer_addr().map(Addr::Inet),
58            #[cfg(unix)]
59            Stream::Unix(s) => s.peer_addr().map(|e| e.into()),
60        }
61    }
62}
63impl AsyncRead for Stream {
64    fn poll_read(
65        mut self: Pin<&mut Self>,
66        cx: &mut Context,
67        buf: &mut ReadBuf<'_>,
68    ) -> Poll<Result<(), Error>> {
69        match &mut *self {
70            Stream::Inet(s) => Pin::new(s).as_mut().poll_read(cx, buf),
71            #[cfg(unix)]
72            Stream::Unix(s) => Pin::new(s).as_mut().poll_read(cx, buf),
73        }
74    }
75}
76impl AsyncWrite for Stream {
77    fn poll_write(
78        mut self: Pin<&mut Self>,
79        cx: &mut Context,
80        buf: &[u8],
81    ) -> Poll<Result<usize, Error>> {
82        match &mut *self {
83            Stream::Inet(s) => Pin::new(s).as_mut().poll_write(cx, buf),
84            #[cfg(unix)]
85            Stream::Unix(s) => Pin::new(s).as_mut().poll_write(cx, buf),
86        }
87    }
88
89    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
90        match &mut *self {
91            Stream::Inet(s) => Pin::new(s).as_mut().poll_flush(cx),
92            #[cfg(unix)]
93            Stream::Unix(s) => Pin::new(s).as_mut().poll_flush(cx),
94        }
95    }
96
97    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
98        match &mut *self {
99            Stream::Inet(s) => Pin::new(s).as_mut().poll_shutdown(cx),
100            #[cfg(unix)]
101            Stream::Unix(s) => Pin::new(s).as_mut().poll_shutdown(cx),
102        }
103    }
104}
105
106
107
108#[cfg(test)]
109pub(crate) mod tests {
110    use super::*;
111    use std::net::SocketAddr;
112    use tokio::io::{AsyncReadExt, AsyncWriteExt};
113    use tokio::net::TcpListener;
114    #[cfg(unix)]
115    use tokio::net::UnixListener;
116    use tokio::runtime::Builder;
117
118    pub(crate) async fn local_socket_pair() -> Result<(TcpListener, Addr), std::io::Error> {
119        let a: SocketAddr = "127.0.0.1:0".parse().unwrap();
120        let app_listener = TcpListener::bind(a).await?;
121        let a: Addr = app_listener.local_addr()?.into();
122        Ok((app_listener, a))
123    }
124
125    #[test]
126    fn tcp_connect() {
127        let rt = Builder::new_current_thread().enable_all().build().unwrap();
128        async fn mock_app(app_listener: TcpListener) {
129            let (mut app_socket, _) = app_listener.accept().await.unwrap();
130            let mut buf = [0u8; 32];
131            let i = app_socket.read(&mut buf).await.unwrap();
132            app_socket.write_all(&buf[..i]).await.unwrap();
133        }
134
135        async fn con() {
136            let (app_listener, a) = local_socket_pair().await.unwrap();
137            tokio::spawn(mock_app(app_listener));
138
139            let mut s = Stream::connect(&a).await.expect("tcp connect failed");
140
141            let data = b"1234";
142            s.write_all(&data[..]).await.expect("tcp write failed");
143
144            let mut buf = [0u8; 32];
145            let i = s.read(&mut buf).await.expect("tcp read failed");
146            assert_eq!(&buf[..i], &data[..]);
147        }
148        rt.block_on(con());
149    }
150    #[cfg(unix)]
151    #[test]
152    fn unix_connect() {
153        use std::path::Path;
154
155        let rt = Builder::new_current_thread().enable_all().build().unwrap();
156        async fn mock_app(app_listener: UnixListener) {
157            let (mut app_socket, _) = app_listener.accept().await.unwrap();
158            let mut buf = [0u8; 32];
159            let i = app_socket.read(&mut buf).await.unwrap();
160            app_socket.write_all(&buf[..i]).await.unwrap();
161        }
162
163        async fn con() {
164            let a: &Path = Path::new("/tmp/afcgi.sock");
165            let app_listener = UnixListener::bind(a).unwrap();
166            tokio::spawn(mock_app(app_listener));
167
168            let a: Addr = "/tmp/afcgi.sock".parse().expect("unix parse failed");
169            let mut s = Stream::connect(&a).await.expect("unix connect failed");
170
171            let data = b"1234";
172            s.write_all(&data[..]).await.expect("unix write failed");
173
174            let mut buf = [0u8; 32];
175            let i = s.read(&mut buf).await.expect("unix read failed");
176            assert_eq!(&buf[..i], &data[..]);
177        }
178        rt.block_on(con());
179        std::fs::remove_file("/tmp/afcgi.sock").unwrap();
180    }
181}