async_stream_connection/
stream.rs1use std::pin::Pin;
2use std::task::{Context, Poll};
3use tokio::io::{AsyncRead, AsyncWrite, Error, ReadBuf};
4use tokio::net::TcpStream;
5#[cfg(unix)]
6use tokio::net::UnixStream;
7
8use std::io;
9
10use crate::Addr;
11
12#[derive(Debug)]
14pub enum Stream {
15 Inet(TcpStream),
17 #[cfg(unix)]
18 Unix(UnixStream),
20}
21
22impl From<TcpStream> for Stream {
23 fn from(s: TcpStream) -> Stream {
24 Stream::Inet(s)
25 }
26}
27
28#[cfg(unix)]
29impl From<UnixStream> for Stream {
30 fn from(s: UnixStream) -> Stream {
31 Stream::Unix(s)
32 }
33}
34
35impl Stream {
36 pub async fn connect(s: &Addr) -> io::Result<Stream> {
38 match s {
39 Addr::Inet(s) => TcpStream::connect(s).await.map(Stream::Inet),
40 #[cfg(unix)]
41 Addr::Unix(s) => UnixStream::connect(s).await.map(Stream::Unix),
42 }
43 }
44
45 pub fn local_addr(&self) -> io::Result<Addr> {
47 match self {
48 Stream::Inet(s) => s.local_addr().map(Addr::Inet),
49 #[cfg(unix)]
50 Stream::Unix(s) => s.local_addr().map(|e| e.into()),
51 }
52 }
53
54 pub fn peer_addr(&self) -> io::Result<Addr> {
56 match self {
57 Stream::Inet(s) => s.peer_addr().map(Addr::Inet),
58 #[cfg(unix)]
59 Stream::Unix(s) => s.peer_addr().map(|e| e.into()),
60 }
61 }
62}
63impl AsyncRead for Stream {
64 fn poll_read(
65 mut self: Pin<&mut Self>,
66 cx: &mut Context,
67 buf: &mut ReadBuf<'_>,
68 ) -> Poll<Result<(), Error>> {
69 match &mut *self {
70 Stream::Inet(s) => Pin::new(s).as_mut().poll_read(cx, buf),
71 #[cfg(unix)]
72 Stream::Unix(s) => Pin::new(s).as_mut().poll_read(cx, buf),
73 }
74 }
75}
76impl AsyncWrite for Stream {
77 fn poll_write(
78 mut self: Pin<&mut Self>,
79 cx: &mut Context,
80 buf: &[u8],
81 ) -> Poll<Result<usize, Error>> {
82 match &mut *self {
83 Stream::Inet(s) => Pin::new(s).as_mut().poll_write(cx, buf),
84 #[cfg(unix)]
85 Stream::Unix(s) => Pin::new(s).as_mut().poll_write(cx, buf),
86 }
87 }
88
89 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
90 match &mut *self {
91 Stream::Inet(s) => Pin::new(s).as_mut().poll_flush(cx),
92 #[cfg(unix)]
93 Stream::Unix(s) => Pin::new(s).as_mut().poll_flush(cx),
94 }
95 }
96
97 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Error>> {
98 match &mut *self {
99 Stream::Inet(s) => Pin::new(s).as_mut().poll_shutdown(cx),
100 #[cfg(unix)]
101 Stream::Unix(s) => Pin::new(s).as_mut().poll_shutdown(cx),
102 }
103 }
104}
105
106
107
108#[cfg(test)]
109pub(crate) mod tests {
110 use super::*;
111 use std::net::SocketAddr;
112 use tokio::io::{AsyncReadExt, AsyncWriteExt};
113 use tokio::net::TcpListener;
114 #[cfg(unix)]
115 use tokio::net::UnixListener;
116 use tokio::runtime::Builder;
117
118 pub(crate) async fn local_socket_pair() -> Result<(TcpListener, Addr), std::io::Error> {
119 let a: SocketAddr = "127.0.0.1:0".parse().unwrap();
120 let app_listener = TcpListener::bind(a).await?;
121 let a: Addr = app_listener.local_addr()?.into();
122 Ok((app_listener, a))
123 }
124
125 #[test]
126 fn tcp_connect() {
127 let rt = Builder::new_current_thread().enable_all().build().unwrap();
128 async fn mock_app(app_listener: TcpListener) {
129 let (mut app_socket, _) = app_listener.accept().await.unwrap();
130 let mut buf = [0u8; 32];
131 let i = app_socket.read(&mut buf).await.unwrap();
132 app_socket.write_all(&buf[..i]).await.unwrap();
133 }
134
135 async fn con() {
136 let (app_listener, a) = local_socket_pair().await.unwrap();
137 tokio::spawn(mock_app(app_listener));
138
139 let mut s = Stream::connect(&a).await.expect("tcp connect failed");
140
141 let data = b"1234";
142 s.write_all(&data[..]).await.expect("tcp write failed");
143
144 let mut buf = [0u8; 32];
145 let i = s.read(&mut buf).await.expect("tcp read failed");
146 assert_eq!(&buf[..i], &data[..]);
147 }
148 rt.block_on(con());
149 }
150 #[cfg(unix)]
151 #[test]
152 fn unix_connect() {
153 use std::path::Path;
154
155 let rt = Builder::new_current_thread().enable_all().build().unwrap();
156 async fn mock_app(app_listener: UnixListener) {
157 let (mut app_socket, _) = app_listener.accept().await.unwrap();
158 let mut buf = [0u8; 32];
159 let i = app_socket.read(&mut buf).await.unwrap();
160 app_socket.write_all(&buf[..i]).await.unwrap();
161 }
162
163 async fn con() {
164 let a: &Path = Path::new("/tmp/afcgi.sock");
165 let app_listener = UnixListener::bind(a).unwrap();
166 tokio::spawn(mock_app(app_listener));
167
168 let a: Addr = "/tmp/afcgi.sock".parse().expect("unix parse failed");
169 let mut s = Stream::connect(&a).await.expect("unix connect failed");
170
171 let data = b"1234";
172 s.write_all(&data[..]).await.expect("unix write failed");
173
174 let mut buf = [0u8; 32];
175 let i = s.read(&mut buf).await.expect("unix read failed");
176 assert_eq!(&buf[..i], &data[..]);
177 }
178 rt.block_on(con());
179 std::fs::remove_file("/tmp/afcgi.sock").unwrap();
180 }
181}