hyper_http_proxy/
stream.rs1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use hyper::rt::{Read, ReadBufCursor, Write};
6use hyper_util::client::legacy::connect::{Connected, Connection};
7
8#[cfg(feature = "__tls")]
9use hyper_util::rt::TokioIo;
10
11#[cfg(feature = "__rustls")]
12use tokio_rustls::client::TlsStream as RustlsStream;
13
14#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
15use tokio_native_tls::TlsStream as TokioNativeTlsStream;
16
17#[cfg(feature = "__rustls")]
18pub type TlsStream<R> = TokioIo<RustlsStream<TokioIo<R>>>;
19
20#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
21pub type TlsStream<R> = TokioIo<TokioNativeTlsStream<TokioIo<R>>>;
22
23pub enum ProxyStream<R> {
25 NoProxy(R),
26 Regular(R),
27 #[cfg(feature = "__tls")]
28 Secured(Box<TlsStream<R>>),
29}
30
31macro_rules! match_fn_pinned {
32 ($self:expr, $fn:ident, $ctx:expr, $buf:expr) => {
33 match $self.get_mut() {
34 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx, $buf),
35 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx, $buf),
36 #[cfg(feature = "__tls")]
37 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx, $buf),
38 }
39 };
40
41 ($self:expr, $fn:ident, $ctx:expr) => {
42 match $self.get_mut() {
43 ProxyStream::NoProxy(s) => Pin::new(s).$fn($ctx),
44 ProxyStream::Regular(s) => Pin::new(s).$fn($ctx),
45 #[cfg(feature = "__tls")]
46 ProxyStream::Secured(s) => Pin::new(s).$fn($ctx),
47 }
48 };
49}
50
51impl<R: Read + Write + Unpin> Read for ProxyStream<R> {
52 fn poll_read(
53 self: Pin<&mut Self>,
54 cx: &mut Context<'_>,
55 buf: ReadBufCursor<'_>,
56 ) -> Poll<io::Result<()>> {
57 match_fn_pinned!(self, poll_read, cx, buf)
58 }
59}
60
61impl<R: Read + Write + Unpin> Write for ProxyStream<R> {
62 fn poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<io::Result<usize>> {
67 match_fn_pinned!(self, poll_write, cx, buf)
68 }
69
70 fn poll_write_vectored(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 bufs: &[io::IoSlice<'_>],
74 ) -> Poll<Result<usize, io::Error>> {
75 match_fn_pinned!(self, poll_write_vectored, cx, bufs)
76 }
77
78 fn is_write_vectored(&self) -> bool {
79 match self {
80 ProxyStream::NoProxy(s) => s.is_write_vectored(),
81 ProxyStream::Regular(s) => s.is_write_vectored(),
82 #[cfg(feature = "__tls")]
83 ProxyStream::Secured(s) => s.is_write_vectored(),
84 }
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 match_fn_pinned!(self, poll_flush, cx)
89 }
90
91 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
92 match_fn_pinned!(self, poll_shutdown, cx)
93 }
94}
95
96impl<R: Read + Write + Connection + Unpin> Connection for ProxyStream<R> {
97 fn connected(&self) -> Connected {
98 match self {
99 ProxyStream::NoProxy(s) => s.connected(),
100
101 ProxyStream::Regular(s) => s.connected().proxy(true),
102 #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
103 ProxyStream::Secured(s) => s
104 .inner()
105 .get_ref()
106 .get_ref()
107 .get_ref()
108 .inner()
109 .connected()
110 .proxy(true),
111
112 #[cfg(feature = "__rustls")]
113 ProxyStream::Secured(s) => s.inner().get_ref().0.inner().connected().proxy(true),
114 }
115 }
116}