hyper_reverse_proxy/
lib.rs

1//!
2//! A simple reverse proxy, to be used with [Hyper].
3//!
4//! The implementation ensures that [Hop-by-hop headers] are stripped correctly in both directions,
5//! and adds the client's IP address to a comma-space-separated list of forwarding addresses in the
6//! `X-Forwarded-For` header.
7//!
8//! The implementation is based on Go's [`httputil.ReverseProxy`].
9//!
10//! [Hyper]: http://hyper.rs/
11//! [Hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
12//! [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy
13//!
14//! # Example
15//!
16//! Add these dependencies to your `Cargo.toml` file.
17//!
18//! ```toml
19//! [dependencies]
20//! hyper-reverse-proxy = "0.5"
21//! hyper = { version = "0.14", features = ["full"] }
22//! tokio = { version = "1", features = ["full"] }
23//! ```
24//!
25//! The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
26//! and will proxy these calls:
27//!
28//! * `"/target/first"` will be proxied to `http://127.0.0.1:13901`
29//!
30//! * `"/target/second"` will be proxied to `http://127.0.0.1:13902`
31//!
32//! * All other URLs will be handled by `debug_request` function, that will display request information.
33//!
34//! ```rust,no_run
35//! use hyper::server::conn::AddrStream;
36//! use hyper::{Body, Request, Response, Server, StatusCode};
37//! use hyper::service::{service_fn, make_service_fn};
38//! use std::{convert::Infallible, net::SocketAddr};
39//! use std::net::IpAddr;
40//!
41//! fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible>  {
42//!     let body_str = format!("{:?}", req);
43//!     Ok(Response::new(Body::from(body_str)))
44//! }
45//!
46//! async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
47//!     if req.uri().path().starts_with("/target/first") {
48//!         // will forward requests to port 13901
49//!         match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13901", req).await {
50//!             Ok(response) => {Ok(response)}
51//!             Err(_error) => {Ok(Response::builder()
52//!                                   .status(StatusCode::INTERNAL_SERVER_ERROR)
53//!                                   .body(Body::empty())
54//!                                   .unwrap())}
55//!         }
56//!     } else if req.uri().path().starts_with("/target/second") {
57//!         // will forward requests to port 13902
58//!         match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13902", req).await {
59//!             Ok(response) => {Ok(response)}
60//!             Err(_error) => {Ok(Response::builder()
61//!                                   .status(StatusCode::INTERNAL_SERVER_ERROR)
62//!                                   .body(Body::empty())
63//!                                   .unwrap())}
64//!         }
65//!     } else {
66//!         debug_request(req)
67//!     }
68//! }
69//!
70//! #[tokio::main]
71//! async fn main() {
72//!     let bind_addr = "127.0.0.1:8000";
73//!     let addr:SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
74//!
75//!     let make_svc = make_service_fn(|conn: &AddrStream| {
76//!         let remote_addr = conn.remote_addr().ip();
77//!         async move {
78//!             Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req)))
79//!         }
80//!     });
81//!
82//!     let server = Server::bind(&addr).serve(make_svc);
83//!
84//!     println!("Running server on {:?}", addr);
85//!
86//!     if let Err(e) = server.await {
87//!         eprintln!("server error: {}", e);
88//!     }
89//! }
90//! ```
91//!
92
93use hyper::header::{HeaderMap, HeaderValue};
94use hyper::http::header::{InvalidHeaderValue, ToStrError};
95use hyper::http::uri::InvalidUri;
96use hyper::{Body, Client, Error, Request, Response, Uri};
97use lazy_static::lazy_static;
98use std::net::IpAddr;
99use std::str::FromStr;
100
101#[derive(Debug)]
102pub enum ProxyError {
103    InvalidUri(InvalidUri),
104    HyperError(Error),
105    ForwardHeaderError,
106}
107
108impl From<Error> for ProxyError {
109    fn from(err: Error) -> ProxyError {
110        ProxyError::HyperError(err)
111    }
112}
113
114impl From<InvalidUri> for ProxyError {
115    fn from(err: InvalidUri) -> ProxyError {
116        ProxyError::InvalidUri(err)
117    }
118}
119
120impl From<ToStrError> for ProxyError {
121    fn from(_err: ToStrError) -> ProxyError {
122        ProxyError::ForwardHeaderError
123    }
124}
125
126impl From<InvalidHeaderValue> for ProxyError {
127    fn from(_err: InvalidHeaderValue) -> ProxyError {
128        ProxyError::ForwardHeaderError
129    }
130}
131
132fn is_hop_header(name: &str) -> bool {
133    use unicase::Ascii;
134
135    // A list of the headers, using `unicase` to help us compare without
136    // worrying about the case, and `lazy_static!` to prevent reallocation
137    // of the vector.
138    lazy_static! {
139        static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
140            Ascii::new("Connection"),
141            Ascii::new("Keep-Alive"),
142            Ascii::new("Proxy-Authenticate"),
143            Ascii::new("Proxy-Authorization"),
144            Ascii::new("Te"),
145            Ascii::new("Trailers"),
146            Ascii::new("Transfer-Encoding"),
147            Ascii::new("Upgrade"),
148        ];
149    }
150
151    HOP_HEADERS.iter().any(|h| h == &name)
152}
153
154/// Returns a clone of the headers without the [hop-by-hop headers].
155///
156/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
157fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
158    let mut result = HeaderMap::new();
159    for (k, v) in headers.iter() {
160        if !is_hop_header(k.as_str()) {
161            result.insert(k.clone(), v.clone());
162        }
163    }
164    result
165}
166
167fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
168    *response.headers_mut() = remove_hop_headers(response.headers());
169    response
170}
171
172fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri, InvalidUri> {
173    let forward_uri = match req.uri().query() {
174        Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
175        None => format!("{}{}", forward_url, req.uri().path()),
176    };
177
178    Uri::from_str(forward_uri.as_str())
179}
180
181fn create_proxied_request<B>(
182    client_ip: IpAddr,
183    forward_url: &str,
184    mut request: Request<B>,
185) -> Result<Request<B>, ProxyError> {
186    *request.headers_mut() = remove_hop_headers(request.headers());
187    *request.uri_mut() = forward_uri(forward_url, &request)?;
188
189    let x_forwarded_for_header_name = "x-forwarded-for";
190
191    // Add forwarding information in the headers
192    match request.headers_mut().entry(x_forwarded_for_header_name) {
193        hyper::header::Entry::Vacant(entry) => {
194            entry.insert(client_ip.to_string().parse()?);
195        }
196
197        hyper::header::Entry::Occupied(mut entry) => {
198            let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
199            entry.insert(addr.parse()?);
200        }
201    }
202
203    Ok(request)
204}
205
206pub async fn call(
207    client_ip: IpAddr,
208    forward_uri: &str,
209    request: Request<Body>,
210) -> Result<Response<Body>, ProxyError> {
211    let proxied_request = create_proxied_request(client_ip, &forward_uri, request)?;
212
213    let client = Client::new();
214    let response = client.request(proxied_request).await?;
215    let proxied_response = create_proxied_response(response);
216    Ok(proxied_response)
217}