rustygear/
wrappedstream.rs1use std::pin::Pin;
2
3use tokio::{
4 io::{AsyncRead, AsyncWrite},
5 net::TcpStream,
6};
7#[cfg(feature = "tls")]
8use tokio_rustls::{client, server};
9
10#[derive(Debug)]
11pub enum WrappedStream {
12 #[cfg(feature = "tls")]
13 ClientTls(Box<client::TlsStream<TcpStream>>),
14 #[cfg(feature = "tls")]
15 ServerTls(Box<server::TlsStream<TcpStream>>),
16 Plain(TcpStream),
17}
18
19impl Unpin for WrappedStream {}
20
21#[cfg(feature = "tls")]
22impl From<client::TlsStream<TcpStream>> for WrappedStream {
23 fn from(value: client::TlsStream<TcpStream>) -> Self {
24 WrappedStream::ClientTls(Box::new(value))
25 }
26}
27
28#[cfg(feature = "tls")]
29impl From<server::TlsStream<TcpStream>> for WrappedStream {
30 fn from(value: server::TlsStream<TcpStream>) -> Self {
31 WrappedStream::ServerTls(Box::new(value))
32 }
33}
34
35impl From<TcpStream> for WrappedStream {
36 fn from(value: TcpStream) -> Self {
37 WrappedStream::Plain(value)
38 }
39}
40
41impl AsyncRead for WrappedStream {
42 fn poll_read(
43 mut self: std::pin::Pin<&mut Self>,
44 cx: &mut std::task::Context<'_>,
45 buf: &mut tokio::io::ReadBuf<'_>,
46 ) -> std::task::Poll<std::io::Result<()>> {
47 match &mut *self {
48 #[cfg(feature = "tls")]
49 WrappedStream::ClientTls(stream) => Pin::new(stream).poll_read(cx, buf),
50 #[cfg(feature = "tls")]
51 WrappedStream::ServerTls(stream) => Pin::new(stream).poll_read(cx, buf),
52 WrappedStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
53 }
54 }
55}
56
57impl AsyncWrite for WrappedStream {
58 fn poll_write(
59 mut self: std::pin::Pin<&mut Self>,
60 cx: &mut std::task::Context<'_>,
61 buf: &[u8],
62 ) -> std::task::Poll<Result<usize, std::io::Error>> {
63 match &mut *self {
64 WrappedStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
65 #[cfg(feature = "tls")]
66 WrappedStream::ClientTls(stream) => Pin::new(stream).poll_write(cx, buf),
67 #[cfg(feature = "tls")]
68 WrappedStream::ServerTls(stream) => Pin::new(stream).poll_write(cx, buf),
69 }
70 }
71
72 fn poll_flush(
73 mut self: std::pin::Pin<&mut Self>,
74 cx: &mut std::task::Context<'_>,
75 ) -> std::task::Poll<Result<(), std::io::Error>> {
76 match &mut *self {
77 WrappedStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
78 #[cfg(feature = "tls")]
79 WrappedStream::ClientTls(stream) => Pin::new(stream).poll_flush(cx),
80 #[cfg(feature = "tls")]
81 WrappedStream::ServerTls(stream) => Pin::new(stream).poll_flush(cx),
82 }
83 }
84
85 fn poll_shutdown(
86 mut self: std::pin::Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 ) -> std::task::Poll<Result<(), std::io::Error>> {
89 match &mut *self {
90 WrappedStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
91 #[cfg(feature = "tls")]
92 WrappedStream::ClientTls(stream) => Pin::new(stream).poll_shutdown(cx),
93 #[cfg(feature = "tls")]
94 WrappedStream::ServerTls(stream) => Pin::new(stream).poll_shutdown(cx),
95 }
96 }
97}