1use std::io;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio::net::TcpStream;
11use tokio_rustls::client::TlsStream;
12
13pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
19
20impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
22
23#[derive(Debug)]
29pub enum ConnectionStream {
30 Plain(TcpStream),
32 Tls(Box<TlsStream<TcpStream>>),
34}
35
36impl ConnectionStream {
37 pub fn plain(stream: TcpStream) -> Self {
39 Self::Plain(stream)
40 }
41
42 pub fn tls(stream: TlsStream<TcpStream>) -> Self {
44 Self::Tls(Box::new(stream))
45 }
46
47 pub fn connection_type(&self) -> &'static str {
49 match self {
50 Self::Plain(_) => "TCP",
51 Self::Tls(_) => "TLS",
52 }
53 }
54
55 pub fn is_encrypted(&self) -> bool {
57 matches!(self, Self::Tls(_))
58 }
59
60 pub fn is_unencrypted(&self) -> bool {
62 matches!(self, Self::Plain(_))
63 }
64
65 pub fn as_tcp_stream(&self) -> Option<&TcpStream> {
70 match self {
71 Self::Plain(tcp) => Some(tcp),
72 Self::Tls(_) => None,
73 }
74 }
75
76 pub fn as_tcp_stream_mut(&mut self) -> Option<&mut TcpStream> {
78 match self {
79 Self::Plain(tcp) => Some(tcp),
80 Self::Tls(_) => None,
81 }
82 }
83
84 pub fn as_tls_stream(&self) -> Option<&TlsStream<TcpStream>> {
86 match self {
87 Self::Tls(tls) => Some(tls.as_ref()),
88 Self::Plain(_) => None,
89 }
90 }
91
92 pub fn as_tls_stream_mut(&mut self) -> Option<&mut TlsStream<TcpStream>> {
94 match self {
95 Self::Tls(tls) => Some(tls.as_mut()),
96 Self::Plain(_) => None,
97 }
98 }
99
100 pub fn underlying_tcp_stream(&self) -> &TcpStream {
105 match self {
106 Self::Plain(tcp) => tcp,
107 Self::Tls(tls) => tls.get_ref().0,
108 }
109 }
110}
111
112impl AsyncRead for ConnectionStream {
113 fn poll_read(
114 mut self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &mut ReadBuf<'_>,
117 ) -> Poll<io::Result<()>> {
118 match &mut *self {
119 Self::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
120 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf),
121 }
122 }
123}
124
125impl AsyncWrite for ConnectionStream {
126 fn poll_write(
127 mut self: Pin<&mut Self>,
128 cx: &mut Context<'_>,
129 buf: &[u8],
130 ) -> Poll<io::Result<usize>> {
131 match &mut *self {
132 Self::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
133 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf),
134 }
135 }
136
137 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
138 match &mut *self {
139 Self::Plain(stream) => Pin::new(stream).poll_flush(cx),
140 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_flush(cx),
141 }
142 }
143
144 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
145 match &mut *self {
146 Self::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
147 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx),
148 }
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use tokio::io::{AsyncReadExt, AsyncWriteExt};
156
157 #[tokio::test]
158 async fn test_connection_stream_plain_tcp() {
159 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
161 let addr = listener.local_addr().unwrap();
162
163 let client_handle = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
164
165 let (server_stream, _) = listener.accept().await.unwrap();
166 let client_stream = client_handle.await.unwrap();
167
168 let mut server_conn = ConnectionStream::plain(server_stream);
170 let mut client_conn = ConnectionStream::plain(client_stream);
171
172 client_conn.write_all(b"Hello").await.unwrap();
174
175 let mut buf = [0u8; 5];
176 server_conn.read_exact(&mut buf).await.unwrap();
177 assert_eq!(&buf, b"Hello");
178
179 assert!(client_conn.is_unencrypted());
181 assert!(!client_conn.is_encrypted());
182 assert_eq!(client_conn.connection_type(), "TCP");
183 assert!(client_conn.as_tcp_stream().is_some());
184 }
185
186 #[test]
187 fn test_async_stream_trait() {
188 fn assert_async_stream<T: AsyncStream>() {}
190 assert_async_stream::<TcpStream>();
191 assert_async_stream::<ConnectionStream>();
192 }
193
194 #[tokio::test]
195 async fn test_connection_stream_tcp_access() {
196 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
197 let addr = listener.local_addr().unwrap();
198
199 let tcp_stream = std::net::TcpStream::connect(addr).unwrap();
200 tcp_stream.set_nonblocking(true).unwrap();
201 let tokio_stream = TcpStream::from_std(tcp_stream).unwrap();
202
203 let mut conn_stream = ConnectionStream::plain(tokio_stream);
204
205 assert!(conn_stream.as_tcp_stream().is_some());
207 assert!(conn_stream.as_tcp_stream_mut().is_some());
208
209 assert!(conn_stream.is_unencrypted());
211 assert!(!conn_stream.is_encrypted());
212 assert_eq!(conn_stream.connection_type(), "TCP");
213
214 let _underlying = conn_stream.underlying_tcp_stream();
216 }
217
218 #[tokio::test]
219 async fn test_connection_type_methods() {
220 use std::net::TcpListener;
222
223 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
224 let addr = listener.local_addr().unwrap();
225
226 let tcp = std::net::TcpStream::connect(addr).unwrap();
227 tcp.set_nonblocking(true).unwrap();
228 let stream = TcpStream::from_std(tcp).unwrap();
229
230 let conn = ConnectionStream::plain(stream);
231
232 assert!(conn.is_unencrypted(), "Plain TCP should be unencrypted");
234 assert!(!conn.is_encrypted(), "Plain TCP should not be encrypted");
235 assert_eq!(conn.connection_type(), "TCP");
236 }
237}