1use crate::tls::TlsStream;
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12
13pub trait AsyncStream: AsyncRead + AsyncWrite + Unpin + Send {}
19
20impl<T> AsyncStream for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
22
23#[derive(Debug)]
29pub enum ConnectionStream {
30 Plain(TcpStream),
32 Tls(Box<TlsStream<TcpStream>>),
34}
35
36impl ConnectionStream {
37 pub fn plain(stream: TcpStream) -> Self {
39 Self::Plain(stream)
40 }
41
42 pub fn tls(stream: TlsStream<TcpStream>) -> Self {
44 Self::Tls(Box::new(stream))
45 }
46
47 #[must_use]
49 pub fn connection_type(&self) -> &'static str {
50 match self {
51 Self::Plain(_) => "TCP",
52 Self::Tls(_) => "TLS",
53 }
54 }
55
56 #[inline]
58 #[must_use]
59 pub const fn is_encrypted(&self) -> bool {
60 matches!(self, Self::Tls(_))
61 }
62
63 #[inline]
65 #[must_use]
66 pub const fn is_unencrypted(&self) -> bool {
67 matches!(self, Self::Plain(_))
68 }
69
70 #[must_use]
75 pub fn as_tcp_stream(&self) -> Option<&TcpStream> {
76 match self {
77 Self::Plain(tcp) => Some(tcp),
78 Self::Tls(_) => None,
79 }
80 }
81
82 pub fn as_tcp_stream_mut(&mut self) -> Option<&mut TcpStream> {
84 match self {
85 Self::Plain(tcp) => Some(tcp),
86 Self::Tls(_) => None,
87 }
88 }
89
90 #[must_use]
92 pub fn as_tls_stream(&self) -> Option<&TlsStream<TcpStream>> {
93 match self {
94 Self::Tls(tls) => Some(tls.as_ref()),
95 Self::Plain(_) => None,
96 }
97 }
98
99 pub fn as_tls_stream_mut(&mut self) -> Option<&mut TlsStream<TcpStream>> {
101 match self {
102 Self::Tls(tls) => Some(tls.as_mut()),
103 Self::Plain(_) => None,
104 }
105 }
106
107 #[must_use]
112 pub fn underlying_tcp_stream(&self) -> &TcpStream {
113 match self {
114 Self::Plain(tcp) => tcp,
115 Self::Tls(tls) => tls.get_ref().0,
116 }
117 }
118}
119
120impl AsyncRead for ConnectionStream {
121 fn poll_read(
122 mut self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 buf: &mut ReadBuf<'_>,
125 ) -> Poll<io::Result<()>> {
126 match &mut *self {
127 Self::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
128 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_read(cx, buf),
129 }
130 }
131}
132
133impl AsyncWrite for ConnectionStream {
134 fn poll_write(
135 mut self: Pin<&mut Self>,
136 cx: &mut Context<'_>,
137 buf: &[u8],
138 ) -> Poll<io::Result<usize>> {
139 match &mut *self {
140 Self::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
141 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_write(cx, buf),
142 }
143 }
144
145 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
146 match &mut *self {
147 Self::Plain(stream) => Pin::new(stream).poll_flush(cx),
148 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_flush(cx),
149 }
150 }
151
152 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
153 match &mut *self {
154 Self::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
155 Self::Tls(stream) => Pin::new(stream.as_mut()).poll_shutdown(cx),
156 }
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use tokio::io::{AsyncReadExt, AsyncWriteExt};
164
165 #[tokio::test]
166 async fn test_connection_stream_plain_tcp() {
167 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
169 let addr = listener.local_addr().unwrap();
170
171 let client_handle = tokio::spawn(async move { TcpStream::connect(addr).await.unwrap() });
172
173 let (server_stream, _) = listener.accept().await.unwrap();
174 let client_stream = client_handle.await.unwrap();
175
176 let mut server_conn = ConnectionStream::plain(server_stream);
178 let mut client_conn = ConnectionStream::plain(client_stream);
179
180 client_conn.write_all(b"Hello").await.unwrap();
182
183 let mut buf = [0u8; 5];
184 server_conn.read_exact(&mut buf).await.unwrap();
185 assert_eq!(&buf, b"Hello");
186
187 assert!(client_conn.is_unencrypted());
189 assert!(!client_conn.is_encrypted());
190 assert_eq!(client_conn.connection_type(), "TCP");
191 assert!(client_conn.as_tcp_stream().is_some());
192 }
193
194 #[test]
195 fn test_async_stream_trait() {
196 fn assert_async_stream<T: AsyncStream>() {}
198 assert_async_stream::<TcpStream>();
199 assert_async_stream::<ConnectionStream>();
200 }
201
202 #[tokio::test]
203 async fn test_connection_stream_tcp_access() {
204 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
205 let addr = listener.local_addr().unwrap();
206
207 let tcp_stream = std::net::TcpStream::connect(addr).unwrap();
208 tcp_stream.set_nonblocking(true).unwrap();
209 let tokio_stream = TcpStream::from_std(tcp_stream).unwrap();
210
211 let mut conn_stream = ConnectionStream::plain(tokio_stream);
212
213 assert!(conn_stream.as_tcp_stream().is_some());
215 assert!(conn_stream.as_tcp_stream_mut().is_some());
216
217 assert!(conn_stream.is_unencrypted());
219 assert!(!conn_stream.is_encrypted());
220 assert_eq!(conn_stream.connection_type(), "TCP");
221
222 let _underlying = conn_stream.underlying_tcp_stream();
224 }
225
226 #[tokio::test]
227 async fn test_connection_type_methods() {
228 use std::net::TcpListener;
230
231 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
232 let addr = listener.local_addr().unwrap();
233
234 let tcp = std::net::TcpStream::connect(addr).unwrap();
235 tcp.set_nonblocking(true).unwrap();
236 let stream = TcpStream::from_std(tcp).unwrap();
237
238 let conn = ConnectionStream::plain(stream);
239
240 assert!(conn.is_unencrypted(), "Plain TCP should be unencrypted");
242 assert!(!conn.is_encrypted(), "Plain TCP should not be encrypted");
243 assert_eq!(conn.connection_type(), "TCP");
244 }
245
246 #[tokio::test]
248 async fn test_tls_connection_type_methods() {
249 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
254 let addr = listener.local_addr().unwrap();
255 let tcp = std::net::TcpStream::connect(addr).unwrap();
256 tcp.set_nonblocking(true).unwrap();
257 let stream = TcpStream::from_std(tcp).unwrap();
258
259 let conn = ConnectionStream::plain(stream);
260
261 assert!(conn.is_unencrypted());
263 assert!(!conn.is_encrypted());
264 assert_eq!(conn.connection_type(), "TCP");
265
266 assert!(conn.as_tcp_stream().is_some());
268 assert!(conn.as_tls_stream().is_none());
269 }
270
271 #[tokio::test]
272 async fn test_plain_tcp_stream_accessors() {
273 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
274 let addr = listener.local_addr().unwrap();
275
276 let tcp = std::net::TcpStream::connect(addr).unwrap();
277 tcp.set_nonblocking(true).unwrap();
278 let stream = TcpStream::from_std(tcp).unwrap();
279
280 let mut conn = ConnectionStream::plain(stream);
281
282 assert!(conn.as_tcp_stream().is_some());
284 assert!(conn.as_tls_stream().is_none());
285
286 assert!(conn.as_tcp_stream_mut().is_some());
288 assert!(conn.as_tls_stream_mut().is_none());
289
290 let _underlying = conn.underlying_tcp_stream();
292 }
293
294 #[tokio::test]
295 async fn test_connection_stream_enum_variants() {
296 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
300 let addr = listener.local_addr().unwrap();
301 let tcp = std::net::TcpStream::connect(addr).unwrap();
302 tcp.set_nonblocking(true).unwrap();
303 let stream = TcpStream::from_std(tcp).unwrap();
304
305 let plain_conn = ConnectionStream::plain(stream);
306
307 match plain_conn {
309 ConnectionStream::Plain(_) => {
310 assert!(true, "Plain variant works");
311 }
312 ConnectionStream::Tls(_) => {
313 panic!("Should be Plain, not Tls");
314 }
315 }
316 }
317
318 #[tokio::test]
319 async fn test_connection_type_string_values() {
320 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
321 let addr = listener.local_addr().unwrap();
322
323 let tcp = std::net::TcpStream::connect(addr).unwrap();
324 tcp.set_nonblocking(true).unwrap();
325 let stream = TcpStream::from_std(tcp).unwrap();
326
327 let conn = ConnectionStream::plain(stream);
328
329 assert_eq!(conn.connection_type(), "TCP");
331
332 let type_str = conn.connection_type();
334 assert!(!type_str.is_empty());
335 assert!(type_str.len() < 10); }
337
338 #[tokio::test]
339 async fn test_plain_connection_debug_format() {
340 use std::net::TcpListener;
341
342 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
343 let addr = listener.local_addr().unwrap();
344
345 let tcp = std::net::TcpStream::connect(addr).unwrap();
346 tcp.set_nonblocking(true).unwrap();
347 let stream = TcpStream::from_std(tcp).unwrap();
348
349 let conn = ConnectionStream::plain(stream);
350
351 let debug_str = format!("{:?}", conn);
353 assert!(
354 debug_str.contains("Plain"),
355 "Debug output should indicate Plain TCP"
356 );
357 }
358}