axum_proxy/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! `axum-proxy` is tower [`Service`s](tower_service::Service) that performs "reverse
4//! proxy" with various rewriting rules.
5//!
6//! Internally these services use [`hyper::Client`] to send an incoming request to the another
7//! server. The [`connector`](hyper::client::connect::Connect) for a client can be
8//! [`HttpConnector`](hyper::client::HttpConnector), [`HttpsConnector`](hyper_tls::HttpsConnector),
9//! or any ones whichever you want.
10//!
11//! # Examples
12//!
13//! There are two types of services, [`OneshotService`] and [`ReusedService`]. The
14//! [`OneshotService`] *owns* the `Client`, while the [`ReusedService`] *shares* the `Client`
15//! via [`Arc`](std::sync::Arc).
16//!
17//!
18//! ## General usage
19//!
20//! ```
21//! # async fn run_test() {
22//! use axum_proxy::ReusedServiceBuilder;
23//! use axum_proxy::{ReplaceAll, ReplaceN};
24//!
25//! use hyper::body::Bytes;
26//! use http_body_util::Full;
27//! use http::Request;
28//! use tower_service::Service as _;
29//!
30//! let svc_builder = axum_proxy::builder_http("example.com:1234").unwrap();
31//!
32//! let req1 = Request::builder()
33//!     .method("GET")
34//!     .uri("https://myserver.com/foo/bar/foo")
35//!     .body(Full::new(Bytes::new()))
36//!     .unwrap();
37//!
38//! // Clones Arc<Client>
39//! let mut svc1 = svc_builder.build(ReplaceAll("foo", "baz"));
40//! // http://example.com:1234/baz/bar/baz
41//! let _res = svc1.call(req1).await.unwrap();
42//!
43//! let req2 = Request::builder()
44//!     .method("POST")
45//!     .uri("https://myserver.com/foo/bar/foo")
46//!     .header("Content-Type", "application/x-www-form-urlencoded")
47//!     .body(Full::new(Bytes::from("key=value")))
48//!     .unwrap();
49//!
50//! let mut svc2 = svc_builder.build(ReplaceN("foo", "baz", 1));
51//! // http://example.com:1234/baz/bar/foo
52//! let _res = svc2.call(req2).await.unwrap();
53//! # }
54//! ```
55//!
56//! In this example, the `svc1` and `svc2` shares the same `Client`, holding the `Arc<Client>`s
57//! inside them.
58//!
59//! For more information of rewriting rules (`ReplaceAll`, `ReplaceN` *etc.*), see the
60//! documentations of [`rewrite`].
61//!
62//!
63//! ## With axum
64//!
65//! ```
66//! # #[cfg(feature = "axum")] {
67//! use axum_proxy::ReusedServiceBuilder;
68//! use axum_proxy::{TrimPrefix, AppendSuffix, Static};
69//!
70//! use axum::Router;
71//!
72//! #[tokio::main]
73//! async fn main() {
74//!     let host1 = axum_proxy::builder_http("example.com").unwrap();
75//!     let host2 = axum_proxy::builder_http("example.net:1234").unwrap();
76//!
77//!     let app = Router::new()
78//!         .route_service("/healthcheck", host1.build(Static("/")))
79//!         .route_service("/users/{*path}", host1.build(TrimPrefix("/users")))
80//!         .route_service("/posts", host2.build(AppendSuffix("/")));
81//!
82//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
83//!        .await
84//!        .unwrap();
85//!
86//!    axum::serve(listener, app).await.unwrap();
87//! }
88//! # }
89//! ```
90//!
91//!
92//! # Return Types
93//!
94//! The return type ([`Future::Output`](std::future::Future::Output)) of [`ReusedService`] and
95//! [`OneshotService`] is `Result<Result<Response, Error>, Infallible>`. This is because axum's
96//! [`Router`](axum::Router) accepts only such `Service`s.
97//!
98//! The [`Error`] type implements [`IntoResponse`](axum::response::IntoResponse) if you enable the
99//! `axum`feature.
100//! It returns an empty body, with the status code `INTERNAL_SERVER_ERROR`. The description of this
101//! error will be logged out at [error](`log::error`) level in the
102//! [`into_response()`](axum::response::IntoResponse::into_response()) method.
103//!
104//!
105//! # Features
106//!
107//! By default only `http1` is enabled.
108//!
109//! - `http1`: uses `hyper/http1`
110//! - `http2`: uses `hyper/http2`
111//! - `https`: alias to `nativetls`
112//! - `nativetls`: uses the `hyper-tls` crate
113//! - `rustls`: alias to `rustls-webpki-roots`
114//! - `rustls-webpki-roots`: uses the `hyper-rustls` crate, with the feature `webpki-roots`
115//! - `rustls-native-roots`: uses the `hyper-rustls` crate, with the feature `rustls-native-certs`
116//! - `rustls-http2`: `http2` plus `rustls`, and `rustls/http2` is enabled
117//! - `axum`: implements [`IntoResponse`](axum::response::IntoResponse) for [`Error`]
118//!
119//! You must turn on either `http1`or `http2`. You cannot use the services if, for example, only
120//! the `https` feature is on.
121//!
122//! Through this document, we use `rustls` to mean *any* of `rustls*` features unless otherwise
123//! specified.
124
125mod error;
126pub use error::ProxyError;
127
128#[cfg(any(feature = "http1", feature = "http2"))]
129#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
130pub mod client;
131
132pub mod rewrite;
133pub use rewrite::*;
134
135mod future;
136pub use future::RevProxyFuture;
137
138#[cfg(any(feature = "http1", feature = "http2"))]
139mod oneshot;
140#[cfg(any(feature = "http1", feature = "http2"))]
141#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
142pub use oneshot::OneshotService;
143
144#[cfg(any(feature = "http1", feature = "http2"))]
145mod reused;
146#[cfg(any(feature = "http1", feature = "http2"))]
147#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
148pub use reused::Builder as ReusedServiceBuilder;
149#[cfg(any(feature = "http1", feature = "http2"))]
150#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
151pub use reused::ReusedService;
152#[cfg(all(
153    any(feature = "http1", feature = "http2"),
154    any(feature = "https", feature = "nativetls")
155))]
156#[cfg_attr(
157    docsrs,
158    doc(cfg(all(
159        any(feature = "http1", feature = "http2"),
160        any(feature = "https", feature = "nativetls")
161    )))
162)]
163pub use reused::builder_https;
164#[cfg(all(any(feature = "http1", feature = "http2"), feature = "nativetls"))]
165#[cfg_attr(
166    docsrs,
167    doc(cfg(all(any(feature = "http1", feature = "http2"), feature = "nativetls")))
168)]
169pub use reused::builder_nativetls;
170#[cfg(all(any(feature = "http1", feature = "http2"), feature = "__rustls"))]
171#[cfg_attr(
172    docsrs,
173    doc(cfg(all(any(feature = "http1", feature = "http2"), feature = "rustls")))
174)]
175pub use reused::builder_rustls;
176#[cfg(any(feature = "http1", feature = "http2"))]
177#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
178pub use reused::{builder, builder_http};
179
180#[cfg(not(feature = "http1"))]
181compile_error!("http1 is a mandatory feature");
182
183#[cfg(all(
184    any(feature = "rustls-ring", feature = "rustls-aws-lc"),
185    not(any(feature = "rustls-webpki-roots", feature = "rustls-native-roots"))
186))]
187compile_error!(
188    "When enabling rustls-ring and/or rustls-aws-lc, you must enable rustls-webpki-roots and/or rustls-native-roots"
189);
190
191#[cfg(test)]
192mod test_helper {
193    use std::convert::Infallible;
194
195    use http::{Request, Response, StatusCode};
196    use http_body_util::BodyExt as _;
197    use hyper::body::Incoming;
198    use mockito::{Matcher, ServerGuard};
199    use tower_service::Service;
200
201    use super::{ProxyError, RevProxyFuture};
202
203    async fn call<S, B>(
204        service: &mut S,
205        (method, suffix, content_type, body): (&str, &str, Option<&str>, B),
206        expected: (StatusCode, &str),
207    ) where
208        S: Service<
209                Request<String>,
210                Response = Result<Response<Incoming>, ProxyError>,
211                Error = Infallible,
212                Future = RevProxyFuture,
213            >,
214        B: Into<String>,
215    {
216        let mut builder = Request::builder()
217            .method(method)
218            .uri(format!("https://test.com{}", suffix));
219
220        if let Some(content_type) = content_type {
221            builder = builder.header("Content-Type", content_type);
222        }
223
224        let request = builder.body(body.into()).unwrap();
225
226        let result = service.call(request).await.unwrap();
227        assert!(result.is_ok());
228
229        let response = result.unwrap();
230        assert_eq!(response.status(), expected.0);
231
232        let body = response.into_body().collect().await;
233        assert!(body.is_ok());
234
235        assert_eq!(body.unwrap().to_bytes(), expected.1);
236    }
237
238    pub async fn match_path<S>(server: &mut ServerGuard, svc: &mut S)
239    where
240        S: Service<
241                Request<String>,
242                Response = Result<Response<Incoming>, ProxyError>,
243                Error = Infallible,
244                Future = RevProxyFuture,
245            >,
246    {
247        let _mk = server
248            .mock("GET", "/goo/bar/goo/baz/goo")
249            .with_body("ok")
250            .create_async()
251            .await;
252
253        call(
254            svc,
255            ("GET", "/foo/bar/foo/baz/foo", None, ""),
256            (StatusCode::OK, "ok"),
257        )
258        .await;
259
260        call(
261            svc,
262            ("GET", "/foo/bar/foo/baz", None, ""),
263            (StatusCode::NOT_IMPLEMENTED, ""),
264        )
265        .await;
266    }
267
268    pub async fn match_query<S>(server: &mut ServerGuard, svc: &mut S)
269    where
270        S: Service<
271                Request<String>,
272                Response = Result<Response<Incoming>, ProxyError>,
273                Error = Infallible,
274                Future = RevProxyFuture,
275            >,
276    {
277        let _mk = server
278            .mock("GET", "/goo")
279            .match_query(Matcher::UrlEncoded("greeting".into(), "good day".into()))
280            .with_body("ok")
281            .create_async()
282            .await;
283
284        call(
285            svc,
286            ("GET", "/foo?greeting=good%20day", None, ""),
287            (StatusCode::OK, "ok"),
288        )
289        .await;
290
291        call(
292            svc,
293            ("GET", "/foo", None, ""),
294            (StatusCode::NOT_IMPLEMENTED, ""),
295        )
296        .await;
297    }
298
299    pub async fn match_post<S>(server: &mut ServerGuard, svc: &mut S)
300    where
301        S: Service<
302                Request<String>,
303                Response = Result<Response<Incoming>, ProxyError>,
304                Error = Infallible,
305                Future = RevProxyFuture,
306            >,
307    {
308        let _mk = server
309            .mock("POST", "/goo")
310            .match_body("test")
311            .with_body("ok")
312            .create_async()
313            .await;
314
315        call(svc, ("POST", "/foo", None, "test"), (StatusCode::OK, "ok")).await;
316
317        call(
318            svc,
319            ("PUT", "/foo", None, "test"),
320            (StatusCode::NOT_IMPLEMENTED, ""),
321        )
322        .await;
323
324        call(
325            svc,
326            ("POST", "/foo", None, "tests"),
327            (StatusCode::NOT_IMPLEMENTED, ""),
328        )
329        .await;
330    }
331
332    pub async fn match_header<S>(server: &mut ServerGuard, svc: &mut S)
333    where
334        S: Service<
335                Request<String>,
336                Response = Result<Response<Incoming>, ProxyError>,
337                Error = Infallible,
338                Future = RevProxyFuture,
339            >,
340    {
341        let _mk = server
342            .mock("POST", "/goo")
343            .match_header("content-type", "application/json")
344            .match_body(r#"{"key":"value"}"#)
345            .with_body("ok")
346            .create_async()
347            .await;
348
349        call(
350            svc,
351            (
352                "POST",
353                "/foo",
354                Some("application/json"),
355                r#"{"key":"value"}"#,
356            ),
357            (StatusCode::OK, "ok"),
358        )
359        .await;
360
361        call(
362            svc,
363            ("POST", "/foo", None, r#"{"key":"value"}"#),
364            (StatusCode::NOT_IMPLEMENTED, ""),
365        )
366        .await;
367
368        call(
369            svc,
370            (
371                "POST",
372                "/foo",
373                Some("application/json"),
374                r#"{"key":"values"}"#,
375            ),
376            (StatusCode::NOT_IMPLEMENTED, ""),
377        )
378        .await;
379    }
380}