warp_reverse_proxy/
lib.rs

1//! Fully composable [warp](https://github.com/seanmonstar/warp) filter that can be used as a reverse proxy. It forwards the request to the
2//! desired address and replies back the remote address response.
3//!
4//!
5//! ```no_run
6//! use warp::{hyper::Body, Filter, Rejection, Reply, http::Response};
7//! use warp_reverse_proxy::reverse_proxy_filter;
8//!
9//! async fn log_response(response: Response<Body>) -> Result<impl Reply, Rejection> {
10//!     println!("{:?}", response);
11//!     Ok(response)
12//! }
13//!
14//! #[tokio::main]
15//! async fn main() {
16//!     let hello = warp::path!("hello" / String).map(|name| format!("Hello, {}!", name));
17//!
18//!     // // spawn base server
19//!     tokio::spawn(warp::serve(hello).run(([0, 0, 0, 0], 8080)));
20//!
21//!     // Forward request to localhost in other port
22//!     let app = warp::path!("hello" / ..).and(
23//!         reverse_proxy_filter("".to_string(), "http://127.0.0.1:8080/".to_string())
24//!             .and_then(log_response),
25//!     );
26//!
27//!     // spawn proxy server
28//!     warp::serve(app).run(([0, 0, 0, 0], 3030)).await;
29//! }
30//! ```
31pub mod errors;
32
33use once_cell::sync::{Lazy, OnceCell};
34use reqwest::redirect::Policy;
35use unicase::Ascii;
36use warp::filters::path::FullPath;
37use warp::http;
38use warp::http::{HeaderMap, HeaderValue, Method as RequestMethod};
39use warp::hyper::body::Bytes;
40use warp::hyper::Body;
41use warp::{Filter, Rejection};
42
43/// Reverse proxy internal client
44///
45/// It can be overridden if needed calling `OnceCell::set` as follows:
46/// # Examples
47/// ```
48/// use warp_reverse_proxy::CLIENT;
49/// use reqwest::Client;
50///
51/// let client = Client::builder().build().expect("client goes boom...");
52/// CLIENT.set(client).expect("client is set");
53/// ```
54pub static CLIENT: OnceCell<reqwest::Client> = OnceCell::new();
55
56/// Alias of warp `FullPath`
57pub type Uri = FullPath;
58
59/// Alias of query parameters.
60///
61/// This is the type that holds the request query parameters.
62pub type QueryParameters = Option<String>;
63
64/// Alias of warp `Method`
65pub type Method = RequestMethod;
66
67/// Alias of warp `HeaderMap`
68pub type Headers = HeaderMap;
69
70/// Wrapper around a request data tuple.
71///
72/// It is the type that holds the request data extracted by the [`extract_request_data_filter`](fn.extract_request_data_filter.html) filter.
73pub type Request = (Uri, QueryParameters, Method, Headers, Bytes);
74
75/// Reverse proxy filter
76///
77/// Forwards the request to the desired location. It maps one to one, meaning
78/// that a request to `https://www.bar.foo/handle/this/path` forwarding to `https://www.other.location`
79/// will result in a request to `https://www.other.location/handle/this/path`.
80///
81/// # Arguments
82///
83/// * `base_path` - A string with the initial relative path of the endpoint.
84/// For example a `foo/` applied for an endpoint `foo/bar/` will result on a proxy to `bar/` (hence `/foo` is removed)
85///
86/// * `proxy_address` - Base proxy address to forward request.
87/// # Examples
88///
89/// When making a filter with a path `/handle/this/path` combined with a filter built
90/// with `reverse_proxy_filter("handle".to_string(), "localhost:8080")`
91/// will make that request arriving to `https://www.bar.foo/handle/this/path` be forwarded to `localhost:8080/this/path`
92pub fn reverse_proxy_filter(
93    base_path: String,
94    proxy_address: String,
95) -> impl Filter<Extract = (http::Response<Body>,), Error = Rejection> + Clone {
96    let proxy_address = warp::any().map(move || proxy_address.clone());
97    let base_path = warp::any().map(move || base_path.clone());
98    let data_filter = extract_request_data_filter();
99
100    proxy_address
101        .and(base_path)
102        .and(data_filter)
103        .and_then(proxy_to_and_forward_response)
104        .boxed()
105}
106
107/// Warp filter that extracts query parameters from the request, if they exist.
108pub fn query_params_filter(
109) -> impl Filter<Extract = (QueryParameters,), Error = std::convert::Infallible> + Clone {
110    warp::query::raw()
111        .map(Some)
112        .or_else(|_| async { Ok::<(QueryParameters,), std::convert::Infallible>((None,)) })
113}
114
115/// Warp filter that extracts the relative request path, method, headers map and body of a request.
116pub fn extract_request_data_filter(
117) -> impl Filter<Extract = Request, Error = warp::Rejection> + Clone {
118    warp::path::full()
119        .and(query_params_filter())
120        .and(warp::method())
121        .and(warp::header::headers_cloned())
122        .and(warp::body::bytes())
123}
124
125/// Build a request and send to the requested address.
126///
127/// Wraps the response into a `warp::reply` compatible type (`http::Response`)
128///
129/// # Arguments
130///
131/// * `proxy_address` - A string containing the base proxy address where the request
132/// will be forwarded to.
133///
134/// * `base_path` - A string with the prepended sub-path to be stripped from the request uri path.
135///
136/// * `uri` -> The uri of the extracted request.
137///
138/// * `params` -> The URL query parameters
139///
140/// * `method` -> The request method.
141///
142/// * `headers` -> The request headers.
143///
144/// * `body` -> The request body.
145///
146/// # Examples
147/// Notice that this method usually need to be used in aggregation with
148/// the [`extract_request_data_filter`](fn.extract_request_data_filter.html) filter which already
149/// provides the `(Uri, QueryParameters, Method, Headers, Body)` needed for calling this method. But the `proxy_address`
150/// and the `base_path` arguments need to be provided too.
151/// ```rust, ignore
152/// use warp::{Filter, hyper::Body, Reply, Rejection, hyper::Response};
153/// use warp_reverse_proxy::{extract_request_data_filter, proxy_to_and_forward_response};
154///
155/// async fn log_response(response: Response<Body>) -> Result<impl Reply, Rejection> {
156///     println!("{:?}", response);
157///     Ok(response)
158/// }
159///
160/// let request_filter = extract_request_data_filter();
161/// let app = warp::path!("hello" / String)
162///     .map(|port| (format!("http://127.0.0.1:{}/", port), "".to_string()))
163///     .untuple_one()
164///     .and(request_filter)
165///     .and_then(proxy_to_and_forward_response)
166///     .and_then(log_response);
167/// ```
168pub async fn proxy_to_and_forward_response(
169    proxy_address: String,
170    base_path: String,
171    uri: FullPath,
172    params: QueryParameters,
173    method: Method,
174    headers: HeaderMap,
175    body: Bytes,
176) -> Result<http::Response<Body>, Rejection> {
177    let proxy_uri = remove_relative_path(&uri, base_path, proxy_address);
178    let request = filtered_data_to_request(proxy_uri, (uri, params, method, headers, body))
179        .map_err(warp::reject::custom)?;
180    let response = proxy_request(request).await.map_err(warp::reject::custom)?;
181    response_to_reply(response)
182        .await
183        .map_err(warp::reject::custom)
184}
185
186/// Converts a reqwest response into a http::Response
187async fn response_to_reply(
188    response: reqwest::Response,
189) -> Result<http::Response<Body>, errors::Error> {
190    let mut builder = http::Response::builder();
191    for (k, v) in remove_hop_headers(response.headers()).iter() {
192        builder = builder.header(k, v);
193    }
194    let status = response.status();
195    let body = Body::wrap_stream(response.bytes_stream());
196    builder
197        .status(status)
198        .body(body)
199        .map_err(errors::Error::Http)
200}
201
202fn remove_relative_path(uri: &FullPath, base_path: String, proxy_address: String) -> String {
203    let mut base_path = base_path;
204    if !base_path.starts_with('/') {
205        base_path = format!("/{}", base_path);
206    }
207    let relative_path = uri
208        .as_str()
209        .trim_start_matches(&base_path)
210        .trim_start_matches('/');
211
212    let proxy_address = proxy_address.trim_end_matches('/');
213    format!("{}/{}", proxy_address, relative_path)
214}
215
216/// Checker method to filter hop headers
217///
218/// Headers are checked using unicase to avoid case misfunctions
219fn is_hop_header(header_name: &str) -> bool {
220    static HOP_HEADERS: Lazy<Vec<Ascii<&'static str>>> = Lazy::new(|| {
221        vec![
222            Ascii::new("Connection"),
223            Ascii::new("Keep-Alive"),
224            Ascii::new("Proxy-Authenticate"),
225            Ascii::new("Proxy-Authorization"),
226            Ascii::new("Te"),
227            Ascii::new("Trailers"),
228            Ascii::new("Transfer-Encoding"),
229            Ascii::new("Upgrade"),
230        ]
231    });
232
233    HOP_HEADERS.iter().any(|h| h == &header_name)
234}
235
236fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
237    headers
238        .iter()
239        .filter_map(|(k, v)| {
240            if !is_hop_header(k.as_str()) {
241                Some((k.clone(), v.clone()))
242            } else {
243                None
244            }
245        })
246        .collect()
247}
248
249fn filtered_data_to_request(
250    proxy_address: String,
251    request: Request,
252) -> Result<reqwest::Request, errors::Error> {
253    let (_uri, params, method, headers, body) = request;
254
255    let proxy_uri = if let Some(params) = params {
256        format!("{}?{}", proxy_address, params)
257    } else {
258        proxy_address
259    };
260
261    let headers = remove_hop_headers(&headers);
262
263    CLIENT
264        .get_or_init(default_reqwest_client)
265        .request(method, proxy_uri)
266        .headers(headers)
267        .body(body)
268        .build()
269        .map_err(errors::Error::Request)
270}
271
272/// Build and send a request to the specified address and request data
273async fn proxy_request(request: reqwest::Request) -> Result<reqwest::Response, errors::Error> {
274    CLIENT
275        .get_or_init(default_reqwest_client)
276        .execute(request)
277        .await
278        .map_err(errors::Error::Request)
279}
280
281/// Build a default client with redirect policy set to none
282fn default_reqwest_client() -> reqwest::Client {
283    reqwest::Client::builder()
284        .redirect(Policy::none())
285        .build()
286        // we should panic here, it is enforce that the client is needed, and there is no error
287        // handling possible on function call, better to stop execution.
288        .expect("Default reqwest client couldn't build")
289}
290
291#[cfg(test)]
292pub mod test {
293    use crate::{
294        extract_request_data_filter, filtered_data_to_request, proxy_request, remove_relative_path,
295        reverse_proxy_filter, Request,
296    };
297    use std::net::SocketAddr;
298    use warp::http::StatusCode;
299    use warp::Filter;
300
301    fn serve_test_response(path: String, address: SocketAddr) {
302        if path.is_empty() {
303            tokio::spawn(warp::serve(warp::any().map(warp::reply)).run(address));
304        } else {
305            tokio::spawn(warp::serve(warp::path(path).map(warp::reply)).run(address));
306        }
307    }
308
309    #[tokio::test]
310    async fn request_data_match() {
311        let filter = extract_request_data_filter();
312
313        let (path, query, method, body, header) =
314            ("/foo/bar", "foo=bar", "POST", b"foo bar", ("foo", "bar"));
315        let path_with_query = format!("{}?{}", path, query);
316
317        let result = warp::test::request()
318            .path(path_with_query.as_str())
319            .method(method)
320            .body(body)
321            .header(header.0, header.1)
322            .filter(&filter)
323            .await;
324
325        let (result_path, result_query, result_method, result_headers, result_body): Request =
326            result.unwrap();
327
328        assert_eq!(path, result_path.as_str());
329        assert_eq!(Some(query.to_string()), result_query);
330        assert_eq!(method, result_method.as_str());
331        assert_eq!(bytes::Bytes::from(body.to_vec()), result_body);
332        assert_eq!(result_headers.get(header.0).unwrap(), header.1);
333    }
334
335    #[tokio::test]
336    async fn proxy_forward_response() {
337        let filter = extract_request_data_filter();
338        let (path_with_params, method, body, header) = (
339            "http://127.0.0.1:3030/foo/bar?foo=bar",
340            "GET",
341            b"foo bar",
342            ("foo", "bar"),
343        );
344
345        let result = warp::test::request()
346            .path(path_with_params)
347            .method(method)
348            .body(body)
349            .header(header.0, header.1)
350            .filter(&filter)
351            .await;
352
353        let request: Request = result.unwrap();
354
355        let address = ([127, 0, 0, 1], 4040);
356        serve_test_response("".to_string(), address.into());
357
358        tokio::task::yield_now().await;
359        // transform request data into an actual request
360        let request = filtered_data_to_request(
361            remove_relative_path(
362                &request.0,
363                "".to_string(),
364                "http://127.0.0.1:4040".to_string(),
365            ),
366            request,
367        )
368        .unwrap();
369        let response = proxy_request(request).await.unwrap();
370        assert_eq!(response.status(), StatusCode::OK);
371    }
372
373    #[tokio::test]
374    async fn full_reverse_proxy_filter_forward_response() {
375        let address_str = "http://127.0.0.1:3030";
376        let filter = warp::path!("relative_path" / ..).and(reverse_proxy_filter(
377            "relative_path".to_string(),
378            address_str.to_string(),
379        ));
380        let address = ([127, 0, 0, 1], 3030);
381        let (path, method, body, header) = (
382            "https://127.0.0.1:3030/relative_path/foo",
383            "GET",
384            b"foo bar",
385            ("foo", "bar"),
386        );
387
388        serve_test_response("foo".to_string(), address.into());
389        tokio::task::yield_now().await;
390
391        let response = warp::test::request()
392            .path(path)
393            .method(method)
394            .body(body)
395            .header(header.0, header.1)
396            .reply(&filter)
397            .await;
398
399        assert_eq!(response.status(), StatusCode::OK);
400    }
401}