Skip to main content

rigetti_hyper_proxy/
lib.rs

1//! A Proxy Connector crate for Hyper based applications
2//!
3//! # Example
4//! ```rust,no_run
5//! use hyper::{Request, Uri, body::Body};
6//! use hyper_util::client::legacy::Client;
7//! use hyper_util::client::legacy::connect::HttpConnector;
8//! use hyper_util::rt::TokioExecutor;
9//! use bytes::Bytes;
10//! use futures_util::{TryFutureExt, TryStreamExt};
11//! use http_body_util::{BodyExt, Empty};
12//! use rigetti_hyper_proxy::{Proxy, ProxyConnector, Intercept};
13//! use headers::Authorization;
14//! use std::error::Error;
15//! use tokio::io::{stdout, AsyncWriteExt as _};
16//!
17//! #[tokio::main]
18//! async fn main() -> Result<(), Box<dyn Error>> {
19//! let proxy = {
20//!         let proxy_uri = "http://my-proxy:8080".parse().unwrap();
21//!         let mut proxy = Proxy::new(Intercept::All, proxy_uri);
22//!         proxy.set_authorization(Authorization::basic("John Doe", "Agent1234"));
23//!         let connector = HttpConnector::new();
24//!         # #[cfg(not(any(feature = "tls", feature = "rustls-base", feature = "openssl-tls")))]
25//!         # let proxy_connector = ProxyConnector::from_proxy_unsecured(connector, proxy);
26//!         # #[cfg(any(feature = "tls", feature = "rustls-base", feature = "openssl"))]
27//!         let proxy_connector = ProxyConnector::from_proxy(connector, proxy).unwrap();
28//!         proxy_connector
29//!     };
30//!
31//!     // Connecting to http will trigger regular GETs and POSTs.
32//!     // We need to manually append the relevant headers to the request
33//!     let uri: Uri = "http://my-remote-website.com".parse().unwrap();
34//!     let mut req = Request::get(uri.clone()).body(Empty::<Bytes>::new()).unwrap();
35//!
36//!     if let Some(headers) = proxy.http_headers(&uri) {
37//!         req.headers_mut().extend(headers.clone().into_iter());
38//!     }
39//!
40//!     let client = Client::builder(TokioExecutor::new()).build(proxy);
41//!     let mut resp = client.request(req).await?;
42//!     println!("Response: {}", resp.status());
43//!     while let Some(chunk) = resp.body_mut().collect().await.ok().map(|c| c.to_bytes()) {
44//!         stdout().write_all(&chunk).await?;
45//!     }
46//!
47//!     // Connecting to an https uri is straightforward (uses 'CONNECT' method underneath)
48//!     let uri = "https://my-remote-websitei-secured.com".parse().unwrap();
49//!     let mut resp = client.get(uri).await?;
50//!     println!("Response: {}", resp.status());
51//!     while let Some(chunk) = resp.body_mut().collect().await.ok().map(|c| c.to_bytes()) {
52//!         stdout().write_all(&chunk).await?;
53//!     }
54//!
55//!     Ok(())
56//! }
57//! ```
58
59#![allow(missing_docs)]
60
61mod rt;
62mod stream;
63mod tunnel;
64
65use std::{fmt, io, sync::Arc};
66use std::{
67    future::Future,
68    pin::Pin,
69    task::{Context, Poll},
70};
71
72use futures_util::future::TryFutureExt;
73use headers::{authorization::Credentials, Authorization, HeaderMapExt, ProxyAuthorization};
74use http::header::{HeaderMap, HeaderName, HeaderValue};
75use hyper::rt::{Read, Write};
76use hyper::Uri;
77use tower_service::Service;
78
79pub use stream::ProxyStream;
80
81#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
82use native_tls::TlsConnector as NativeTlsConnector;
83
84#[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
85use tokio_native_tls::TlsConnector;
86
87#[cfg(feature = "__rustls")]
88use hyper_rustls::ConfigBuilderExt;
89
90#[cfg(feature = "__rustls")]
91use tokio_rustls::TlsConnector;
92
93#[cfg(feature = "__rustls")]
94use tokio_rustls::rustls::pki_types::ServerName;
95
96type BoxError = Box<dyn std::error::Error + Send + Sync>;
97
98/// The Intercept enum to filter connections
99#[derive(Debug, Clone)]
100pub enum Intercept {
101    /// All incoming connection will go through proxy
102    All,
103    /// Only http connections will go through proxy
104    Http,
105    /// Only https connections will go through proxy
106    Https,
107    /// No connection will go through this proxy
108    None,
109    /// A custom intercept
110    Custom(Custom),
111}
112
113/// A trait for matching between Destination and Uri
114pub trait Dst {
115    /// Returns the connection scheme, e.g. "http" or "https"
116    fn scheme(&self) -> Option<&str>;
117    /// Returns the host of the connection
118    fn host(&self) -> Option<&str>;
119    /// Returns the port for the connection
120    fn port(&self) -> Option<u16>;
121}
122
123impl Dst for Uri {
124    fn scheme(&self) -> Option<&str> {
125        self.scheme_str()
126    }
127
128    fn host(&self) -> Option<&str> {
129        self.host()
130    }
131
132    fn port(&self) -> Option<u16> {
133        self.port_u16()
134    }
135}
136
137#[inline]
138pub(crate) fn io_err<E: Into<Box<dyn std::error::Error + Send + Sync>>>(e: E) -> io::Error {
139    io::Error::new(io::ErrorKind::Other, e)
140}
141
142pub type CustomProxyCallback =
143    dyn Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync;
144
145/// A Custom struct to proxy custom uris
146#[derive(Clone)]
147pub struct Custom(Arc<CustomProxyCallback>);
148
149impl fmt::Debug for Custom {
150    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
151        write!(f, "_")
152    }
153}
154
155impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
156    for Custom
157{
158    fn from(f: F) -> Custom {
159        Custom(Arc::new(f))
160    }
161}
162
163impl Intercept {
164    /// A function to check if given `Uri` is proxied
165    pub fn matches<D: Dst>(&self, uri: &D) -> bool {
166        match (self, uri.scheme()) {
167            (&Intercept::All, _)
168            | (&Intercept::Http, Some("http"))
169            | (&Intercept::Https, Some("https")) => true,
170            (&Intercept::Custom(Custom(ref f)), _) => f(uri.scheme(), uri.host(), uri.port()),
171            _ => false,
172        }
173    }
174}
175
176impl<F: Fn(Option<&str>, Option<&str>, Option<u16>) -> bool + Send + Sync + 'static> From<F>
177    for Intercept
178{
179    fn from(f: F) -> Intercept {
180        Intercept::Custom(f.into())
181    }
182}
183
184/// A Proxy struct
185#[derive(Clone, Debug)]
186pub struct Proxy {
187    intercept: Intercept,
188    force_connect: bool,
189    headers: HeaderMap,
190    uri: Uri,
191}
192
193impl Proxy {
194    /// Create a new `Proxy`
195    pub fn new<I: Into<Intercept>>(intercept: I, uri: Uri) -> Proxy {
196        let mut proxy = Proxy {
197            intercept: intercept.into(),
198            uri: uri.clone(),
199            headers: HeaderMap::new(),
200            force_connect: false,
201        };
202
203        if let Some((user, pass)) = extract_user_pass(&uri) {
204            proxy.set_authorization(Authorization::basic(user, pass));
205        }
206
207        proxy
208    }
209
210    /// Set `Proxy` authorization
211    pub fn set_authorization<C: Credentials + Clone>(&mut self, credentials: Authorization<C>) {
212        match self.intercept {
213            Intercept::Http => {
214                self.headers.typed_insert(Authorization(credentials.0));
215            }
216            Intercept::Https => {
217                self.headers.typed_insert(ProxyAuthorization(credentials.0));
218            }
219            _ => {
220                self.headers
221                    .typed_insert(Authorization(credentials.0.clone()));
222                self.headers.typed_insert(ProxyAuthorization(credentials.0));
223            }
224        }
225    }
226
227    /// Forces the use of the CONNECT method.
228    pub fn force_connect(&mut self) {
229        self.force_connect = true;
230    }
231
232    /// Set a custom header
233    pub fn set_header(&mut self, name: HeaderName, value: HeaderValue) {
234        self.headers.insert(name, value);
235    }
236
237    /// Get current intercept
238    pub fn intercept(&self) -> &Intercept {
239        &self.intercept
240    }
241
242    /// Get current `Headers` which must be sent to proxy
243    pub fn headers(&self) -> &HeaderMap {
244        &self.headers
245    }
246
247    /// Get proxy uri
248    pub fn uri(&self) -> &Uri {
249        &self.uri
250    }
251}
252
253/// A wrapper around `Proxy`s with a connector.
254#[derive(Clone)]
255pub struct ProxyConnector<C> {
256    proxies: Vec<Proxy>,
257    connector: C,
258
259    #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
260    tls: Option<NativeTlsConnector>,
261
262    #[cfg(feature = "__rustls")]
263    tls: Option<TlsConnector>,
264
265    #[cfg(not(feature = "__tls"))]
266    tls: Option<()>,
267}
268
269impl<C: fmt::Debug> fmt::Debug for ProxyConnector<C> {
270    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
271        write!(
272            f,
273            "ProxyConnector {}{{ proxies: {:?}, connector: {:?} }}",
274            if self.tls.is_some() {
275                ""
276            } else {
277                "(unsecured)"
278            },
279            self.proxies,
280            self.connector
281        )
282    }
283}
284
285impl<C> ProxyConnector<C> {
286    /// Create a new secured Proxies
287    #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
288    pub fn new(connector: C) -> Result<Self, io::Error> {
289        let tls = NativeTlsConnector::builder()
290            .build()
291            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
292
293        Ok(ProxyConnector {
294            proxies: Vec::new(),
295            connector: connector,
296            tls: Some(tls),
297        })
298    }
299
300    /// Create a new secured Proxies
301    #[cfg(feature = "__rustls")]
302    pub fn new(connector: C) -> Result<Self, io::Error> {
303        let config = tokio_rustls::rustls::ClientConfig::builder();
304
305        #[cfg(all(
306            feature = "rustls-tls-native-roots",
307            not(feature = "rustls-tls-webpki-roots")
308        ))]
309        let config = config.with_native_roots()?;
310
311        #[cfg(feature = "rustls-tls-webpki-roots")]
312        let config = config.with_webpki_roots();
313
314        let cfg = Arc::new(config.with_no_client_auth());
315        let tls = TlsConnector::from(cfg);
316
317        Ok(ProxyConnector {
318            proxies: Vec::new(),
319            connector,
320            tls: Some(tls),
321        })
322    }
323
324    /// Create a new unsecured Proxy
325    pub fn unsecured(connector: C) -> Self {
326        ProxyConnector {
327            proxies: Vec::new(),
328            connector,
329            tls: None,
330        }
331    }
332
333    /// Create a proxy connector and attach a particular proxy
334    #[cfg(feature = "__tls")]
335    pub fn from_proxy(connector: C, proxy: Proxy) -> Result<Self, io::Error> {
336        let mut c = ProxyConnector::new(connector)?;
337        c.proxies.push(proxy);
338        Ok(c)
339    }
340
341    /// Create a proxy connector and attach a particular proxy
342    pub fn from_proxy_unsecured(connector: C, proxy: Proxy) -> Self {
343        let mut c = ProxyConnector::unsecured(connector);
344        c.proxies.push(proxy);
345        c
346    }
347
348    /// Change proxy connector
349    pub fn with_connector<CC>(self, connector: CC) -> ProxyConnector<CC> {
350        ProxyConnector {
351            connector,
352            proxies: self.proxies,
353            tls: self.tls,
354        }
355    }
356
357    /// Set or unset tls when tunneling
358    #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
359    pub fn set_tls(&mut self, tls: Option<NativeTlsConnector>) {
360        self.tls = tls;
361    }
362
363    /// Set or unset tls when tunneling
364    #[cfg(feature = "__rustls")]
365    pub fn set_tls(&mut self, tls: Option<TlsConnector>) {
366        self.tls = tls;
367    }
368
369    /// Get the current proxies
370    pub fn proxies(&self) -> &[Proxy] {
371        &self.proxies
372    }
373
374    /// Add a new additional proxy
375    pub fn add_proxy(&mut self, proxy: Proxy) {
376        self.proxies.push(proxy);
377    }
378
379    /// Extend the list of proxies
380    pub fn extend_proxies<I: IntoIterator<Item = Proxy>>(&mut self, proxies: I) {
381        self.proxies.extend(proxies)
382    }
383
384    /// Get http headers for a matching uri
385    ///
386    /// These headers must be appended to the hyper Request for the proxy to work properly.
387    /// This is needed only for http requests.
388    pub fn http_headers(&self, uri: &Uri) -> Option<&HeaderMap> {
389        if uri.scheme_str() != Some("http") {
390            return None;
391        }
392
393        self.match_proxy(uri).map(|p| &p.headers)
394    }
395
396    fn match_proxy<D: Dst>(&self, uri: &D) -> Option<&Proxy> {
397        self.proxies.iter().find(|p| p.intercept.matches(uri))
398    }
399}
400
401macro_rules! mtry {
402    ($e:expr) => {
403        match $e {
404            Ok(v) => v,
405            Err(e) => break Err(e.into()),
406        }
407    };
408}
409
410impl<C> Service<Uri> for ProxyConnector<C>
411where
412    C: Service<Uri>,
413    C::Response: Read + Write + Send + Unpin + 'static,
414    C::Future: Send + 'static,
415    C::Error: Into<BoxError>,
416{
417    type Response = ProxyStream<C::Response>;
418    type Error = io::Error;
419    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
420
421    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
422        match self.connector.poll_ready(cx) {
423            Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
424            Poll::Ready(Err(e)) => Poll::Ready(Err(io_err(e.into()))),
425            Poll::Pending => Poll::Pending,
426        }
427    }
428
429    fn call(&mut self, uri: Uri) -> Self::Future {
430        if let (Some(p), Some(host)) = (self.match_proxy(&uri), uri.host()) {
431            if uri.scheme() == Some(&http::uri::Scheme::HTTPS) || p.force_connect {
432                let host = host.to_owned();
433                let port =
434                    uri.port_u16()
435                        .unwrap_or(if uri.scheme() == Some(&http::uri::Scheme::HTTP) {
436                            80
437                        } else {
438                            443
439                        });
440
441                let tunnel = tunnel::new(&host, port, &p.headers);
442                let connection =
443                    proxy_dst(&uri, &p.uri).map(|proxy_url| self.connector.call(proxy_url));
444                let tls = if uri.scheme() == Some(&http::uri::Scheme::HTTPS) {
445                    self.tls.clone()
446                } else {
447                    None
448                };
449
450                Box::pin(async move {
451                    // this hack will gone once `try_blocks` will eventually stabilized
452                    #[allow(clippy::never_loop)]
453                    loop {
454                        let proxy_stream = mtry!(mtry!(connection).await.map_err(io_err));
455                        let tunnel_stream = mtry!(tunnel.with_stream(proxy_stream).await);
456
457                        break match tls {
458                            #[cfg(all(not(feature = "__rustls"), feature = "native-tls"))]
459                            Some(tls) => {
460                                use hyper_util::rt::TokioIo;
461                                let tls = TlsConnector::from(tls);
462                                let secure_stream = mtry!(tls
463                                    .connect(&host, TokioIo::new(tunnel_stream))
464                                    .await
465                                    .map_err(io_err));
466
467                                Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
468                            }
469
470                            #[cfg(feature = "__rustls")]
471                            Some(tls) => {
472                                use hyper_util::rt::TokioIo;
473                                let server_name =
474                                    mtry!(ServerName::try_from(host.to_string()).map_err(io_err));
475                                let secure_stream = mtry!(tls
476                                    .connect(server_name, TokioIo::new(tunnel_stream))
477                                    .await
478                                    .map_err(io_err));
479
480                                Ok(ProxyStream::Secured(Box::new(TokioIo::new(secure_stream))))
481                            }
482
483                            #[cfg(not(feature = "__tls",))]
484                            Some(_) => panic!("hyper-proxy was not built with TLS support"),
485
486                            None => Ok(ProxyStream::Regular(tunnel_stream)),
487                        };
488                    }
489                })
490            } else {
491                match proxy_dst(&uri, &p.uri) {
492                    Ok(proxy_uri) => Box::pin(
493                        self.connector
494                            .call(proxy_uri)
495                            .map_ok(ProxyStream::Regular)
496                            .map_err(|err| io_err(err.into())),
497                    ),
498                    Err(err) => Box::pin(futures_util::future::err(io_err(err))),
499                }
500            }
501        } else {
502            Box::pin(
503                self.connector
504                    .call(uri)
505                    .map_ok(ProxyStream::NoProxy)
506                    .map_err(|err| io_err(err.into())),
507            )
508        }
509    }
510}
511
512fn proxy_dst(dst: &Uri, proxy: &Uri) -> io::Result<Uri> {
513    Uri::builder()
514        .scheme(
515            proxy
516                .scheme_str()
517                .ok_or_else(|| io_err(format!("proxy uri missing scheme: {}", proxy)))?,
518        )
519        .authority(
520            proxy
521                .authority()
522                .ok_or_else(|| io_err(format!("proxy uri missing host: {}", proxy)))?
523                .clone(),
524        )
525        .path_and_query(
526            dst.path_and_query()
527                .ok_or_else(|| io_err(format!("dst uri missing path: {}", proxy)))?
528                .clone(),
529        )
530        .build()
531        .map_err(|err| io_err(format!("other error: {}", err)))
532}
533
534/// Extracts the username and password from the URI
535fn extract_user_pass(uri: &Uri) -> Option<(&str, &str)> {
536    let authority = uri.authority()?.as_str();
537    let (userinfo, _) = authority.rsplit_once('@')?;
538    let (username, password) = userinfo.split_once(':')?;
539
540    Some((username, password))
541}
542
543#[cfg(test)]
544mod tests {
545    use http::Uri;
546
547    use crate::{Intercept, Proxy};
548
549    #[test]
550    fn test_new_proxy_with_authorization() {
551        let proxy = Proxy::new(
552            Intercept::All,
553            Uri::from_static("https://bob:secret@my-proxy:8080"),
554        );
555
556        assert_eq!(
557            proxy
558                .headers()
559                .get("authorization")
560                .unwrap()
561                .to_str()
562                .unwrap(),
563            "Basic Ym9iOnNlY3JldA=="
564        );
565    }
566
567    #[test]
568    fn test_new_proxy_without_authorization() {
569        let proxy = Proxy::new(Intercept::All, Uri::from_static("https://my-proxy:8080"));
570
571        assert_eq!(proxy.headers().get("authorization"), None);
572    }
573}