rustygear/
wrappedstream.rs

1use std::pin::Pin;
2
3use tokio::{
4    io::{AsyncRead, AsyncWrite},
5    net::TcpStream,
6};
7#[cfg(feature = "tls")]
8use tokio_rustls::{client, server};
9
10#[derive(Debug)]
11pub enum WrappedStream {
12    #[cfg(feature = "tls")]
13    ClientTls(Box<client::TlsStream<TcpStream>>),
14    #[cfg(feature = "tls")]
15    ServerTls(Box<server::TlsStream<TcpStream>>),
16    Plain(TcpStream),
17}
18
19impl Unpin for WrappedStream {}
20
21#[cfg(feature = "tls")]
22impl From<client::TlsStream<TcpStream>> for WrappedStream {
23    fn from(value: client::TlsStream<TcpStream>) -> Self {
24        WrappedStream::ClientTls(Box::new(value))
25    }
26}
27
28#[cfg(feature = "tls")]
29impl From<server::TlsStream<TcpStream>> for WrappedStream {
30    fn from(value: server::TlsStream<TcpStream>) -> Self {
31        WrappedStream::ServerTls(Box::new(value))
32    }
33}
34
35impl From<TcpStream> for WrappedStream {
36    fn from(value: TcpStream) -> Self {
37        WrappedStream::Plain(value)
38    }
39}
40
41impl AsyncRead for WrappedStream {
42    fn poll_read(
43        mut self: std::pin::Pin<&mut Self>,
44        cx: &mut std::task::Context<'_>,
45        buf: &mut tokio::io::ReadBuf<'_>,
46    ) -> std::task::Poll<std::io::Result<()>> {
47        match &mut *self {
48            #[cfg(feature = "tls")]
49            WrappedStream::ClientTls(stream) => Pin::new(stream).poll_read(cx, buf),
50            #[cfg(feature = "tls")]
51            WrappedStream::ServerTls(stream) => Pin::new(stream).poll_read(cx, buf),
52            WrappedStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
53        }
54    }
55}
56
57impl AsyncWrite for WrappedStream {
58    fn poll_write(
59        mut self: std::pin::Pin<&mut Self>,
60        cx: &mut std::task::Context<'_>,
61        buf: &[u8],
62    ) -> std::task::Poll<Result<usize, std::io::Error>> {
63        match &mut *self {
64            WrappedStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65            #[cfg(feature = "tls")]
66            WrappedStream::ClientTls(stream) => Pin::new(stream).poll_write(cx, buf),
67            #[cfg(feature = "tls")]
68            WrappedStream::ServerTls(stream) => Pin::new(stream).poll_write(cx, buf),
69        }
70    }
71
72    fn poll_flush(
73        mut self: std::pin::Pin<&mut Self>,
74        cx: &mut std::task::Context<'_>,
75    ) -> std::task::Poll<Result<(), std::io::Error>> {
76        match &mut *self {
77            WrappedStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
78            #[cfg(feature = "tls")]
79            WrappedStream::ClientTls(stream) => Pin::new(stream).poll_flush(cx),
80            #[cfg(feature = "tls")]
81            WrappedStream::ServerTls(stream) => Pin::new(stream).poll_flush(cx),
82        }
83    }
84
85    fn poll_shutdown(
86        mut self: std::pin::Pin<&mut Self>,
87        cx: &mut std::task::Context<'_>,
88    ) -> std::task::Poll<Result<(), std::io::Error>> {
89        match &mut *self {
90            WrappedStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
91            #[cfg(feature = "tls")]
92            WrappedStream::ClientTls(stream) => Pin::new(stream).poll_shutdown(cx),
93            #[cfg(feature = "tls")]
94            WrappedStream::ServerTls(stream) => Pin::new(stream).poll_shutdown(cx),
95        }
96    }
97}