1#![allow(missing_docs)]
56
57mod stream;
58mod tunnel;
59
60use http::header::{HeaderMap, HeaderName, HeaderValue};
61use hyper::{service::Service, Uri};
62
63use futures::future::TryFutureExt;
64use std::{fmt, io, sync::Arc};
65use std::{
66 future::Future,
67 pin::Pin,
68 task::{Context, Poll},
69};
70
71pub use stream::ProxyStream;
72use tokio::io::{AsyncRead, AsyncWrite};
73
74#[cfg(feature = "tls")]
75use native_tls::TlsConnector as NativeTlsConnector;
76
77#[cfg(feature = "tls")]
78use tokio_native_tls::TlsConnector;
79#[cfg(feature = "rustls-base")]
80use tokio_rustls::TlsConnector;
81
82use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
83#[cfg(feature = "openssl-tls")]
84use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod};
85#[cfg(feature = "openssl-tls")]
86use tokio_openssl::SslStream;
87#[cfg(feature = "rustls-base")]
88use webpki::DNSNameRef;
89
90type BoxError = Box<dyn std::error::Error + Send + Sync>;
91
92#[derive(Debug, Clone)]
94pub enum Intercept {
95 All,
97 Http,
99 Https,
101 None,
103 Custom(Custom),
105}
106
107pub trait Dst {
109 fn scheme(&self) -> Option<&str>;
111 fn host(&self) -> Option<&str>;
113 fn port(&self) -> Option<u16>;
115}
116
117impl Dst for Uri {
118 fn scheme(&self) -> Option<&str> {
119 self.scheme_str()
120 }
121
122 fn host(&self) -> Option<&str> {
123 self.host()
124 }
125
126 fn port(&self) -> Option<u16> {
127 self.port_u16()
128 }
129}
130
131#[inline]
132pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
133 io::Error::new(io::ErrorKind::Other, e)
134}
135
136#[derive(Clone)]
138pub struct Custom(Arc<dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync>);
139
140impl fmt::Debug for Custom {
141 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
142 write!(f, "_")
143 }
144}
145
146impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
147 for Custom
148{
149 fn from(f: F) -> Custom {
150 Custom(Arc::new(f))
151 }
152}
153
154impl Intercept {
155 pub fn matches<D: Dst>(&self, uri: &D) -> bool {
157 match (self, uri.scheme()) {
158 (&Intercept::All, _)
159 | (&Intercept::Http, Some("http"))
160 | (&Intercept::Https, Some("https")) => true,
161 (&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
162 _ => false,
163 }
164 }
165}
166
167impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
168 for Intercept
169{
170 fn from(f: F) -> Intercept {
171 Intercept::Custom(f.into())
172 }
173}
174
175#[derive(Clone, Debug)]
177pub struct Proxy {
178 intercept: Intercept,
179 force_connect: bool,
180 headers: HeaderMap,
181 uri: Uri,
182}
183
184impl Proxy {
185 pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
187 Proxy {
188 intercept: intercept.into(),
189 uri: uri,
190 headers: HeaderMap::new(),
191 force_connect: false,
192 }
193 }
194
195 pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
197 match self.intercept {
198 Intercept::Http => {
199 self.headers.typed_insert(Authorization(credentials.0));
200 }
201 Intercept::Https => {
202 self.headers.typed_insert(ProxyAuthorization(credentials.0));
203 }
204 _ => {
205 self.headers
206 .typed_insert(Authorization(credentials.0.clone()));
207 self.headers.typed_insert(ProxyAuthorization(credentials.0));
208 }
209 }
210 }
211
212 pub fn force_connect(&mut self) {
214 self.force_connect = true;
215 }
216
217 pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
219 self.headers.insert(name, value);
220 }
221
222 pub fn intercept(&self) -> &Intercept {
224 &self.intercept
225 }
226
227 pub fn headers(&self) -> &HeaderMap {
229 &self.headers
230 }
231
232 pub fn uri(&self) -> &Uri {
234 &self.uri
235 }
236}
237
238#[derive(Clone)]
240pub struct ProxyConnector<C> {
241 proxies: Vec<Proxy>,
242 connector: C,
243
244 #[cfg(feature = "tls")]
245 tls: Option<NativeTlsConnector>,
246
247 #[cfg(feature = "rustls-base")]
248 tls: Option<TlsConnector>,
249
250 #[cfg(feature = "openssl-tls")]
251 tls: Option<OpenSslConnector>,
252
253 #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
254 tls: Option<()>,
255}
256
257impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
258 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
259 write!(
260 f,
261 "ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
262 if self.tls.is_some() {
263 ""
264 } else {
265 "(unsecured)"
266 },
267 self.proxies,
268 self.connector
269 )
270 }
271}
272
273impl<C> ProxyConnector<C> {
274 #[cfg(feature = "tls")]
276 pub fn new(connector: C) -> Result<Self, io::Error> {
277 let tls = NativeTlsConnector::builder()
278 .build()
279 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
280
281 Ok(ProxyConnector {
282 proxies: Vec::new(),
283 connector: connector,
284 tls: Some(tls),
285 })
286 }
287
288 #[cfg(feature = "rustls-base")]
290 pub fn new(connector: C) -> Result<Self, io::Error> {
291 let mut config = tokio_rustls::rustls::ClientConfig::new();
292
293 #[cfg(feature = "rustls")]
294 {
295 config.root_store =
296 rustls_native_certs::load_native_certs().map_err(|(_store, io)| io)?;
297 }
298
299 #[cfg(feature = "rustls-webpki")]
300 {
301 config
302 .root_store
303 .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
304 }
305
306 let cfg = Arc::new(config);
307 let tls = TlsConnector::from(cfg);
308
309 Ok(ProxyConnector {
310 proxies: Vec::new(),
311 connector: connector,
312 tls: Some(tls),
313 })
314 }
315
316 #[allow(missing_docs)]
317 #[cfg(feature = "openssl-tls")]
318 pub fn new(connector: C) -> Result<Self, io::Error> {
319 let builder = OpenSslConnector::builder(SslMethod::tls())
320 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
321 let tls = builder.build();
322
323 Ok(ProxyConnector {
324 proxies: Vec::new(),
325 connector: connector,
326 tls: Some(tls),
327 })
328 }
329
330 pub fn unsecured(connector: C) -> Self {
332 ProxyConnector {
333 proxies: Vec::new(),
334 connector: connector,
335 tls: None,
336 }
337 }
338
339 #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
341 pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
342 let mut c = ProxyConnector::new(connector)?;
343 c.proxies.push(proxy);
344 Ok(c)
345 }
346
347 pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
349 let mut c = ProxyConnector::unsecured(connector);
350 c.proxies.push(proxy);
351 c
352 }
353
354 pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
356 ProxyConnector {
357 connector: connector,
358 proxies: self.proxies,
359 tls: self.tls,
360 }
361 }
362
363 #[cfg(any(feature = "tls"))]
365 pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
366 self.tls = tls;
367 }
368
369 #[cfg(any(feature = "rustls-base"))]
371 pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
372 self.tls = tls;
373 }
374
375 #[cfg(any(feature = "openssl-tls"))]
377 pub fn set_tls(&mut self, tls: Option<OpenSslConnector>) {
378 self.tls = tls;
379 }
380
381 pub fn proxies(&self) -> &[Proxy] {
383 &self.proxies
384 }
385
386 pub fn add_proxy(&mut self, proxy: Proxy) {
388 self.proxies.push(proxy);
389 }
390
391 pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
393 self.proxies.extend(proxies)
394 }
395
396 pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
401 if uri.scheme_str().map_or(true, |s| s != "http") {
402 return None;
403 }
404
405 self.match_proxy(uri).map(|p| &p.headers)
406 }
407
408 fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
409 self.proxies.iter().find(|p| p.intercept.matches(uri))
410 }
411}
412
413macro_rules! mtry {
414 ($e:expr) => {
415 match $e {
416 Ok(v) => v,
417 Err(e) => break Err(e.into()),
418 }
419 };
420}
421
422impl<C> Service<Uri> for ProxyConnector<C>
423where
424 C: Service<Uri>,
425 C::Response: AsyncRead + AsyncWrite + Send + Unpin + 'static,
426 C::Future: Send + 'static,
427 C::Error: Into<BoxError>,
428{
429 type Response = ProxyStream<C::Response>;
430 type Error = io::Error;
431 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
432
433 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
434 match self.connector.poll_ready(cx) {
435 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
436 Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
437 Poll::Pending => Poll::Pending,
438 }
439 }
440
441 fn call(&mut self, uri: Uri) -> Self::Future {
442 if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
443 if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
444 let host = host.to_owned();
445 let port = uri.port_u16().unwrap_or(443);
446 let tunnel = tunnel::new(&host, port, &p.headers);
447 let connection =
448 proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
449 let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
450 self.tls.clone()
451 } else {
452 None
453 };
454
455 Box::pin(async move {
456 loop {
457 let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
459 let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
460
461 break match tls {
462 #[cfg(feature = "tls")]
463 Some(tls) => {
464 let tls = TlsConnector::from(tls);
465 let secure_stream =
466 mtry!(tls.connect(&host, tunnel_stream).await.map_err(io_err));
467
468 Ok(ProxyStream::Secured(secure_stream))
469 }
470
471 #[cfg(feature = "rustls-base")]
472 Some(tls) => {
473 let dnsref =
474 mtry!(DNSNameRef::try_from_ascii_str(&host).map_err(io_err));
475 let tls = TlsConnector::from(tls);
476 let secure_stream =
477 mtry!(tls.connect(dnsref, tunnel_stream).await.map_err(io_err));
478
479 Ok(ProxyStream::Secured(secure_stream))
480 }
481
482 #[cfg(feature = "openssl-tls")]
483 Some(tls) => {
484 let config = tls.configure().map_err(io_err)?;
485 let ssl = config.into_ssl(&host).map_err(io_err)?;
486
487 let mut stream = mtry!(SslStream::new(ssl, tunnel_stream));
488 mtry!(Pin::new(&mut stream).connect().await.map_err(io_err));
489
490 Ok(ProxyStream::Secured(stream))
491 }
492
493 #[cfg(not(any(
494 feature = "tls",
495 feature = "rustls-base",
496 feature = "openssl-tls"
497 )))]
498 Some(_) => panic!("hyper-proxy was not built with TLS support"),
499
500 None => Ok(ProxyStream::Regular(tunnel_stream)),
501 };
502 }
503 })
504 } else {
505 match proxy_dst(&uri, &p.uri) {
506 Ok(proxy_uri) => Box::pin(
507 self.connector
508 .call(proxy_uri)
509 .map_ok(ProxyStream::Regular)
510 .map_err(|err| io_err(err.into())),
511 ),
512 Err(err) => Box::pin(futures::future::err(io_err(err))),
513 }
514 }
515 } else {
516 Box::pin(
517 self.connector
518 .call(uri)
519 .map_ok(ProxyStream::NoProxy)
520 .map_err(|err| io_err(err.into())),
521 )
522 }
523 }
524}
525
526fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
527 Uri::builder()
528 .scheme(
529 proxy
530 .scheme_str()
531 .ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
532 )
533 .authority(
534 proxy
535 .authority()
536 .ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
537 .clone(),
538 )
539 .path_and_query(dst.path_and_query().unwrap().clone())
540 .build()
541 .map_err(|err| io_err(format!("other error: {}", err)))
542}