1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
5
6#[cfg(feature = "rustls-base")]
7use tokio_rustls::client::TlsStream as RustlsStream;
8
9#[cfg(feature = "tls")]
10use tokio_native_tls::TlsStream;
11
12#[cfg(feature = "openssl-tls")]
13use tokio_openssl::SslStream as OpenSslStream;
14
15use hyper::client::connect::{Connected, Connection};
16
17#[cfg(feature = "rustls-base")]
18pub type TlsStream<R> = RustlsStream<R>;
19
20#[cfg(feature = "openssl-tls")]
21pub type TlsStream<R> = OpenSslStream<R>;
22
23pub enum ProxyStream<R> {
25 NoProxy(R),
26 Regular(R),
27 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
28 Secured(TlsStream<R>),
29}
30
31macro_rules! match_fn_pinned {
32 ($self:expr, $fn:ident, $ctx:expr, $buf:expr) => {
33 match $self.get_mut() {
34 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
35 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
36 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
37 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
38 }
39 };
40
41 ($self:expr, $fn:ident, $ctx:expr) => {
42 match $self.get_mut() {
43 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
44 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
45 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
46 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
47 }
48 };
49}
50
51impl<R: AsyncRead + AsyncWrite + Unpin> AsyncRead for ProxyStream<R> {
52 fn poll_read(
53 self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: &mut ReadBuf<'_>,
56 ) -> Poll<io::Result<()>> {
57 match_fn_pinned!(self, poll_read, cx, buf)
58 }
59}
60
61impl<R: AsyncRead + AsyncWrite + Unpin> AsyncWrite for ProxyStream<R> {
62 fn poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<io::Result<usize>> {
67 match_fn_pinned!(self, poll_write, cx, buf)
68 }
69
70 fn poll_write_vectored(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 bufs: &[io::IoSlice<'_>],
74 ) -> Poll<Result<usize, io::Error>> {
75 match_fn_pinned!(self, poll_write_vectored, cx, bufs)
76 }
77
78 fn is_write_vectored(&self) -> bool {
79 match self {
80 ProxyStream::NoProxy(s) => s.is_write_vectored(),
81 ProxyStream::Regular(s) => s.is_write_vectored(),
82 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
83 ProxyStream::Secured(s) => s.is_write_vectored(),
84 }
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 match_fn_pinned!(self, poll_flush, cx)
89 }
90
91 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92 match_fn_pinned!(self, poll_shutdown, cx)
93 }
94}
95
96impl<R: AsyncRead + AsyncWrite + Connection + Unpin> Connection for ProxyStream<R> {
97 fn connected(&self) -> Connected {
98 match self {
99 ProxyStream::NoProxy(s) => s.connected(),
100
101 ProxyStream::Regular(s) => s.connected().proxy(true),
102 #[cfg(feature = "tls")]
103 ProxyStream::Secured(s) => s.get_ref().get_ref().get_ref().connected().proxy(true),
104
105 #[cfg(feature = "rustls-base")]
106 ProxyStream::Secured(s) => s.get_ref().0.connected().proxy(true),
107
108 #[cfg(feature = "openssl-tls")]
109 ProxyStream::Secured(s) => s.get_ref().connected().proxy(true),
110 }
111 }
112}