Skip to main content

axum_proxy/
lib.rs

1#![cfg_attr(docsrs, feature(doc_cfg))]
2
3//! `axum-proxy` is tower [`Service`s](tower_service::Service) that performs "reverse
4//! proxy" with various rewriting rules.
5//!
6//! Internally these services use [`hyper::Client`] to send an incoming request to the another
7//! server. The [`connector`](hyper::client::connect::Connect) for a client can be
8//! [`HttpConnector`](hyper::client::HttpConnector), [`HttpsConnector`](hyper_tls::HttpsConnector),
9//! or any ones whichever you want.
10//!
11//! # Examples
12//!
13//! There are two types of services, [`OneshotService`] and [`ReusedService`]. The
14//! [`OneshotService`] *owns* the `Client`, while the [`ReusedService`] *shares* the `Client`
15//! via [`Arc`](std::sync::Arc).
16//!
17//!
18//! ## General usage
19//!
20//! ```
21//! # async fn run_test() {
22//! use axum_proxy::ReusedServiceBuilder;
23//! use axum_proxy::{ReplaceAll, ReplaceN};
24//!
25//! use hyper::body::Bytes;
26//! use http_body_util::Full;
27//! use http::Request;
28//! use tower_service::Service as _;
29//!
30//! let svc_builder = axum_proxy::builder_http("example.com:1234").unwrap();
31//!
32//! let req1 = Request::builder()
33//!     .method("GET")
34//!     .uri("https://myserver.com/foo/bar/foo")
35//!     .body(Full::new(Bytes::new()))
36//!     .unwrap();
37//!
38//! // Clones Arc<Client>
39//! let mut svc1 = svc_builder.build(ReplaceAll("foo", "baz"));
40//! // http://example.com:1234/baz/bar/baz
41//! let _res = svc1.call(req1).await.unwrap();
42//!
43//! let req2 = Request::builder()
44//!     .method("POST")
45//!     .uri("https://myserver.com/foo/bar/foo")
46//!     .header("Content-Type", "application/x-www-form-urlencoded")
47//!     .body(Full::new(Bytes::from("key=value")))
48//!     .unwrap();
49//!
50//! let mut svc2 = svc_builder.build(ReplaceN("foo", "baz", 1));
51//! // http://example.com:1234/baz/bar/foo
52//! let _res = svc2.call(req2).await.unwrap();
53//! # }
54//! ```
55//!
56//! In this example, the `svc1` and `svc2` shares the same `Client`, holding the `Arc<Client>`s
57//! inside them.
58//!
59//! For more information of rewriting rules (`ReplaceAll`, `ReplaceN` *etc.*), see the
60//! documentations of [`rewrite`].
61//!
62//!
63//! ## With axum
64//!
65//! ```
66//! # #[cfg(feature = "axum")] {
67//! use axum_proxy::ReusedServiceBuilder;
68//! use axum_proxy::{TrimPrefix, AppendSuffix, Static};
69//!
70//! use axum::Router;
71//!
72//! #[tokio::main]
73//! async fn main() {
74//!     let host1 = axum_proxy::builder_http("example.com").unwrap();
75//!     let host2 = axum_proxy::builder_http("example.net:1234").unwrap();
76//!
77//!     let app = Router::new()
78//!         .route_service("/healthcheck", host1.build(Static("/")))
79//!         .route_service("/users/{*path}", host1.build(TrimPrefix("/users")))
80//!         .route_service("/posts", host2.build(AppendSuffix("/")));
81//!
82//!     let listener = tokio::net::TcpListener::bind("0.0.0.0:3000")
83//!        .await
84//!        .unwrap();
85//!
86//!    axum::serve(listener, app).await.unwrap();
87//! }
88//! # }
89//! ```
90//!
91//!
92//! # Return Types
93//!
94//! The return type ([`Future::Output`](std::future::Future::Output)) of [`ReusedService`] and
95//! [`OneshotService`] is `Result<Result<Response, Error>, Infallible>`. This is because axum's
96//! [`Router`](axum::Router) accepts only such `Service`s.
97//!
98//! The [`Error`] type implements [`IntoResponse`](axum::response::IntoResponse) if you enable the
99//! `axum`feature.
100//! It returns an empty body, with the status code `INTERNAL_SERVER_ERROR`. The description of this
101//! error will be logged out at [error](`log::error`) level in the
102//! [`into_response()`](axum::response::IntoResponse::into_response()) method.
103//!
104//!
105//! # Features
106//!
107//! By default only `http1` is enabled.
108//!
109//! - `http1`: uses `hyper/http1`
110//! - `http2`: uses `hyper/http2`
111//! - `https`: alias to `nativetls`
112//! - `nativetls`: uses the `hyper-tls` crate
113//! - `rustls`: alias to `rustls-webpki-roots`
114//! - `rustls-webpki-roots`: uses the `hyper-rustls` crate, with the feature `webpki-roots`
115//! - `rustls-native-roots`: uses the `hyper-rustls` crate, with the feature `rustls-native-certs`
116//! - `rustls-http2`: `http2` plus `rustls`, and `rustls/http2` is enabled
117//! - `axum`: implements [`IntoResponse`](axum::response::IntoResponse) for [`Error`]
118//!
119//! You must turn on either `http1`or `http2`. You cannot use the services if, for example, only
120//! the `https` feature is on.
121//!
122//! Through this document, we use `rustls` to mean *any* of `rustls*` features unless otherwise
123//! specified.
124
125mod error;
126pub use error::ProxyError;
127
128#[cfg(any(feature = "http1", feature = "http2"))]
129#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
130pub mod client;
131
132pub mod rewrite;
133pub use rewrite::*;
134
135mod future;
136pub use future::RevProxyFuture;
137
138#[cfg(any(feature = "http1", feature = "http2"))]
139mod oneshot;
140#[cfg(any(feature = "http1", feature = "http2"))]
141#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
142pub use oneshot::OneshotService;
143
144#[cfg(any(feature = "http1", feature = "http2"))]
145mod reused;
146#[cfg(any(feature = "http1", feature = "http2"))]
147#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
148pub use reused::Builder as ReusedServiceBuilder;
149#[cfg(any(feature = "http1", feature = "http2"))]
150#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
151pub use reused::ReusedService;
152#[cfg(all(
153    any(feature = "http1", feature = "http2"),
154    any(feature = "https", feature = "nativetls")
155))]
156#[cfg_attr(
157    docsrs,
158    doc(cfg(all(
159        any(feature = "http1", feature = "http2"),
160        any(feature = "https", feature = "nativetls")
161    )))
162)]
163pub use reused::builder_https;
164#[cfg(all(any(feature = "http1", feature = "http2"), feature = "nativetls"))]
165#[cfg_attr(
166    docsrs,
167    doc(cfg(all(any(feature = "http1", feature = "http2"), feature = "nativetls")))
168)]
169pub use reused::builder_nativetls;
170#[cfg(all(any(feature = "http1", feature = "http2"), feature = "__rustls"))]
171#[cfg_attr(
172    docsrs,
173    doc(cfg(all(any(feature = "http1", feature = "http2"), feature = "rustls")))
174)]
175pub use reused::builder_rustls;
176#[cfg(any(feature = "http1", feature = "http2"))]
177#[cfg_attr(docsrs, doc(cfg(any(feature = "http1", feature = "http2"))))]
178pub use reused::{builder, builder_http};
179
180#[cfg(not(feature = "http1"))]
181compile_error!("http1 is a mandatory feature");
182
183#[cfg(all(
184    any(feature = "rustls-ring", feature = "rustls-aws-lc"),
185    not(any(feature = "rustls-webpki-roots", feature = "rustls-native-roots"))
186))]
187compile_error!(
188    "When enabling rustls-ring and/or rustls-aws-lc, you must enable rustls-webpki-roots and/or rustls-native-roots"
189);
190
191#[cfg(test)]
192mod test_helper {
193    use std::convert::Infallible;
194
195    use http::{Request, Response, StatusCode};
196    use http_body_util::BodyExt as _;
197    use hyper::body::Incoming;
198    use mockito::{Matcher, ServerGuard};
199    use pretty_assertions::assert_eq;
200    use tower_service::Service;
201
202    use super::{ProxyError, RevProxyFuture};
203
204    async fn call<S, B>(
205        service: &mut S,
206        (method, suffix, content_type, body): (&str, &str, Option<&str>, B),
207        expected: (StatusCode, &str),
208    ) where
209        S: Service<
210                Request<String>,
211                Response = Result<Response<Incoming>, ProxyError>,
212                Error = Infallible,
213                Future = RevProxyFuture,
214            >,
215        B: Into<String>,
216    {
217        let mut builder = Request::builder()
218            .method(method)
219            .uri(format!("https://test.com{}", suffix));
220
221        if let Some(content_type) = content_type {
222            builder = builder.header("Content-Type", content_type);
223        }
224
225        let request = builder.body(body.into()).unwrap();
226
227        let result = service.call(request).await.unwrap();
228        assert!(result.is_ok());
229
230        let response = result.unwrap();
231        assert_eq!(response.status(), expected.0);
232
233        let body = response.into_body().collect().await;
234        assert!(body.is_ok());
235
236        assert_eq!(body.unwrap().to_bytes(), expected.1);
237    }
238
239    pub async fn match_path<S>(server: &mut ServerGuard, svc: &mut S)
240    where
241        S: Service<
242                Request<String>,
243                Response = Result<Response<Incoming>, ProxyError>,
244                Error = Infallible,
245                Future = RevProxyFuture,
246            >,
247    {
248        let _mk = server
249            .mock("GET", "/goo/bar/goo/baz/goo")
250            .with_body("ok")
251            .create_async()
252            .await;
253
254        call(
255            svc,
256            ("GET", "/foo/bar/foo/baz/foo", None, ""),
257            (StatusCode::OK, "ok"),
258        )
259        .await;
260
261        call(
262            svc,
263            ("GET", "/foo/bar/foo/baz", None, ""),
264            (StatusCode::NOT_IMPLEMENTED, ""),
265        )
266        .await;
267    }
268
269    pub async fn match_query<S>(server: &mut ServerGuard, svc: &mut S)
270    where
271        S: Service<
272                Request<String>,
273                Response = Result<Response<Incoming>, ProxyError>,
274                Error = Infallible,
275                Future = RevProxyFuture,
276            >,
277    {
278        let _mk = server
279            .mock("GET", "/goo")
280            .match_query(Matcher::UrlEncoded("greeting".into(), "good day".into()))
281            .with_body("ok")
282            .create_async()
283            .await;
284
285        call(
286            svc,
287            ("GET", "/foo?greeting=good%20day", None, ""),
288            (StatusCode::OK, "ok"),
289        )
290        .await;
291
292        call(
293            svc,
294            ("GET", "/foo", None, ""),
295            (StatusCode::NOT_IMPLEMENTED, ""),
296        )
297        .await;
298    }
299
300    pub async fn match_post<S>(server: &mut ServerGuard, svc: &mut S)
301    where
302        S: Service<
303                Request<String>,
304                Response = Result<Response<Incoming>, ProxyError>,
305                Error = Infallible,
306                Future = RevProxyFuture,
307            >,
308    {
309        let _mk = server
310            .mock("POST", "/goo")
311            .match_body("test")
312            .with_body("ok")
313            .create_async()
314            .await;
315
316        call(svc, ("POST", "/foo", None, "test"), (StatusCode::OK, "ok")).await;
317
318        call(
319            svc,
320            ("PUT", "/foo", None, "test"),
321            (StatusCode::NOT_IMPLEMENTED, ""),
322        )
323        .await;
324
325        call(
326            svc,
327            ("POST", "/foo", None, "tests"),
328            (StatusCode::NOT_IMPLEMENTED, ""),
329        )
330        .await;
331    }
332
333    pub async fn match_header<S>(server: &mut ServerGuard, svc: &mut S)
334    where
335        S: Service<
336                Request<String>,
337                Response = Result<Response<Incoming>, ProxyError>,
338                Error = Infallible,
339                Future = RevProxyFuture,
340            >,
341    {
342        let _mk = server
343            .mock("POST", "/goo")
344            .match_header("content-type", "application/json")
345            .match_body(r#"{"key":"value"}"#)
346            .with_body("ok")
347            .create_async()
348            .await;
349
350        call(
351            svc,
352            (
353                "POST",
354                "/foo",
355                Some("application/json"),
356                r#"{"key":"value"}"#,
357            ),
358            (StatusCode::OK, "ok"),
359        )
360        .await;
361
362        call(
363            svc,
364            ("POST", "/foo", None, r#"{"key":"value"}"#),
365            (StatusCode::NOT_IMPLEMENTED, ""),
366        )
367        .await;
368
369        call(
370            svc,
371            (
372                "POST",
373                "/foo",
374                Some("application/json"),
375                r#"{"key":"values"}"#,
376            ),
377            (StatusCode::NOT_IMPLEMENTED, ""),
378        )
379        .await;
380    }
381}