hyper_util/client/legacy/connect/proxy/
tunnel.rs1use std::error::Error as StdError;
2use std::future::Future;
3use std::marker::{PhantomData, Unpin};
4use std::pin::Pin;
5use std::task::{self, ready, Poll};
6
7use http::{HeaderMap, HeaderValue, Uri};
8use hyper::rt::{Read, Write};
9use pin_project_lite::pin_project;
10use tower_service::Service;
11
12#[derive(Debug, Clone)]
18pub struct Tunnel<C> {
19 headers: Headers,
20 inner: C,
21 proxy_dst: Uri,
22}
23
24#[derive(Clone, Debug)]
25enum Headers {
26 Empty,
27 Auth(HeaderValue),
28 Extra(HeaderMap),
29}
30
31#[derive(Debug)]
32pub enum TunnelError {
33 ConnectFailed(Box<dyn StdError + Send + Sync>),
34 Io(std::io::Error),
35 MissingHost,
36 ProxyAuthRequired,
37 ProxyHeadersTooLong,
38 TunnelUnexpectedEof,
39 TunnelUnsuccessful,
40}
41
42pin_project! {
43 #[must_use = "futures do nothing unless polled"]
49 #[allow(missing_debug_implementations)]
50 pub struct Tunneling<F, T> {
51 #[pin]
52 fut: BoxTunneling<T>,
53 _marker: PhantomData<F>,
54 }
55}
56
57type BoxTunneling<T> = Pin<Box<dyn Future<Output = Result<T, TunnelError>> + Send>>;
58
59impl<C> Tunnel<C> {
60 pub fn new(proxy_dst: Uri, connector: C) -> Self {
69 Self {
70 headers: Headers::Empty,
71 inner: connector,
72 proxy_dst,
73 }
74 }
75
76 pub fn with_auth(mut self, mut auth: HeaderValue) -> Self {
78 auth.set_sensitive(true);
80 match self.headers {
81 Headers::Empty => {
82 self.headers = Headers::Auth(auth);
83 }
84 Headers::Auth(ref mut existing) => {
85 *existing = auth;
86 }
87 Headers::Extra(ref mut extra) => {
88 extra.insert(http::header::PROXY_AUTHORIZATION, auth);
89 }
90 }
91
92 self
93 }
94
95 pub fn with_headers(mut self, mut headers: HeaderMap) -> Self {
99 match self.headers {
100 Headers::Empty => {
101 self.headers = Headers::Extra(headers);
102 }
103 Headers::Auth(auth) => {
104 headers
105 .entry(http::header::PROXY_AUTHORIZATION)
106 .or_insert(auth);
107 self.headers = Headers::Extra(headers);
108 }
109 Headers::Extra(ref mut extra) => {
110 extra.extend(headers);
111 }
112 }
113
114 self
115 }
116}
117
118impl<C> Service<Uri> for Tunnel<C>
119where
120 C: Service<Uri>,
121 C::Future: Send + 'static,
122 C::Response: Read + Write + Unpin + Send + 'static,
123 C::Error: Into<Box<dyn StdError + Send + Sync>>,
124{
125 type Response = C::Response;
126 type Error = TunnelError;
127 type Future = Tunneling<C::Future, C::Response>;
128
129 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
130 ready!(self.inner.poll_ready(cx)).map_err(|e| TunnelError::ConnectFailed(e.into()))?;
131 Poll::Ready(Ok(()))
132 }
133
134 fn call(&mut self, dst: Uri) -> Self::Future {
135 let connecting = self.inner.call(self.proxy_dst.clone());
136 let headers = self.headers.clone();
137
138 Tunneling {
139 fut: Box::pin(async move {
140 let conn = connecting
141 .await
142 .map_err(|e| TunnelError::ConnectFailed(e.into()))?;
143 tunnel(
144 conn,
145 dst.host().ok_or(TunnelError::MissingHost)?,
146 dst.port().map(|p| p.as_u16()).unwrap_or(443),
147 &headers,
148 )
149 .await
150 }),
151 _marker: PhantomData,
152 }
153 }
154}
155
156impl<F, T, E> Future for Tunneling<F, T>
157where
158 F: Future<Output = Result<T, E>>,
159{
160 type Output = Result<T, TunnelError>;
161
162 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
163 self.project().fut.poll(cx)
164 }
165}
166
167async fn tunnel<T>(mut conn: T, host: &str, port: u16, headers: &Headers) -> Result<T, TunnelError>
168where
169 T: Read + Write + Unpin,
170{
171 let mut buf = format!(
172 "\
173 CONNECT {host}:{port} HTTP/1.1\r\n\
174 Host: {host}:{port}\r\n\
175 "
176 )
177 .into_bytes();
178
179 match headers {
180 Headers::Auth(auth) => {
181 buf.extend_from_slice(b"Proxy-Authorization: ");
182 buf.extend_from_slice(auth.as_bytes());
183 buf.extend_from_slice(b"\r\n");
184 }
185 Headers::Extra(extra) => {
186 for (name, value) in extra {
187 buf.extend_from_slice(name.as_str().as_bytes());
188 buf.extend_from_slice(b": ");
189 buf.extend_from_slice(value.as_bytes());
190 buf.extend_from_slice(b"\r\n");
191 }
192 }
193 Headers::Empty => (),
194 }
195
196 buf.extend_from_slice(b"\r\n");
198
199 crate::rt::write_all(&mut conn, &buf)
200 .await
201 .map_err(TunnelError::Io)?;
202
203 let mut buf = [0; 8192];
204 let mut pos = 0;
205
206 loop {
207 let n = crate::rt::read(&mut conn, &mut buf[pos..])
208 .await
209 .map_err(TunnelError::Io)?;
210
211 if n == 0 {
212 return Err(TunnelError::TunnelUnexpectedEof);
213 }
214 pos += n;
215
216 let recvd = &buf[..pos];
217 if recvd.starts_with(b"HTTP/1.1 200") || recvd.starts_with(b"HTTP/1.0 200") {
218 if recvd.ends_with(b"\r\n\r\n") {
219 return Ok(conn);
220 }
221 if pos == buf.len() {
222 return Err(TunnelError::ProxyHeadersTooLong);
223 }
224 } else if recvd.starts_with(b"HTTP/1.1 407") {
226 return Err(TunnelError::ProxyAuthRequired);
227 } else {
228 return Err(TunnelError::TunnelUnsuccessful);
229 }
230 }
231}
232
233impl std::fmt::Display for TunnelError {
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 f.write_str("tunnel error: ")?;
236
237 f.write_str(match self {
238 TunnelError::MissingHost => "missing destination host",
239 TunnelError::ProxyAuthRequired => "proxy authorization required",
240 TunnelError::ProxyHeadersTooLong => "proxy response headers too long",
241 TunnelError::TunnelUnexpectedEof => "unexpected end of file",
242 TunnelError::TunnelUnsuccessful => "unsuccessful",
243 TunnelError::ConnectFailed(_) => "failed to create underlying connection",
244 TunnelError::Io(_) => "io error establishing tunnel",
245 })
246 }
247}
248
249impl std::error::Error for TunnelError {
250 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
251 match self {
252 TunnelError::Io(ref e) => Some(e),
253 TunnelError::ConnectFailed(ref e) => Some(&**e),
254 _ => None,
255 }
256 }
257}