hyper_reverse_proxy/lib.rs
1//!
2//! A simple reverse proxy, to be used with [Hyper].
3//!
4//! The implementation ensures that [Hop-by-hop headers] are stripped correctly in both directions,
5//! and adds the client's IP address to a comma-space-separated list of forwarding addresses in the
6//! `X-Forwarded-For` header.
7//!
8//! The implementation is based on Go's [`httputil.ReverseProxy`].
9//!
10//! [Hyper]: http://hyper.rs/
11//! [Hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
12//! [`httputil.ReverseProxy`]: https://golang.org/pkg/net/http/httputil/#ReverseProxy
13//!
14//! # Example
15//!
16//! Add these dependencies to your `Cargo.toml` file.
17//!
18//! ```toml
19//! [dependencies]
20//! hyper-reverse-proxy = "0.5"
21//! hyper = { version = "0.14", features = ["full"] }
22//! tokio = { version = "1", features = ["full"] }
23//! ```
24//!
25//! The following example will set up a reverse proxy listening on `127.0.0.1:13900`,
26//! and will proxy these calls:
27//!
28//! * `"/target/first"` will be proxied to `http://127.0.0.1:13901`
29//!
30//! * `"/target/second"` will be proxied to `http://127.0.0.1:13902`
31//!
32//! * All other URLs will be handled by `debug_request` function, that will display request information.
33//!
34//! ```rust,no_run
35//! use hyper::server::conn::AddrStream;
36//! use hyper::{Body, Request, Response, Server, StatusCode};
37//! use hyper::service::{service_fn, make_service_fn};
38//! use std::{convert::Infallible, net::SocketAddr};
39//! use std::net::IpAddr;
40//!
41//! fn debug_request(req: Request<Body>) -> Result<Response<Body>, Infallible> {
42//! let body_str = format!("{:?}", req);
43//! Ok(Response::new(Body::from(body_str)))
44//! }
45//!
46//! async fn handle(client_ip: IpAddr, req: Request<Body>) -> Result<Response<Body>, Infallible> {
47//! if req.uri().path().starts_with("/target/first") {
48//! // will forward requests to port 13901
49//! match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13901", req).await {
50//! Ok(response) => {Ok(response)}
51//! Err(_error) => {Ok(Response::builder()
52//! .status(StatusCode::INTERNAL_SERVER_ERROR)
53//! .body(Body::empty())
54//! .unwrap())}
55//! }
56//! } else if req.uri().path().starts_with("/target/second") {
57//! // will forward requests to port 13902
58//! match hyper_reverse_proxy::call(client_ip, "http://127.0.0.1:13902", req).await {
59//! Ok(response) => {Ok(response)}
60//! Err(_error) => {Ok(Response::builder()
61//! .status(StatusCode::INTERNAL_SERVER_ERROR)
62//! .body(Body::empty())
63//! .unwrap())}
64//! }
65//! } else {
66//! debug_request(req)
67//! }
68//! }
69//!
70//! #[tokio::main]
71//! async fn main() {
72//! let bind_addr = "127.0.0.1:8000";
73//! let addr:SocketAddr = bind_addr.parse().expect("Could not parse ip:port.");
74//!
75//! let make_svc = make_service_fn(|conn: &AddrStream| {
76//! let remote_addr = conn.remote_addr().ip();
77//! async move {
78//! Ok::<_, Infallible>(service_fn(move |req| handle(remote_addr, req)))
79//! }
80//! });
81//!
82//! let server = Server::bind(&addr).serve(make_svc);
83//!
84//! println!("Running server on {:?}", addr);
85//!
86//! if let Err(e) = server.await {
87//! eprintln!("server error: {}", e);
88//! }
89//! }
90//! ```
91//!
92
93use hyper::header::{HeaderMap, HeaderValue};
94use hyper::http::header::{InvalidHeaderValue, ToStrError};
95use hyper::http::uri::InvalidUri;
96use hyper::{Body, Client, Error, Request, Response, Uri};
97use lazy_static::lazy_static;
98use std::net::IpAddr;
99use std::str::FromStr;
100
101#[derive(Debug)]
102pub enum ProxyError {
103 InvalidUri(InvalidUri),
104 HyperError(Error),
105 ForwardHeaderError,
106}
107
108impl From<Error> for ProxyError {
109 fn from(err: Error) -> ProxyError {
110 ProxyError::HyperError(err)
111 }
112}
113
114impl From<InvalidUri> for ProxyError {
115 fn from(err: InvalidUri) -> ProxyError {
116 ProxyError::InvalidUri(err)
117 }
118}
119
120impl From<ToStrError> for ProxyError {
121 fn from(_err: ToStrError) -> ProxyError {
122 ProxyError::ForwardHeaderError
123 }
124}
125
126impl From<InvalidHeaderValue> for ProxyError {
127 fn from(_err: InvalidHeaderValue) -> ProxyError {
128 ProxyError::ForwardHeaderError
129 }
130}
131
132fn is_hop_header(name: &str) -> bool {
133 use unicase::Ascii;
134
135 // A list of the headers, using `unicase` to help us compare without
136 // worrying about the case, and `lazy_static!` to prevent reallocation
137 // of the vector.
138 lazy_static! {
139 static ref HOP_HEADERS: Vec<Ascii<&'static str>> = vec![
140 Ascii::new("Connection"),
141 Ascii::new("Keep-Alive"),
142 Ascii::new("Proxy-Authenticate"),
143 Ascii::new("Proxy-Authorization"),
144 Ascii::new("Te"),
145 Ascii::new("Trailers"),
146 Ascii::new("Transfer-Encoding"),
147 Ascii::new("Upgrade"),
148 ];
149 }
150
151 HOP_HEADERS.iter().any(|h| h == &name)
152}
153
154/// Returns a clone of the headers without the [hop-by-hop headers].
155///
156/// [hop-by-hop headers]: http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
157fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
158 let mut result = HeaderMap::new();
159 for (k, v) in headers.iter() {
160 if !is_hop_header(k.as_str()) {
161 result.insert(k.clone(), v.clone());
162 }
163 }
164 result
165}
166
167fn create_proxied_response<B>(mut response: Response<B>) -> Response<B> {
168 *response.headers_mut() = remove_hop_headers(response.headers());
169 response
170}
171
172fn forward_uri<B>(forward_url: &str, req: &Request<B>) -> Result<Uri, InvalidUri> {
173 let forward_uri = match req.uri().query() {
174 Some(query) => format!("{}{}?{}", forward_url, req.uri().path(), query),
175 None => format!("{}{}", forward_url, req.uri().path()),
176 };
177
178 Uri::from_str(forward_uri.as_str())
179}
180
181fn create_proxied_request<B>(
182 client_ip: IpAddr,
183 forward_url: &str,
184 mut request: Request<B>,
185) -> Result<Request<B>, ProxyError> {
186 *request.headers_mut() = remove_hop_headers(request.headers());
187 *request.uri_mut() = forward_uri(forward_url, &request)?;
188
189 let x_forwarded_for_header_name = "x-forwarded-for";
190
191 // Add forwarding information in the headers
192 match request.headers_mut().entry(x_forwarded_for_header_name) {
193 hyper::header::Entry::Vacant(entry) => {
194 entry.insert(client_ip.to_string().parse()?);
195 }
196
197 hyper::header::Entry::Occupied(mut entry) => {
198 let addr = format!("{}, {}", entry.get().to_str()?, client_ip);
199 entry.insert(addr.parse()?);
200 }
201 }
202
203 Ok(request)
204}
205
206pub async fn call(
207 client_ip: IpAddr,
208 forward_uri: &str,
209 request: Request<Body>,
210) -> Result<Response<Body>, ProxyError> {
211 let proxied_request = create_proxied_request(client_ip, &forward_uri, request)?;
212
213 let client = Client::new();
214 let response = client.request(proxied_request).await?;
215 let proxied_response = create_proxied_response(response);
216 Ok(proxied_response)
217}