hyper_proxy/
lib.rs

1//! A Proxy Connector crate for Hyper based applications
2//!
3//! # Example
4//! ```rust,no_run
5//! use hyper::{Client, Request, Uri, body::HttpBody};
6//! use hyper::client::HttpConnector;
7//! use futures::{TryFutureExt, TryStreamExt};
8//! use hyper_proxy::{Proxy, ProxyConnector, Intercept};
9//! use headers::Authorization;
10//! use std::error::Error;
11//! use tokio::io::{stdout, AsyncWriteExt as _};
12//!
13//! #[tokio::main]
14//! async fn main() -> Result<(), Box<dyn Error>> {
15//!     let proxy = {
16//!         let proxy_uri = "http://my-proxy:8080".parse().unwrap();
17//!         let mut proxy = Proxy::new(Intercept::All, proxy_uri);
18//!         proxy.set_authorization(Authorization::basic("John Doe", "Agent1234"));
19//!         let connector = HttpConnector::new();
20//!         # #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
21//!         # let proxy_connector = ProxyConnector::from_proxy_unsecured(connector, proxy);
22//!         # #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl"))]
23//!         let proxy_connector = ProxyConnector::from_proxy(connector, proxy).unwrap();
24//!         proxy_connector
25//!     };
26//!
27//!     // Connecting to http will trigger regular GETs and POSTs.
28//!     // We need to manually append the relevant headers to the request
29//!     let uri: Uri = "http://my-remote-website.com".parse().unwrap();
30//!     let mut req = Request::get(uri.clone()).body(hyper::Body::empty()).unwrap();
31//!
32//!     if let Some(headers) = proxy.http_headers(&uri) {
33//!         req.headers_mut().extend(headers.clone().into_iter());
34//!     }
35//!
36//!     let client = Client::builder().build(proxy);
37//!     let mut resp = client.request(req).await?;
38//!     println!("Response: {}", resp.status());
39//!     while let Some(chunk) = resp.body_mut().data().await {
40//!         stdout().write_all(&chunk?).await?;
41//!     }
42//!
43//!     // Connecting to an https uri is straightforward (uses 'CONNECT' method underneath)
44//!     let uri = "https://my-remote-websitei-secured.com".parse().unwrap();
45//!     let mut resp = client.get(uri).await?;
46//!     println!("Response: {}", resp.status());
47//!     while let Some(chunk) = resp.body_mut().data().await {
48//!         stdout().write_all(&chunk?).await?;
49//!     }
50//!
51//!     Ok(())
52//! }
53//! ```
54
55#![allow(missing_docs)]
56
57mod stream;
58mod tunnel;
59
60use http::header::{HeaderMap, HeaderName, HeaderValue};
61use hyper::{service::Service, Uri};
62
63use futures::future::TryFutureExt;
64use std::{fmt, io, sync::Arc};
65use std::{
66    future::Future,
67    pin::Pin,
68    task::{Context, Poll},
69};
70
71pub use stream::ProxyStream;
72use tokio::io::{AsyncRead, AsyncWrite};
73
74#[cfg(feature = "tls")]
75use native_tls::TlsConnector as NativeTlsConnector;
76
77#[cfg(feature = "tls")]
78use tokio_native_tls::TlsConnector;
79#[cfg(feature = "rustls-base")]
80use tokio_rustls::TlsConnector;
81
82use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
83#[cfg(feature = "openssl-tls")]
84use openssl::ssl::{SslConnector as OpenSslConnector, SslMethod};
85#[cfg(feature = "openssl-tls")]
86use tokio_openssl::SslStream;
87#[cfg(feature = "rustls-base")]
88use webpki::DNSNameRef;
89
90type BoxError = Box<dyn std::error::Error + Send + Sync>;
91
92/// The Intercept enum to filter connections
93#[derive(Debug, Clone)]
94pub enum Intercept {
95    /// All incoming connection will go through proxy
96    All,
97    /// Only http connections will go through proxy
98    Http,
99    /// Only https connections will go through proxy
100    Https,
101    /// No connection will go through this proxy
102    None,
103    /// A custom intercept
104    Custom(Custom),
105}
106
107/// A trait for matching between Destination and Uri
108pub trait Dst {
109    /// Returns the connection scheme, e.g. "http" or "https"
110    fn scheme(&self) -> Option<&str>;
111    /// Returns the host of the connection
112    fn host(&self) -> Option<&str>;
113    /// Returns the port for the connection
114    fn port(&self) -> Option<u16>;
115}
116
117impl Dst for Uri {
118    fn scheme(&self) -> Option<&str> {
119        self.scheme_str()
120    }
121
122    fn host(&self) -> Option<&str> {
123        self.host()
124    }
125
126    fn port(&self) -> Option<u16> {
127        self.port_u16()
128    }
129}
130
131#[inline]
132pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
133    io::Error::new(io::ErrorKind::Other, e)
134}
135
136/// A Custom struct to proxy custom uris
137#[derive(Clone)]
138pub struct Custom(Arc<dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync>);
139
140impl fmt::Debug for Custom {
141    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
142        write!(f, "_")
143    }
144}
145
146impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
147    for Custom
148{
149    fn from(f: F) -> Custom {
150        Custom(Arc::new(f))
151    }
152}
153
154impl Intercept {
155    /// A function to check if given `Uri` is proxied
156    pub fn matches<D: Dst>(&self, uri: &D) -> bool {
157        match (self, uri.scheme()) {
158            (&Intercept::All, _)
159            | (&Intercept::Http, Some("http"))
160            | (&Intercept::Https, Some("https")) => true,
161            (&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
162            _ => false,
163        }
164    }
165}
166
167impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
168    for Intercept
169{
170    fn from(f: F) -> Intercept {
171        Intercept::Custom(f.into())
172    }
173}
174
175/// A Proxy strcut
176#[derive(Clone, Debug)]
177pub struct Proxy {
178    intercept: Intercept,
179    force_connect: bool,
180    headers: HeaderMap,
181    uri: Uri,
182}
183
184impl Proxy {
185    /// Create a new `Proxy`
186    pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
187        Proxy {
188            intercept: intercept.into(),
189            uri: uri,
190            headers: HeaderMap::new(),
191            force_connect: false,
192        }
193    }
194
195    /// Set `Proxy` authorization
196    pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
197        match self.intercept {
198            Intercept::Http => {
199                self.headers.typed_insert(Authorization(credentials.0));
200            }
201            Intercept::Https => {
202                self.headers.typed_insert(ProxyAuthorization(credentials.0));
203            }
204            _ => {
205                self.headers
206                    .typed_insert(Authorization(credentials.0.clone()));
207                self.headers.typed_insert(ProxyAuthorization(credentials.0));
208            }
209        }
210    }
211
212    /// Forces the use of the CONNECT method.
213    pub fn force_connect(&mut self) {
214        self.force_connect = true;
215    }
216
217    /// Set a custom header
218    pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
219        self.headers.insert(name, value);
220    }
221
222    /// Get current intercept
223    pub fn intercept(&self) -> &Intercept {
224        &self.intercept
225    }
226
227    /// Get current `Headers` which must be sent to proxy
228    pub fn headers(&self) -> &HeaderMap {
229        &self.headers
230    }
231
232    /// Get proxy uri
233    pub fn uri(&self) -> &Uri {
234        &self.uri
235    }
236}
237
238/// A wrapper around `Proxy`s with a connector.
239#[derive(Clone)]
240pub struct ProxyConnector<C> {
241    proxies: Vec<Proxy>,
242    connector: C,
243
244    #[cfg(feature = "tls")]
245    tls: Option<NativeTlsConnector>,
246
247    #[cfg(feature = "rustls-base")]
248    tls: Option<TlsConnector>,
249
250    #[cfg(feature = "openssl-tls")]
251    tls: Option<OpenSslConnector>,
252
253    #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
254    tls: Option<()>,
255}
256
257impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
258    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
259        write!(
260            f,
261            "ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
262            if self.tls.is_some() {
263                ""
264            } else {
265                "(unsecured)"
266            },
267            self.proxies,
268            self.connector
269        )
270    }
271}
272
273impl<C> ProxyConnector<C> {
274    /// Create a new secured Proxies
275    #[cfg(feature = "tls")]
276    pub fn new(connector: C) -> Result<Self, io::Error> {
277        let tls = NativeTlsConnector::builder()
278            .build()
279            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
280
281        Ok(ProxyConnector {
282            proxies: Vec::new(),
283            connector: connector,
284            tls: Some(tls),
285        })
286    }
287
288    /// Create a new secured Proxies
289    #[cfg(feature = "rustls-base")]
290    pub fn new(connector: C) -> Result<Self, io::Error> {
291        let mut config = tokio_rustls::rustls::ClientConfig::new();
292
293        #[cfg(feature = "rustls")]
294        {
295            config.root_store =
296                rustls_native_certs::load_native_certs().map_err(|(_store, io)| io)?;
297        }
298
299        #[cfg(feature = "rustls-webpki")]
300        {
301            config
302                .root_store
303                .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
304        }
305
306        let cfg = Arc::new(config);
307        let tls = TlsConnector::from(cfg);
308
309        Ok(ProxyConnector {
310            proxies: Vec::new(),
311            connector: connector,
312            tls: Some(tls),
313        })
314    }
315
316    #[allow(missing_docs)]
317    #[cfg(feature = "openssl-tls")]
318    pub fn new(connector: C) -> Result<Self, io::Error> {
319        let builder = OpenSslConnector::builder(SslMethod::tls())
320            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
321        let tls = builder.build();
322
323        Ok(ProxyConnector {
324            proxies: Vec::new(),
325            connector: connector,
326            tls: Some(tls),
327        })
328    }
329
330    /// Create a new unsecured Proxy
331    pub fn unsecured(connector: C) -> Self {
332        ProxyConnector {
333            proxies: Vec::new(),
334            connector: connector,
335            tls: None,
336        }
337    }
338
339    /// Create a proxy connector and attach a particular proxy
340    #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls"))]
341    pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
342        let mut c = ProxyConnector::new(connector)?;
343        c.proxies.push(proxy);
344        Ok(c)
345    }
346
347    /// Create a proxy connector and attach a particular proxy
348    pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
349        let mut c = ProxyConnector::unsecured(connector);
350        c.proxies.push(proxy);
351        c
352    }
353
354    /// Change proxy connector
355    pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
356        ProxyConnector {
357            connector: connector,
358            proxies: self.proxies,
359            tls: self.tls,
360        }
361    }
362
363    /// Set or unset tls when tunneling
364    #[cfg(any(feature = "tls"))]
365    pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
366        self.tls = tls;
367    }
368
369    /// Set or unset tls when tunneling
370    #[cfg(any(feature = "rustls-base"))]
371    pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
372        self.tls = tls;
373    }
374
375    /// Set or unset tls when tunneling
376    #[cfg(any(feature = "openssl-tls"))]
377    pub fn set_tls(&mut self, tls: Option<OpenSslConnector>) {
378        self.tls = tls;
379    }
380
381    /// Get the current proxies
382    pub fn proxies(&self) -> &[Proxy] {
383        &self.proxies
384    }
385
386    /// Add a new additional proxy
387    pub fn add_proxy(&mut self, proxy: Proxy) {
388        self.proxies.push(proxy);
389    }
390
391    /// Extend the list of proxies
392    pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
393        self.proxies.extend(proxies)
394    }
395
396    /// Get http headers for a matching uri
397    ///
398    /// These headers must be appended to the hyper Request for the proxy to work properly.
399    /// This is needed only for http requests.
400    pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
401        if uri.scheme_str().map_or(true, |s| s != "http") {
402            return None;
403        }
404
405        self.match_proxy(uri).map(|p| &p.headers)
406    }
407
408    fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
409        self.proxies.iter().find(|p| p.intercept.matches(uri))
410    }
411}
412
413macro_rules! mtry {
414    ($e:expr) => {
415        match $e {
416            Ok(v) => v,
417            Err(e) => break Err(e.into()),
418        }
419    };
420}
421
422impl<C> Service<Uri> for ProxyConnector<C>
423where
424    C: Service<Uri>,
425    C::Response: AsyncRead + AsyncWrite + Send + Unpin + 'static,
426    C::Future: Send + 'static,
427    C::Error: Into<BoxError>,
428{
429    type Response = ProxyStream<C::Response>;
430    type Error = io::Error;
431    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
432
433    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
434        match self.connector.poll_ready(cx) {
435            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
436            Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
437            Poll::Pending => Poll::Pending,
438        }
439    }
440
441    fn call(&mut self, uri: Uri) -> Self::Future {
442        if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
443            if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
444                let host = host.to_owned();
445                let port = uri.port_u16().unwrap_or(443);
446                let tunnel = tunnel::new(&host, port, &p.headers);
447                let connection =
448                    proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
449                let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
450                    self.tls.clone()
451                } else {
452                    None
453                };
454
455                Box::pin(async move {
456                    loop {
457                        // this hack will gone once `try_blocks` will eventually stabilized
458                        let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
459                        let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
460
461                        break match tls {
462                            #[cfg(feature = "tls")]
463                            Some(tls) => {
464                                let tls = TlsConnector::from(tls);
465                                let secure_stream =
466                                    mtry!(tls.connect(&host, tunnel_stream).await.map_err(io_err));
467
468                                Ok(ProxyStream::Secured(secure_stream))
469                            }
470
471                            #[cfg(feature = "rustls-base")]
472                            Some(tls) => {
473                                let dnsref =
474                                    mtry!(DNSNameRef::try_from_ascii_str(&host).map_err(io_err));
475                                let tls = TlsConnector::from(tls);
476                                let secure_stream =
477                                    mtry!(tls.connect(dnsref, tunnel_stream).await.map_err(io_err));
478
479                                Ok(ProxyStream::Secured(secure_stream))
480                            }
481
482                            #[cfg(feature = "openssl-tls")]
483                            Some(tls) => {
484                                let config = tls.configure().map_err(io_err)?;
485                                let ssl = config.into_ssl(&host).map_err(io_err)?;
486
487                                let mut stream = mtry!(SslStream::new(ssl, tunnel_stream));
488                                mtry!(Pin::new(&mut stream).connect().await.map_err(io_err));
489
490                                Ok(ProxyStream::Secured(stream))
491                            }
492
493                            #[cfg(not(any(
494                                feature = "tls",
495                                feature = "rustls-base",
496                                feature = "openssl-tls"
497                            )))]
498                            Some(_) => panic!("hyper-proxy was not built with TLS support"),
499
500                            None => Ok(ProxyStream::Regular(tunnel_stream)),
501                        };
502                    }
503                })
504            } else {
505                match proxy_dst(&uri, &p.uri) {
506                    Ok(proxy_uri) => Box::pin(
507                        self.connector
508                            .call(proxy_uri)
509                            .map_ok(ProxyStream::Regular)
510                            .map_err(|err| io_err(err.into())),
511                    ),
512                    Err(err) => Box::pin(futures::future::err(io_err(err))),
513                }
514            }
515        } else {
516            Box::pin(
517                self.connector
518                    .call(uri)
519                    .map_ok(ProxyStream::NoProxy)
520                    .map_err(|err| io_err(err.into())),
521            )
522        }
523    }
524}
525
526fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
527    Uri::builder()
528        .scheme(
529            proxy
530                .scheme_str()
531                .ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
532        )
533        .authority(
534            proxy
535                .authority()
536                .ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
537                .clone(),
538        )
539        .path_and_query(dst.path_and_query().unwrap().clone())
540        .build()
541        .map_err(|err| io_err(format!("other error: {}", err)))
542}