axum_strangler/
lib.rs

1//! # Axum Strangler
2//! A `tower_service::Service` for use in the `axum` web framework to apply the  Strangler Fig pattern.
3//! This makes "strangling" a bit easier, as everything that is handled by the "strangler" will
4//! automatically no longer be forwarded to the "stranglee" or "strangled application" (a.k.a. the old application).
5//!
6//! ## Example
7//! ```rust
8//! #[tokio::main]
9//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
10//!     let strangler = axum_strangler::Strangler::new(
11//!         axum::http::uri::Authority::from_static("127.0.0.1:3333"),
12//!     );
13//!     let router = axum::Router::new().fallback_service(strangler);
14//!     axum::Server::bind(&"127.0.0.1:0".parse()?)
15//!         .serve(router.into_make_service())
16//!         # .with_graceful_shutdown(async {
17//!         # // Shut down immediately
18//!         # })
19//!         .await?;
20//!     Ok(())
21//! }
22//! ```
23//!
24//! ## Caveats
25//! Note that when registering a route with `axum`, all requests will be handled by it, even if you don't register anything for the specific method.
26//! This means that in the following snippet, requests for `/new` with the method
27//! POST, PUT, DELETE, OPTIONS, HEAD, PATCH, or TRACE will no longer be forwarded to the strangled application:
28//! ```rust
29//! async fn handler() {}
30//!
31//! #[tokio::main]
32//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
33//!     let strangler = axum_strangler::Strangler::new(
34//!         axum::http::uri::Authority::from_static("127.0.0.1:3333"),
35//!     );
36//!     let router = axum::Router::new()
37//!         .route(
38//!             "/test",
39//!              axum::routing::get(handler)
40//!         )
41//!         .fallback_service(strangler);
42//!     axum::Server::bind(&"127.0.0.1:0".parse()?)
43//!         .serve(router.into_make_service())
44//!         # .with_graceful_shutdown(async {
45//!         # // Shut down immediately
46//!         # })
47//!         .await?;
48//!     Ok(())
49//! }
50//! ```
51//!
52//! If you only want to implement a single method and still forward the rest, you can do so by adding the strangler as the fallback
53//! for that specific `MethodRouter`:
54//! ```rust
55//! async fn handler() {}
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
59//!     let strangler = axum_strangler::Strangler::new(
60//!         axum::http::uri::Authority::from_static("127.0.0.1:3333"),
61//!     );
62//!     let router = axum::Router::new()
63//!         .route(
64//!             "/test",
65//!             axum::routing::get(handler)
66//!                 .fallback_service(strangler.clone())
67//!         )
68//!         .fallback_service(strangler);
69//!     axum::Server::bind(&"127.0.0.1:0".parse()?)
70//!         .serve(router.into_make_service())
71//!         # .with_graceful_shutdown(async {
72//!         # // Shut down immediately
73//!         # })
74//!         .await?;
75//!     Ok(())
76//! }
77//! ```
78//!
79//! ## Websocket support
80//! If you enable the feature `websocket` (and possibly one of the supporting tls ones: websocket-native-tls,
81//! websocket-rustls-tls-native-roots, websocket-rustls-tls-webpki-roots), a websocket will be set up, and each websocket
82//! message will be relayed.
83//!
84//! ## Tracing propagation
85//! Enabling the `tracing-opentelemetry-text-map-propagation` feature, will cause traceparent header to be set on
86//! requests that get forwarded, based on the current `tracing` (& `tracing-opentelemetry`) context.
87//!
88//! Note that this requires the `opentelemetry` `TextMapPropagator` to be installed.
89
90#![cfg_attr(docsrs, feature(doc_cfg))]
91
92use std::{
93    convert::Infallible,
94    future::Future,
95    pin::Pin,
96    sync::Arc,
97    task::{Context, Poll},
98};
99
100use tower_service::Service;
101
102mod builder;
103mod inner;
104
105pub enum HttpScheme {
106    HTTP,
107    #[cfg(any(docsrs, feature = "https"))]
108    #[cfg_attr(docsrs, doc(cfg(feature = "https")))]
109    HTTPS,
110}
111
112#[cfg(any(docsrs, feature = "websocket"))]
113#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
114pub enum WebSocketScheme {
115    WS,
116    #[cfg(any(
117        feature = "websocket-native-tls",
118        feature = "websocket-rustls-tls-native-roots",
119        feature = "websocket-rustls-tls-webpki-roots"
120    ))]
121    #[cfg_attr(
122        docsrs,
123        doc(cfg(any(
124            feature = "websocket-native-tls",
125            feature = "websocket-rustls-tls-native-roots",
126            feature = "websocket-rustls-tls-webpki-roots"
127        )))
128    )]
129    WSS,
130}
131
132/// Forwards all requests to another application.
133/// Can be used in a lot of places, but the most common one would be as a `.fallback` on an `axum` `Router`.
134/// # Example
135/// ```rust
136/// #[tokio::main]
137/// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
138///     let strangler_svc = axum_strangler::Strangler::new(
139///         axum::http::uri::Authority::from_static("127.0.0.1:3333"),
140///     );
141///     let router = axum::Router::new().fallback_service(strangler_svc);
142///     axum::Server::bind(&"127.0.0.1:0".parse()?)
143///         .serve(router.into_make_service())
144///         # .with_graceful_shutdown(async {
145///         # // Shut down immediately
146///         # })
147///         .await?;
148///     Ok(())
149/// }
150/// ```
151#[derive(Clone)]
152pub struct Strangler {
153    inner: Arc<dyn inner::InnerStrangler + Send + Sync>,
154}
155
156impl Strangler {
157    /// Creates a new `Strangler` with the default options.
158    /// For more control, see [`builder::StranglerBuilder`]
159    pub fn new(strangled_authority: http::uri::Authority) -> Self {
160        Strangler::builder(strangled_authority).build()
161    }
162
163    pub fn builder(strangled_authority: http::uri::Authority) -> builder::StranglerBuilder {
164        builder::StranglerBuilder::new(strangled_authority)
165    }
166
167    /// Forwards the request to the strangled service.
168    /// Meant to be used when you want to send something to the strangled application
169    /// based on some custom logic.
170    pub async fn forward_to_strangled(
171        &self,
172        req: http::Request<hyper::body::Body>,
173    ) -> axum_core::response::Response {
174        self.inner.forward_call_to_strangled(req).await
175    }
176}
177
178impl Service<http::Request<hyper::body::Body>> for Strangler {
179    type Response = axum_core::response::Response;
180    type Error = Infallible;
181    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
182
183    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184        Poll::Ready(Ok(()))
185    }
186
187    fn call(&mut self, req: http::Request<hyper::body::Body>) -> Self::Future {
188        let inner = self.inner.clone();
189
190        let fut = async move { Ok(inner.forward_call_to_strangled(req).await) };
191        Box::pin(fut)
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use axum::{body::HttpBody, Router};
199    use wiremock::{
200        matchers::{method, path},
201        Mock, ResponseTemplate,
202    };
203
204    /// Create a mock service that's not connecting to anything.
205    fn make_svc() -> Strangler {
206        Strangler::new(axum::http::uri::Authority::from_static("127.0.0.1:0"))
207    }
208
209    #[tokio::test]
210    async fn can_be_used_as_fallback() {
211        let router = Router::new().fallback_service(make_svc());
212        axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service());
213    }
214
215    #[tokio::test]
216    async fn can_be_used_for_a_route() {
217        let router = Router::new().route_service("/api", make_svc());
218        axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service());
219    }
220
221    #[tokio::test]
222    async fn proxies_strangled_http_service() {
223        let mock_server = wiremock::MockServer::start().await;
224
225        Mock::given(method("GET"))
226            .and(path("/api/something"))
227            .respond_with(ResponseTemplate::new(200).set_body_string("I'm being strangled"))
228            .mount(&mock_server)
229            .await;
230
231        let strangler_svc = Strangler::new(
232            axum::http::uri::Authority::try_from(format!(
233                "127.0.0.1:{}",
234                mock_server.address().port()
235            ))
236            .unwrap(),
237        );
238
239        let router = Router::new().fallback_service(strangler_svc);
240
241        let req = http::Request::get("/api/something")
242            .body(hyper::body::Body::empty())
243            .unwrap();
244        let mut res = router.clone().call(req).await.unwrap();
245
246        assert_eq!(res.status(), http::StatusCode::OK);
247
248        assert_eq!(
249            res.body_mut().data().await.unwrap().unwrap(),
250            "I'm being strangled".as_bytes()
251        );
252    }
253
254    #[cfg(feature = "nested-routers")]
255    #[tokio::test]
256    async fn handles_nested_routers() {
257        let mock_server = wiremock::MockServer::start().await;
258
259        Mock::given(method("GET"))
260            .and(path("/api/something"))
261            .respond_with(ResponseTemplate::new(200))
262            .mount(&mock_server)
263            .await;
264
265        Mock::given(method("GET"))
266            .and(path("/api/something-else"))
267            .respond_with(ResponseTemplate::new(418))
268            .mount(&mock_server)
269            .await;
270
271        let strangler_svc = Strangler::new(
272            axum::http::uri::Authority::try_from(format!(
273                "127.0.0.1:{}",
274                mock_server.address().port()
275            ))
276            .unwrap(),
277        );
278
279        let nested_router = Router::new()
280            .route_service("/something", strangler_svc.clone())
281            .route_service("/something-else", strangler_svc.clone());
282        let router = Router::new()
283            .nest("/api", nested_router)
284            .fallback_service(strangler_svc);
285
286        let req = http::Request::get("/api/something")
287            .body(hyper::body::Body::empty())
288            .unwrap();
289        let res = router.clone().call(req).await.unwrap();
290        assert_eq!(res.status(), http::StatusCode::OK);
291
292        let req = http::Request::get("/api/something-else")
293            .body(hyper::body::Body::empty())
294            .unwrap();
295        let res = router.clone().call(req).await.unwrap();
296        assert_eq!(res.status(), http::StatusCode::IM_A_TEAPOT);
297
298        let req = http::Request::get("/not-api/something-else")
299            .body(hyper::body::Body::empty())
300            .unwrap();
301        let res = router.clone().call(req).await.unwrap();
302        assert_eq!(res.status(), http::StatusCode::NOT_FOUND);
303    }
304}