1pub mod errors;
32
33use once_cell::sync::{Lazy, OnceCell};
34use reqwest::redirect::Policy;
35use unicase::Ascii;
36use warp::filters::path::FullPath;
37use warp::http;
38use warp::http::{HeaderMap, HeaderValue, Method as RequestMethod};
39use warp::hyper::body::Bytes;
40use warp::hyper::Body;
41use warp::{Filter, Rejection};
42
43pub static CLIENT: OnceCell<reqwest::Client> = OnceCell::new();
55
56pub type Uri = FullPath;
58
59pub type QueryParameters = Option<String>;
63
64pub type Method = RequestMethod;
66
67pub type Headers = HeaderMap;
69
70pub type Request = (Uri, QueryParameters, Method, Headers, Bytes);
74
75pub fn reverse_proxy_filter(
93 base_path: String,
94 proxy_address: String,
95) -> impl Filter<Extract = (http::Response<Body>,), Error = Rejection> + Clone {
96 let proxy_address = warp::any().map(move || proxy_address.clone());
97 let base_path = warp::any().map(move || base_path.clone());
98 let data_filter = extract_request_data_filter();
99
100 proxy_address
101 .and(base_path)
102 .and(data_filter)
103 .and_then(proxy_to_and_forward_response)
104 .boxed()
105}
106
107pub fn query_params_filter(
109) -> impl Filter<Extract = (QueryParameters,), Error = std::convert::Infallible> + Clone {
110 warp::query::raw()
111 .map(Some)
112 .or_else(|_| async { Ok::<(QueryParameters,), std::convert::Infallible>((None,)) })
113}
114
115pub fn extract_request_data_filter(
117) -> impl Filter<Extract = Request, Error = warp::Rejection> + Clone {
118 warp::path::full()
119 .and(query_params_filter())
120 .and(warp::method())
121 .and(warp::header::headers_cloned())
122 .and(warp::body::bytes())
123}
124
125pub async fn proxy_to_and_forward_response(
169 proxy_address: String,
170 base_path: String,
171 uri: FullPath,
172 params: QueryParameters,
173 method: Method,
174 headers: HeaderMap,
175 body: Bytes,
176) -> Result<http::Response<Body>, Rejection> {
177 let proxy_uri = remove_relative_path(&uri, base_path, proxy_address);
178 let request = filtered_data_to_request(proxy_uri, (uri, params, method, headers, body))
179 .map_err(warp::reject::custom)?;
180 let response = proxy_request(request).await.map_err(warp::reject::custom)?;
181 response_to_reply(response)
182 .await
183 .map_err(warp::reject::custom)
184}
185
186async fn response_to_reply(
188 response: reqwest::Response,
189) -> Result<http::Response<Body>, errors::Error> {
190 let mut builder = http::Response::builder();
191 for (k, v) in remove_hop_headers(response.headers()).iter() {
192 builder = builder.header(k, v);
193 }
194 let status = response.status();
195 let body = Body::wrap_stream(response.bytes_stream());
196 builder
197 .status(status)
198 .body(body)
199 .map_err(errors::Error::Http)
200}
201
202fn remove_relative_path(uri: &FullPath, base_path: String, proxy_address: String) -> String {
203 let mut base_path = base_path;
204 if !base_path.starts_with('/') {
205 base_path = format!("/{}", base_path);
206 }
207 let relative_path = uri
208 .as_str()
209 .trim_start_matches(&base_path)
210 .trim_start_matches('/');
211
212 let proxy_address = proxy_address.trim_end_matches('/');
213 format!("{}/{}", proxy_address, relative_path)
214}
215
216fn is_hop_header(header_name: &str) -> bool {
220 static HOP_HEADERS: Lazy<Vec<Ascii<&'static str>>> = Lazy::new(|| {
221 vec![
222 Ascii::new("Connection"),
223 Ascii::new("Keep-Alive"),
224 Ascii::new("Proxy-Authenticate"),
225 Ascii::new("Proxy-Authorization"),
226 Ascii::new("Te"),
227 Ascii::new("Trailers"),
228 Ascii::new("Transfer-Encoding"),
229 Ascii::new("Upgrade"),
230 ]
231 });
232
233 HOP_HEADERS.iter().any(|h| h == &header_name)
234}
235
236fn remove_hop_headers(headers: &HeaderMap<HeaderValue>) -> HeaderMap<HeaderValue> {
237 headers
238 .iter()
239 .filter_map(|(k, v)| {
240 if !is_hop_header(k.as_str()) {
241 Some((k.clone(), v.clone()))
242 } else {
243 None
244 }
245 })
246 .collect()
247}
248
249fn filtered_data_to_request(
250 proxy_address: String,
251 request: Request,
252) -> Result<reqwest::Request, errors::Error> {
253 let (_uri, params, method, headers, body) = request;
254
255 let proxy_uri = if let Some(params) = params {
256 format!("{}?{}", proxy_address, params)
257 } else {
258 proxy_address
259 };
260
261 let headers = remove_hop_headers(&headers);
262
263 CLIENT
264 .get_or_init(default_reqwest_client)
265 .request(method, proxy_uri)
266 .headers(headers)
267 .body(body)
268 .build()
269 .map_err(errors::Error::Request)
270}
271
272async fn proxy_request(request: reqwest::Request) -> Result<reqwest::Response, errors::Error> {
274 CLIENT
275 .get_or_init(default_reqwest_client)
276 .execute(request)
277 .await
278 .map_err(errors::Error::Request)
279}
280
281fn default_reqwest_client() -> reqwest::Client {
283 reqwest::Client::builder()
284 .redirect(Policy::none())
285 .build()
286 .expect("Default reqwest client couldn't build")
289}
290
291#[cfg(test)]
292pub mod test {
293 use crate::{
294 extract_request_data_filter, filtered_data_to_request, proxy_request, remove_relative_path,
295 reverse_proxy_filter, Request,
296 };
297 use std::net::SocketAddr;
298 use warp::http::StatusCode;
299 use warp::Filter;
300
301 fn serve_test_response(path: String, address: SocketAddr) {
302 if path.is_empty() {
303 tokio::spawn(warp::serve(warp::any().map(warp::reply)).run(address));
304 } else {
305 tokio::spawn(warp::serve(warp::path(path).map(warp::reply)).run(address));
306 }
307 }
308
309 #[tokio::test]
310 async fn request_data_match() {
311 let filter = extract_request_data_filter();
312
313 let (path, query, method, body, header) =
314 ("/foo/bar", "foo=bar", "POST", b"foo bar", ("foo", "bar"));
315 let path_with_query = format!("{}?{}", path, query);
316
317 let result = warp::test::request()
318 .path(path_with_query.as_str())
319 .method(method)
320 .body(body)
321 .header(header.0, header.1)
322 .filter(&filter)
323 .await;
324
325 let (result_path, result_query, result_method, result_headers, result_body): Request =
326 result.unwrap();
327
328 assert_eq!(path, result_path.as_str());
329 assert_eq!(Some(query.to_string()), result_query);
330 assert_eq!(method, result_method.as_str());
331 assert_eq!(bytes::Bytes::from(body.to_vec()), result_body);
332 assert_eq!(result_headers.get(header.0).unwrap(), header.1);
333 }
334
335 #[tokio::test]
336 async fn proxy_forward_response() {
337 let filter = extract_request_data_filter();
338 let (path_with_params, method, body, header) = (
339 "http://127.0.0.1:3030/foo/bar?foo=bar",
340 "GET",
341 b"foo bar",
342 ("foo", "bar"),
343 );
344
345 let result = warp::test::request()
346 .path(path_with_params)
347 .method(method)
348 .body(body)
349 .header(header.0, header.1)
350 .filter(&filter)
351 .await;
352
353 let request: Request = result.unwrap();
354
355 let address = ([127, 0, 0, 1], 4040);
356 serve_test_response("".to_string(), address.into());
357
358 tokio::task::yield_now().await;
359 let request = filtered_data_to_request(
361 remove_relative_path(
362 &request.0,
363 "".to_string(),
364 "http://127.0.0.1:4040".to_string(),
365 ),
366 request,
367 )
368 .unwrap();
369 let response = proxy_request(request).await.unwrap();
370 assert_eq!(response.status(), StatusCode::OK);
371 }
372
373 #[tokio::test]
374 async fn full_reverse_proxy_filter_forward_response() {
375 let address_str = "http://127.0.0.1:3030";
376 let filter = warp::path!("relative_path" / ..).and(reverse_proxy_filter(
377 "relative_path".to_string(),
378 address_str.to_string(),
379 ));
380 let address = ([127, 0, 0, 1], 3030);
381 let (path, method, body, header) = (
382 "https://127.0.0.1:3030/relative_path/foo",
383 "GET",
384 b"foo bar",
385 ("foo", "bar"),
386 );
387
388 serve_test_response("foo".to_string(), address.into());
389 tokio::task::yield_now().await;
390
391 let response = warp::test::request()
392 .path(path)
393 .method(method)
394 .body(body)
395 .header(header.0, header.1)
396 .reply(&filter)
397 .await;
398
399 assert_eq!(response.status(), StatusCode::OK);
400 }
401}