1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7
8#[cfg(any(feature = "rustls-base", feature = "tls", feature = "openssl-tls"))]
9use hyper_util::rt::TokioIo;
10
11#[cfg(feature = "rustls-base")]
12use tokio_rustls::client::TlsStream as RustlsStream;
13
14#[cfg(feature = "tls")]
15use tokio_native_tls::TlsStream as TokioNativeTlsStream;
16
17#[cfg(feature = "openssl-tls")]
18use tokio_openssl::SslStream as OpenSslStream;
19
20#[cfg(feature = "rustls-base")]
21pub type TlsStream<R> = TokioIo<RustlsStream<TokioIo<R>>>;
22
23#[cfg(feature = "tls")]
24pub type TlsStream<R> = TokioIo<TokioNativeTlsStream<TokioIo<R>>>;
25
26#[cfg(feature = "openssl-tls")]
27pub type TlsStream<R> = TokioIo<OpenSslStream<TokioIo<R>>>;
28
29pub enum ProxyStream<R> {
31 NoProxy(R),
32 Regular(R),
33 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
34 Secured(TlsStream<R>),
35}
36
37macro_rules! match_fn_pinned {
38 ($self:expr, $fn:ident, $ctx:expr, $buf:expr) => {
39 match $self.get_mut() {
40 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
41 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
42 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
43 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
44 }
45 };
46
47 ($self:expr, $fn:ident, $ctx:expr) => {
48 match $self.get_mut() {
49 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
50 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
51 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
52 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
53 }
54 };
55}
56
57impl<R: Read + Write + Unpin> Read for ProxyStream<R> {
58 fn poll_read(
59 self: Pin<&mut Self>,
60 cx: &mut Context<'_>,
61 buf: ReadBufCursor<'_>,
62 ) -> Poll<io::Result<()>> {
63 match_fn_pinned!(self, poll_read, cx, buf)
64 }
65}
66
67impl<R: Read + Write + Unpin> Write for ProxyStream<R> {
68 fn poll_write(
69 self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<io::Result<usize>> {
73 match_fn_pinned!(self, poll_write, cx, buf)
74 }
75
76 fn poll_write_vectored(
77 self: Pin<&mut Self>,
78 cx: &mut Context<'_>,
79 bufs: &[io::IoSlice<'_>],
80 ) -> Poll<Result<usize, io::Error>> {
81 match_fn_pinned!(self, poll_write_vectored, cx, bufs)
82 }
83
84 fn is_write_vectored(&self) -> bool {
85 match self {
86 ProxyStream::NoProxy(s) => s.is_write_vectored(),
87 ProxyStream::Regular(s) => s.is_write_vectored(),
88 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
89 ProxyStream::Secured(s) => s.is_write_vectored(),
90 }
91 }
92
93 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
94 match_fn_pinned!(self, poll_flush, cx)
95 }
96
97 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
98 match_fn_pinned!(self, poll_shutdown, cx)
99 }
100}
101
102impl<R: Read + Write + Connection + Unpin> Connection for ProxyStream<R> {
103 fn connected(&self) -> Connected {
104 match self {
105 ProxyStream::NoProxy(s) => s.connected(),
106
107 ProxyStream::Regular(s) => s.connected().proxy(true),
108 #[cfg(feature = "tls")]
109 ProxyStream::Secured(s) => s
110 .inner()
111 .get_ref()
112 .get_ref()
113 .get_ref()
114 .inner()
115 .connected()
116 .proxy(true),
117
118 #[cfg(feature = "rustls-base")]
119 ProxyStream::Secured(s) => s.inner().get_ref().0.inner().connected().proxy(true),
120
121 #[cfg(feature = "openssl-tls")]
122 ProxyStream::Secured(s) => s.inner().get_ref().inner().connected().proxy(true),
123 }
124 }
125}