httproxide_hyper_reverse_proxy/
lib.rs

1//!
2//! A simple reverse proxy, to be used with [Hyper].
3//!
4//! The implementation ensures that [Hop-by-hop headers] are stripped correctly in both directions,
5//! and adds the client's IP address to a comma-space-separated list of forwarding addresses in the
6//! `X-Forwarded-For` header.
7//!
8//! The implementation is based on Go's [`httputil.ReverseProxy`].
9//!
10//! [Hyper]: http://hyper.rs/
11//! [Hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
12//! [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy
13//!
14//! # Example
15//!
16//! Add these dependencies to your `Cargo.toml` file.
17//!
18//! ```toml
19//! [dependencies]
20//! hyper-reverse-proxy = "0.5"
21//! hyper = { version = "0.14", features = ["full"] }
22//! tokio = { version = "1", features = ["full"] }
23//! ```
24//!
25//! To enable support for connecting to downstream HTTPS servers, enable the `https` feature:
26//!
27//! ```toml
28//! hyper-reverse-proxy = { version = "0.4", features = ["https"] }
29//! ```
30//!
31//! The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
32//! and will proxy these calls:
33//!
34//! * `"/target/first"` will be proxied to `http://127.0.0.1:13901`
35//!
36//! * `"/target/second"` will be proxied to `http://127.0.0.1:13902`
37//!
38//! * All other URLs will be handled by `debug_request` function, that will display request information.
39//!
40//! ```rust,no_run
41//! use hyper::server::conn::AddrStream;
42//! use hyper::service::{make_service_fn, service_fn};
43//! use hyper::{Body, Request, Response, Server, StatusCode};
44//! use hyper_reverse_proxy::ReverseProxy;
45//! use hyper_trust_dns::{RustlsHttpsConnector, TrustDnsResolver};
46//! use std::net::IpAddr;
47//! use std::{convert::Infallible, net::SocketAddr};
48//!
49//! lazy_static::lazy_static! {
50//!     static ref  PROXY_CLIENT: ReverseProxy<RustlsHttpsConnector> = {
51//!         ReverseProxy::new(
52//!             hyper::Client::builder().build::<_, hyper::Body>(TrustDnsResolver::default().into_rustls_webpki_https_connector()),
53//!         )
54//!     };
55//! }
56//!
57//! fn debug_request(req: &Request<Body>) -> Result<Response<Body>, Infallible> {
58//!     let body_str = format!("{:?}", req);
59//!     Ok(Response::new(Body::from(body_str)))
60//! }
61//!
62//! async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
63//!     if req.uri().path().starts_with("/target/first") {
64//!         match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13901", req)
65//!             .await
66//!         {
67//!             Ok(response) => {
68//!                 Ok(response)
69//!             },
70//!             Err(_error) => {
71//!                 Ok(Response::builder()
72//!                 .status(StatusCode::INTERNAL_SERVER_ERROR)
73//!                 .body(Body::empty())
74//!                 .unwrap())},
75//!         }
76//!     } else if req.uri().path().starts_with("/target/second") {
77//!         match PROXY_CLIENT.call(client_ip, "http://127.0.0.1:13902", req)
78//!             .await
79//!         {
80//!             Ok(response) => Ok(response),
81//!             Err(_error) => Ok(Response::builder()
82//!                 .status(StatusCode::INTERNAL_SERVER_ERROR)
83//!                 .body(Body::empty())
84//!                 .unwrap()),
85//!         }
86//!     } else {
87//!         debug_request(&req)
88//!     }
89//! }
90//!
91//! #[tokio::main]
92//! async fn main() {
93//!     let bind_addr = "127.0.0.1:8000";
94//!     let addr: SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
95//!
96//!     let make_svc = make_service_fn(|conn: &AddrStream| {
97//!         let remote_addr = conn.remote_addr().ip();
98//!         async move { Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req))) }
99//!     });
100//!
101//!     let server = Server::bind(&addr).serve(make_svc);
102//!
103//!     println!("Running server on {:?}", addr);
104//!
105//!     if let Err(e) = server.await {
106//!         eprintln!("server error: {}", e);
107//!     }
108//! }
109//!
110//! ```
111#[macro_use]
112extern crate tracing;
113
114use hyper::body::HttpBody;
115use hyper::client::connect::Connect;
116use hyper::header::{HeaderMap, HeaderName, HeaderValue};
117use hyper::http::header::{InvalidHeaderValue, ToStrError};
118use hyper::http::uri::InvalidUri;
119use hyper::upgrade::OnUpgrade;
120use hyper::{Body, Client, Error, Request, Response, StatusCode};
121use lazy_static::lazy_static;
122use std::net::IpAddr;
123use thiserror::Error as ThisError;
124use tokio::io::copy_bidirectional;
125
126lazy_static! {
127    static ref TE_HEADER: HeaderName = HeaderName::from_static("te");
128    static ref CONNECTION_HEADER: HeaderName = HeaderName::from_static("connection");
129    static ref UPGRADE_HEADER: HeaderName = HeaderName::from_static("upgrade");
130    static ref TRAILER_HEADER: HeaderName = HeaderName::from_static("trailer");
131    static ref TRAILERS_HEADER: HeaderName = HeaderName::from_static("trailers");
132    // A list of the headers, using hypers actual HeaderName comparison
133    static ref HOP_HEADERS: [HeaderName; 9] = [
134        CONNECTION_HEADER.clone(),
135        TE_HEADER.clone(),
136        TRAILER_HEADER.clone(),
137        HeaderName::from_static("keep-alive"),
138        HeaderName::from_static("proxy-connection"),
139        HeaderName::from_static("proxy-authenticate"),
140        HeaderName::from_static("proxy-authorization"),
141        HeaderName::from_static("transfer-encoding"),
142        HeaderName::from_static("upgrade"),
143    ];
144
145    static ref X_FORWARDED_FOR: HeaderName = HeaderName::from_static("x-forwarded-for");
146}
147
148#[derive(Debug, ThisError)]
149pub enum ProxyError {
150    #[error("{0}")]
151    InvalidUri(#[from] InvalidUri),
152    #[error("{0}")]
153    HyperError(#[from] Error),
154    #[error("ForwardHeaderError")]
155    ForwardHeaderError,
156    #[error("UpgradeError: {0}")]
157    UpgradeError(String),
158}
159
160impl From<ToStrError> for ProxyError {
161    fn from(_err: ToStrError) -> ProxyError {
162        ProxyError::ForwardHeaderError
163    }
164}
165
166impl From<InvalidHeaderValue> for ProxyError {
167    fn from(_err: InvalidHeaderValue) -> ProxyError {
168        ProxyError::ForwardHeaderError
169    }
170}
171
172fn remove_hop_headers(headers: &mut HeaderMap) {
173    debug!("Removing hop headers");
174
175    for header in &*HOP_HEADERS {
176        headers.remove(header);
177    }
178}
179
180fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
181    #[allow(clippy::blocks_in_if_conditions)]
182    if headers
183        .get(&*CONNECTION_HEADER)
184        .map(|value| {
185            value
186                .to_str()
187                .unwrap()
188                .split(',')
189                .any(|e| e.trim() == *UPGRADE_HEADER)
190        })
191        .unwrap_or(false)
192    {
193        if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
194            debug!(
195                "Found upgrade header with value: {}",
196                upgrade_value.to_str().unwrap().to_owned()
197            );
198
199            return Some(upgrade_value.to_str().unwrap().to_owned());
200        }
201    }
202
203    None
204}
205
206fn remove_connection_headers(headers: &mut HeaderMap) {
207    if headers.get(&*CONNECTION_HEADER).is_some() {
208        debug!("Removing connection headers");
209
210        let value = headers.get(&*CONNECTION_HEADER).cloned().unwrap();
211
212        for name in value.to_str().unwrap().split(',') {
213            if !name.trim().is_empty() {
214                headers.remove(name.trim());
215            }
216        }
217    }
218}
219
220fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
221    info!("Creating proxied response");
222
223    remove_hop_headers(response.headers_mut());
224    remove_connection_headers(response.headers_mut());
225
226    response
227}
228
229fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> String {
230    debug!("Building forward uri");
231
232    let split_url = forward_url.split('?').collect::<Vec<&str>>();
233
234    let mut base_url: &str = split_url.get(0).unwrap_or(&"");
235    let forward_url_query: &str = split_url.get(1).unwrap_or(&"");
236
237    let path2 = req.uri().path();
238
239    if base_url.ends_with('/') {
240        let mut path1_chars = base_url.chars();
241        path1_chars.next_back();
242
243        base_url = path1_chars.as_str();
244    }
245
246    let total_length = base_url.len()
247        + path2.len()
248        + 1
249        + forward_url_query.len()
250        + req.uri().query().map(|e| e.len()).unwrap_or(0);
251
252    debug!("Creating url with capacity to {}", total_length);
253
254    let mut url = String::with_capacity(total_length);
255
256    url.push_str(base_url);
257    url.push_str(path2);
258
259    if !forward_url_query.is_empty() || req.uri().query().map(|e| !e.is_empty()).unwrap_or(false) {
260        debug!("Adding query parts to url");
261        url.push('?');
262        url.push_str(forward_url_query);
263
264        if forward_url_query.is_empty() {
265            debug!("Using request query");
266
267            url.push_str(req.uri().query().unwrap_or(""));
268        } else {
269            debug!("Merging request and forward_url query");
270
271            let request_query_items = req.uri().query().unwrap_or("").split('&').map(|el| {
272                let parts = el.split('=').collect::<Vec<&str>>();
273                (parts[0], if parts.len() > 1 { parts[1] } else { "" })
274            });
275
276            let forward_query_items = forward_url_query
277                .split('&')
278                .map(|el| {
279                    let parts = el.split('=').collect::<Vec<&str>>();
280                    parts[0]
281                })
282                .collect::<Vec<_>>();
283
284            for (key, value) in request_query_items {
285                if !forward_query_items.iter().any(|e| e == &key) {
286                    url.push('&');
287                    url.push_str(key);
288                    url.push('=');
289                    url.push_str(value);
290                }
291            }
292
293            if url.ends_with('&') {
294                let mut parts = url.chars();
295                parts.next_back();
296
297                url = parts.as_str().to_string();
298            }
299        }
300    }
301
302    debug!("Built forwarding url from request: {}", url);
303
304    url.parse().unwrap()
305}
306
307fn create_proxied_request<B>(
308    client_ip: IpAddr,
309    forward_url: &str,
310    mut request: Request<B>,
311    upgrade_type: Option<&String>,
312) -> Result<Request<B>, ProxyError> {
313    info!("Creating proxied request");
314
315    let contains_te_trailers_value = request
316        .headers()
317        .get(&*TE_HEADER)
318        .map(|value| {
319            value
320                .to_str()
321                .unwrap()
322                .split(',')
323                .any(|e| e.trim() == *TRAILERS_HEADER)
324        })
325        .unwrap_or(false);
326
327    let uri: hyper::Uri = forward_uri(forward_url, &request).parse()?;
328
329    debug!("Setting headers of proxied request");
330
331    *request.uri_mut() = uri;
332
333    remove_hop_headers(request.headers_mut());
334    remove_connection_headers(request.headers_mut());
335
336    if contains_te_trailers_value {
337        debug!("Setting up trailer headers");
338
339        request
340            .headers_mut()
341            .insert(&*TE_HEADER, HeaderValue::from_static("trailers"));
342    }
343
344    if let Some(value) = upgrade_type {
345        debug!("Repopulate upgrade headers");
346
347        request
348            .headers_mut()
349            .insert(&*UPGRADE_HEADER, value.parse().unwrap());
350        request
351            .headers_mut()
352            .insert(&*CONNECTION_HEADER, HeaderValue::from_static("UPGRADE"));
353    }
354
355    // Add forwarding information in the headers
356    match request.headers_mut().entry(&*X_FORWARDED_FOR) {
357        hyper::header::Entry::Vacant(entry) => {
358            debug!("X-Fowraded-for header was vacant");
359            entry.insert(client_ip.to_string().parse()?);
360        }
361
362        hyper::header::Entry::Occupied(entry) => {
363            debug!("X-Fowraded-for header was occupied");
364            let client_ip_str = client_ip.to_string();
365            let mut addr =
366                String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len());
367
368            addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap());
369            addr.push(',');
370            addr.push(' ');
371            addr.push_str(&client_ip_str);
372        }
373    }
374
375    debug!("Created proxied request");
376
377    Ok(request)
378}
379
380pub async fn call<'a, C, B>(
381    client_ip: IpAddr,
382    forward_uri: &str,
383    mut request: Request<B>,
384    client: &'a Client<C, B>,
385) -> Result<Response<Body>, ProxyError>
386where
387    C: Connect + Clone + Send + Sync + 'static,
388    B: HttpBody + Send + 'static,
389    B::Data: Send,
390    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
391{
392    info!(
393        "Received proxy call from {} to {}, client: {}",
394        request.uri().to_string(),
395        forward_uri,
396        client_ip
397    );
398
399    let request_upgrade_type = get_upgrade_type(request.headers());
400    let request_upgraded = request.extensions_mut().remove::<OnUpgrade>();
401
402    let proxied_request = create_proxied_request(
403        client_ip,
404        forward_uri,
405        request,
406        request_upgrade_type.as_ref(),
407    )?;
408    let mut response = client.request(proxied_request).await?;
409
410    if response.status() == StatusCode::SWITCHING_PROTOCOLS {
411        let response_upgrade_type = get_upgrade_type(response.headers());
412
413        if request_upgrade_type != response_upgrade_type {
414            return Err(ProxyError::UpgradeError(format!(
415                "backend tried to switch to protocol {:?} when {:?} was requested",
416                response_upgrade_type, request_upgrade_type
417            )));
418        };
419        let request_upgraded = match request_upgraded {
420            Some(v) => v,
421            None => {
422                return Err(ProxyError::UpgradeError(
423                    "request does not have an upgrade extension".to_string(),
424                ))
425            }
426        };
427        let mut response_upgraded = match response.extensions_mut().remove::<OnUpgrade>() {
428            Some(v) => v.await?,
429            None => {
430                return Err(ProxyError::UpgradeError(
431                    "response does not have an upgrade extension".to_string(),
432                ))
433            }
434        };
435
436        debug!("Responding to a connection upgrade response");
437        tokio::spawn(async move {
438            let mut request_upgraded = match request_upgraded.await {
439                Ok(v) => v,
440                Err(e) => {
441                    warn!("failed to upgrade request: {}", e);
442                    return;
443                }
444            };
445
446            if let Some(err) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
447                .await
448                .err()
449            {
450                if err.kind() != std::io::ErrorKind::UnexpectedEof {
451                    warn!("coping between upgraded connections failed: {}", err);
452                }
453            }
454        });
455
456        Ok(response)
457    } else {
458        let proxied_response = create_proxied_response(response);
459
460        debug!("Responding to call with response");
461        Ok(proxied_response)
462    }
463}
464
465pub struct ReverseProxy<C, B> {
466    client: Client<C, B>,
467}
468
469impl<C, B> ReverseProxy<C, B>
470where
471    C: Connect + Clone + Send + Sync + 'static,
472    B: HttpBody + Send + 'static,
473    B::Data: Send,
474    B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
475{
476    pub fn new(client: Client<C, B>) -> Self {
477        Self { client }
478    }
479
480    pub async fn call(
481        &self,
482        client_ip: IpAddr,
483        forward_uri: &str,
484        request: Request<B>,
485    ) -> Result<Response<Body>, ProxyError> {
486        call::<C, B>(client_ip, forward_uri, request, &self.client).await
487    }
488}
489
490#[cfg(feature = "__bench")]
491pub mod benches {
492    pub fn hop_headers() -> &'static [crate::HeaderName] {
493        &*super::HOP_HEADERS
494    }
495
496    pub fn create_proxied_response<T>(response: crate::Response<T>) {
497        super::create_proxied_response(response);
498    }
499
500    pub fn forward_uri<B>(forward_url: &str, req: &crate::Request<B>) {
501        super::forward_uri(forward_url, req);
502    }
503
504    pub fn create_proxied_request<B>(
505        client_ip: crate::IpAddr,
506        forward_url: &str,
507        request: crate::Request<B>,
508        upgrade_type: Option<&String>,
509    ) {
510        super::create_proxied_request(client_ip, forward_url, request, upgrade_type).unwrap();
511    }
512}