hyper_proxy2/
stream.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7
8#[cfg(any(feature = "rustls-base", feature = "tls", feature = "openssl-tls"))]
9use hyper_util::rt::TokioIo;
10
11#[cfg(feature = "rustls-base")]
12use tokio_rustls::client::TlsStream as RustlsStream;
13
14#[cfg(feature = "tls")]
15use tokio_native_tls::TlsStream as TokioNativeTlsStream;
16
17#[cfg(feature = "openssl-tls")]
18use tokio_openssl::SslStream as OpenSslStream;
19
20#[cfg(feature = "rustls-base")]
21pub type TlsStream<R> = TokioIo<RustlsStream<TokioIo<R>>>;
22
23#[cfg(feature = "tls")]
24pub type TlsStream<R> = TokioIo<TokioNativeTlsStream<TokioIo<R>>>;
25
26#[cfg(feature = "openssl-tls")]
27pub type TlsStream<R> = TokioIo<OpenSslStream<TokioIo<R>>>;
28
29/// A Proxy Stream wrapper
30pub enum ProxyStream<R> {
31    NoProxy(R),
32    Regular(R),
33    #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
34    Secured(TlsStream<R>),
35}
36
37macro_rules! match_fn_pinned {
38    ($self:expr, $fn:ident, $ctx:expr, $buf:expr) => {
39        match $self.get_mut() {
40            ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
41            ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
42            #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
43            ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
44        }
45    };
46
47    ($self:expr, $fn:ident, $ctx:expr) => {
48        match $self.get_mut() {
49            ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
50            ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
51            #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
52            ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
53        }
54    };
55}
56
57impl<R: Read + Write + Unpin> Read for ProxyStream<R> {
58    fn poll_read(
59        self: Pin<&mut Self>,
60        cx: &mut Context<'_>,
61        buf: ReadBufCursor<'_>,
62    ) -> Poll<io::Result<()>> {
63        match_fn_pinned!(self, poll_read, cx, buf)
64    }
65}
66
67impl<R: Read + Write + Unpin> Write for ProxyStream<R> {
68    fn poll_write(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72    ) -> Poll<io::Result<usize>> {
73        match_fn_pinned!(self, poll_write, cx, buf)
74    }
75
76    fn poll_write_vectored(
77        self: Pin<&mut Self>,
78        cx: &mut Context<'_>,
79        bufs: &[io::IoSlice<'_>],
80    ) -> Poll<Result<usize, io::Error>> {
81        match_fn_pinned!(self, poll_write_vectored, cx, bufs)
82    }
83
84    fn is_write_vectored(&self) -> bool {
85        match self {
86            ProxyStream::NoProxy(s) => s.is_write_vectored(),
87            ProxyStream::Regular(s) => s.is_write_vectored(),
88            #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
89            ProxyStream::Secured(s) => s.is_write_vectored(),
90        }
91    }
92
93    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
94        match_fn_pinned!(self, poll_flush, cx)
95    }
96
97    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
98        match_fn_pinned!(self, poll_shutdown, cx)
99    }
100}
101
102impl<R: Read + Write + Connection + Unpin> Connection for ProxyStream<R> {
103    fn connected(&self) -> Connected {
104        match self {
105            ProxyStream::NoProxy(s) => s.connected(),
106
107            ProxyStream::Regular(s) => s.connected().proxy(true),
108            #[cfg(feature = "tls")]
109            ProxyStream::Secured(s) => s
110                .inner()
111                .get_ref()
112                .get_ref()
113                .get_ref()
114                .inner()
115                .connected()
116                .proxy(true),
117
118            #[cfg(feature = "rustls-base")]
119            ProxyStream::Secured(s) => s.inner().get_ref().0.inner().connected().proxy(true),
120
121            #[cfg(feature = "openssl-tls")]
122            ProxyStream::Secured(s) => s.inner().get_ref().inner().connected().proxy(true),
123        }
124    }
125}