1#![allow(missing_docs)]
60
61mod rt;
62mod stream;
63mod tunnel;
64
65use std::{fmt, io, sync::Arc};
66use std::{
67 future::Future,
68 pin::Pin,
69 task::{Context, Poll},
70};
71
72use futures_util::future::TryFutureExt;
73use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
74use http::header::{HeaderMap, HeaderName, HeaderValue};
75use hyper::rt::{Read, Write};
76use hyper::Uri;
77use tower_service::Service;
78
79pub use stream::ProxyStream;
80
81#[cfg(feature = "tls")]
82use native_tls::TlsConnector as NativeTlsConnector;
83
84#[cfg(feature = "tls")]
85use tokio_native_tls::TlsConnector;
86
87#[cfg(feature = "rustls-base")]
88use hyper_rustls::ConfigBuilderExt;
89
90#[cfg(feature = "rustls-base")]
91use tokio_rustls::TlsConnector;
92
93#[cfg(feature = "rustls-base")]
94use tokio_rustls::rustls::pki_types::ServerName;
95
96#[cfg(feature = "openssl-tls")]
97use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod};
98
99#[cfg(feature = "openssl-tls")]
100use tokio_openssl::SslStream;
101
102type BoxError = Box<dyn std::error::Error + Send + Sync>;
103
104#[derive(Debug, Clone)]
106pub enum Intercept {
107 All,
109 Http,
111 Https,
113 None,
115 Custom(Custom),
117}
118
119pub trait Dst {
121 fn scheme(&self) -> Option<&str>;
123 fn host(&self) -> Option<&str>;
125 fn port(&self) -> Option<u16>;
127}
128
129impl Dst for Uri {
130 fn scheme(&self) -> Option<&str> {
131 self.scheme_str()
132 }
133
134 fn host(&self) -> Option<&str> {
135 self.host()
136 }
137
138 fn port(&self) -> Option<u16> {
139 self.port_u16()
140 }
141}
142
143#[inline]
144pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
145 io::Error::new(io::ErrorKind::Other, e)
146}
147
148#[derive(Clone)]
150pub struct Custom(Arc<dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync>);
151
152impl fmt::Debug for Custom {
153 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
154 write!(f, "_")
155 }
156}
157
158impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
159 for Custom
160{
161 fn from(f: F) -> Custom {
162 Custom(Arc::new(f))
163 }
164}
165
166impl Intercept {
167 pub fn matches<D: Dst>(&self, uri: &D) -> bool {
169 match (self, uri.scheme()) {
170 (&Intercept::All, _)
171 | (&Intercept::Http, Some("http"))
172 | (&Intercept::Https, Some("https")) => true,
173 (&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
174 _ => false,
175 }
176 }
177}
178
179impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
180 for Intercept
181{
182 fn from(f: F) -> Intercept {
183 Intercept::Custom(f.into())
184 }
185}
186
187#[derive(Clone, Debug)]
189pub struct Proxy {
190 intercept: Intercept,
191 force_connect: bool,
192 headers: HeaderMap,
193 uri: Uri,
194}
195
196impl Proxy {
197 pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
199 Proxy {
200 intercept: intercept.into(),
201 uri: uri,
202 headers: HeaderMap::new(),
203 force_connect: false,
204 }
205 }
206
207 pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
209 match self.intercept {
210 Intercept::Http => {
211 self.headers.typed_insert(Authorization(credentials.0));
212 }
213 Intercept::Https => {
214 self.headers.typed_insert(ProxyAuthorization(credentials.0));
215 }
216 _ => {
217 self.headers
218 .typed_insert(Authorization(credentials.0.clone()));
219 self.headers.typed_insert(ProxyAuthorization(credentials.0));
220 }
221 }
222 }
223
224 pub fn force_connect(&mut self) {
226 self.force_connect = true;
227 }
228
229 pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
231 self.headers.insert(name, value);
232 }
233
234 pub fn intercept(&self) -> &Intercept {
236 &self.intercept
237 }
238
239 pub fn headers(&self) -> &HeaderMap {
241 &self.headers
242 }
243
244 pub fn uri(&self) -> &Uri {
246 &self.uri
247 }
248}
249
250#[derive(Clone)]
252pub struct ProxyConnector<C> {
253 proxies: Vec<Proxy>,
254 connector: C,
255
256 #[cfg(feature = "tls")]
257 tls: Option<NativeTlsConnector>,
258
259 #[cfg(feature = "rustls-base")]
260 tls: Option<TlsConnector>,
261
262 #[cfg(feature = "openssl-tls")]
263 tls: Option<OpenSslConnector>,
264
265 #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
266 tls: Option<()>,
267}
268
269impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
270 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
271 write!(
272 f,
273 "ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
274 if self.tls.is_some() {
275 ""
276 } else {
277 "(unsecured)"
278 },
279 self.proxies,
280 self.connector
281 )
282 }
283}
284
285impl<C> ProxyConnector<C> {
286 #[cfg(feature = "tls")]
288 pub fn new(connector: C) -> Result<Self, io::Error> {
289 let tls = NativeTlsConnector::builder()
290 .build()
291 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
292
293 Ok(ProxyConnector {
294 proxies: Vec::new(),
295 connector: connector,
296 tls: Some(tls),
297 })
298 }
299
300 #[cfg(feature = "rustls-base")]
302 pub fn new(connector: C) -> Result<Self, io::Error> {
303 let config = tokio_rustls::rustls::ClientConfig::builder();
304
305 #[cfg(feature = "rustls")]
306 let config = config.with_native_roots()?;
307
308 #[cfg(feature = "rustls-webpki")]
309 let config = config.with_webpki_roots();
310
311 let cfg = Arc::new(config.with_no_client_auth());
312 let tls = TlsConnector::from(cfg);
313
314 Ok(ProxyConnector {
315 proxies: Vec::new(),
316 connector,
317 tls: Some(tls),
318 })
319 }
320
321 #[allow(missing_docs)]
322 #[cfg(feature = "openssl-tls")]
323 pub fn new(connector: C) -> Result<Self, io::Error> {
324 let builder = OpenSslConnector::builder(SslMethod::tls())
325 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
326 let tls = builder.build();
327
328 Ok(ProxyConnector {
329 proxies: Vec::new(),
330 connector: connector,
331 tls: Some(tls),
332 })
333 }
334
335 pub fn unsecured(connector: C) -> Self {
337 ProxyConnector {
338 proxies: Vec::new(),
339 connector: connector,
340 tls: None,
341 }
342 }
343
344 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
346 pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
347 let mut c = ProxyConnector::new(connector)?;
348 c.proxies.push(proxy);
349 Ok(c)
350 }
351
352 pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
354 let mut c = ProxyConnector::unsecured(connector);
355 c.proxies.push(proxy);
356 c
357 }
358
359 pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
361 ProxyConnector {
362 connector: connector,
363 proxies: self.proxies,
364 tls: self.tls,
365 }
366 }
367
368 #[cfg(any(feature = "tls"))]
370 pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
371 self.tls = tls;
372 }
373
374 #[cfg(any(feature = "rustls-base"))]
376 pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
377 self.tls = tls;
378 }
379
380 #[cfg(any(feature = "openssl-tls"))]
382 pub fn set_tls(&mut self, tls: Option<OpenSslConnector>) {
383 self.tls = tls;
384 }
385
386 pub fn proxies(&self) -> &[Proxy] {
388 &self.proxies
389 }
390
391 pub fn add_proxy(&mut self, proxy: Proxy) {
393 self.proxies.push(proxy);
394 }
395
396 pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
398 self.proxies.extend(proxies)
399 }
400
401 pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
406 if uri.scheme_str().map_or(true, |s| s != "http") {
407 return None;
408 }
409
410 self.match_proxy(uri).map(|p| &p.headers)
411 }
412
413 fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
414 self.proxies.iter().find(|p| p.intercept.matches(uri))
415 }
416}
417
418macro_rules! mtry {
419 ($e:expr) => {
420 match $e {
421 Ok(v) => v,
422 Err(e) => break Err(e.into()),
423 }
424 };
425}
426
427impl<C> Service<Uri> for ProxyConnector<C>
428where
429 C: Service<Uri>,
430 C::Response: Read + Write + Send + Unpin + 'static,
431 C::Future: Send + 'static,
432 C::Error: Into<BoxError>,
433{
434 type Response = ProxyStream<C::Response>;
435 type Error = io::Error;
436 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
437
438 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
439 match self.connector.poll_ready(cx) {
440 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
441 Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
442 Poll::Pending => Poll::Pending,
443 }
444 }
445
446 fn call(&mut self, uri: Uri) -> Self::Future {
447 if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
448 if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
449 let host = host.to_owned();
450 let port =
451 uri.port_u16()
452 .unwrap_or(if uri.scheme() == Some(&http::uri::Scheme::HTTP) {
453 80
454 } else {
455 443
456 });
457
458 let tunnel = tunnel::new(&host, port, &p.headers);
459 let connection =
460 proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
461 let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
462 self.tls.clone()
463 } else {
464 None
465 };
466
467 Box::pin(async move {
468 loop {
469 let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
471 let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
472
473 break match tls {
474 #[cfg(feature = "tls")]
475 Some(tls) => {
476 use hyper_util::rt::TokioIo;
477 let tls = TlsConnector::from(tls);
478 let secure_stream = mtry!(tls
479 .connect(&host, TokioIo::new(tunnel_stream))
480 .await
481 .map_err(io_err));
482
483 Ok(ProxyStream::Secured(TokioIo::new(secure_stream)))
484 }
485
486 #[cfg(feature = "rustls-base")]
487 Some(tls) => {
488 use hyper_util::rt::TokioIo;
489 let server_name =
490 mtry!(ServerName::try_from(host.to_string()).map_err(io_err));
491 let tls = TlsConnector::from(tls);
492 let secure_stream = mtry!(tls
493 .connect(server_name, TokioIo::new(tunnel_stream))
494 .await
495 .map_err(io_err));
496
497 Ok(ProxyStream::Secured(TokioIo::new(secure_stream)))
498 }
499
500 #[cfg(feature = "openssl-tls")]
501 Some(tls) => {
502 use hyper_util::rt::TokioIo;
503 let config = tls.configure().map_err(io_err)?;
504 let ssl = config.into_ssl(&host).map_err(io_err)?;
505
506 let mut stream =
507 mtry!(SslStream::new(ssl, TokioIo::new(tunnel_stream)));
508 mtry!(Pin::new(&mut stream).connect().await.map_err(io_err));
509
510 Ok(ProxyStream::Secured(TokioIo::new(stream)))
511 }
512
513 #[cfg(not(any(
514 feature = "tls",
515 feature = "rustls-base",
516 feature = "openssl-tls"
517 )))]
518 Some(_) => panic!("hyper-proxy was not built with TLS support"),
519
520 None => Ok(ProxyStream::Regular(tunnel_stream)),
521 };
522 }
523 })
524 } else {
525 match proxy_dst(&uri, &p.uri) {
526 Ok(proxy_uri) => Box::pin(
527 self.connector
528 .call(proxy_uri)
529 .map_ok(ProxyStream::Regular)
530 .map_err(|err| io_err(err.into())),
531 ),
532 Err(err) => Box::pin(futures_util::future::err(io_err(err))),
533 }
534 }
535 } else {
536 Box::pin(
537 self.connector
538 .call(uri)
539 .map_ok(ProxyStream::NoProxy)
540 .map_err(|err| io_err(err.into())),
541 )
542 }
543 }
544}
545
546fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
547 Uri::builder()
548 .scheme(
549 proxy
550 .scheme_str()
551 .ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
552 )
553 .authority(
554 proxy
555 .authority()
556 .ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
557 .clone(),
558 )
559 .path_and_query(dst.path_and_query().unwrap().clone())
560 .build()
561 .map_err(|err| io_err(format!("other error: {}", err)))
562}