axum_strangler/lib.rs
1//! # Axum Strangler
2//! A `tower_service::Service` for use in the `axum` web framework to apply the Strangler Fig pattern.
3//! This makes "strangling" a bit easier, as everything that is handled by the "strangler" will
4//! automatically no longer be forwarded to the "stranglee" or "strangled application" (a.k.a. the old application).
5//!
6//! ## Example
7//! ```rust
8//! #[tokio::main]
9//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
10//! let strangler = axum_strangler::Strangler::new(
11//! axum::http::uri::Authority::from_static("127.0.0.1:3333"),
12//! );
13//! let router = axum::Router::new().fallback_service(strangler);
14//! axum::Server::bind(&"127.0.0.1:0".parse()?)
15//! .serve(router.into_make_service())
16//! # .with_graceful_shutdown(async {
17//! # // Shut down immediately
18//! # })
19//! .await?;
20//! Ok(())
21//! }
22//! ```
23//!
24//! ## Caveats
25//! Note that when registering a route with `axum`, all requests will be handled by it, even if you don't register anything for the specific method.
26//! This means that in the following snippet, requests for `/new` with the method
27//! POST, PUT, DELETE, OPTIONS, HEAD, PATCH, or TRACE will no longer be forwarded to the strangled application:
28//! ```rust
29//! async fn handler() {}
30//!
31//! #[tokio::main]
32//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
33//! let strangler = axum_strangler::Strangler::new(
34//! axum::http::uri::Authority::from_static("127.0.0.1:3333"),
35//! );
36//! let router = axum::Router::new()
37//! .route(
38//! "/test",
39//! axum::routing::get(handler)
40//! )
41//! .fallback_service(strangler);
42//! axum::Server::bind(&"127.0.0.1:0".parse()?)
43//! .serve(router.into_make_service())
44//! # .with_graceful_shutdown(async {
45//! # // Shut down immediately
46//! # })
47//! .await?;
48//! Ok(())
49//! }
50//! ```
51//!
52//! If you only want to implement a single method and still forward the rest, you can do so by adding the strangler as the fallback
53//! for that specific `MethodRouter`:
54//! ```rust
55//! async fn handler() {}
56//!
57//! #[tokio::main]
58//! async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
59//! let strangler = axum_strangler::Strangler::new(
60//! axum::http::uri::Authority::from_static("127.0.0.1:3333"),
61//! );
62//! let router = axum::Router::new()
63//! .route(
64//! "/test",
65//! axum::routing::get(handler)
66//! .fallback_service(strangler.clone())
67//! )
68//! .fallback_service(strangler);
69//! axum::Server::bind(&"127.0.0.1:0".parse()?)
70//! .serve(router.into_make_service())
71//! # .with_graceful_shutdown(async {
72//! # // Shut down immediately
73//! # })
74//! .await?;
75//! Ok(())
76//! }
77//! ```
78//!
79//! ## Websocket support
80//! If you enable the feature `websocket` (and possibly one of the supporting tls ones: websocket-native-tls,
81//! websocket-rustls-tls-native-roots, websocket-rustls-tls-webpki-roots), a websocket will be set up, and each websocket
82//! message will be relayed.
83//!
84//! ## Tracing propagation
85//! Enabling the `tracing-opentelemetry-text-map-propagation` feature, will cause traceparent header to be set on
86//! requests that get forwarded, based on the current `tracing` (& `tracing-opentelemetry`) context.
87//!
88//! Note that this requires the `opentelemetry` `TextMapPropagator` to be installed.
89
90#![cfg_attr(docsrs, feature(doc_cfg))]
91
92use std::{
93 convert::Infallible,
94 future::Future,
95 pin::Pin,
96 sync::Arc,
97 task::{Context, Poll},
98};
99
100use tower_service::Service;
101
102mod builder;
103mod inner;
104
105pub enum HttpScheme {
106 HTTP,
107 #[cfg(any(docsrs, feature = "https"))]
108 #[cfg_attr(docsrs, doc(cfg(feature = "https")))]
109 HTTPS,
110}
111
112#[cfg(any(docsrs, feature = "websocket"))]
113#[cfg_attr(docsrs, doc(cfg(feature = "websocket")))]
114pub enum WebSocketScheme {
115 WS,
116 #[cfg(any(
117 feature = "websocket-native-tls",
118 feature = "websocket-rustls-tls-native-roots",
119 feature = "websocket-rustls-tls-webpki-roots"
120 ))]
121 #[cfg_attr(
122 docsrs,
123 doc(cfg(any(
124 feature = "websocket-native-tls",
125 feature = "websocket-rustls-tls-native-roots",
126 feature = "websocket-rustls-tls-webpki-roots"
127 )))
128 )]
129 WSS,
130}
131
132/// Forwards all requests to another application.
133/// Can be used in a lot of places, but the most common one would be as a `.fallback` on an `axum` `Router`.
134/// # Example
135/// ```rust
136/// #[tokio::main]
137/// async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
138/// let strangler_svc = axum_strangler::Strangler::new(
139/// axum::http::uri::Authority::from_static("127.0.0.1:3333"),
140/// );
141/// let router = axum::Router::new().fallback_service(strangler_svc);
142/// axum::Server::bind(&"127.0.0.1:0".parse()?)
143/// .serve(router.into_make_service())
144/// # .with_graceful_shutdown(async {
145/// # // Shut down immediately
146/// # })
147/// .await?;
148/// Ok(())
149/// }
150/// ```
151#[derive(Clone)]
152pub struct Strangler {
153 inner: Arc<dyn inner::InnerStrangler + Send + Sync>,
154}
155
156impl Strangler {
157 /// Creates a new `Strangler` with the default options.
158 /// For more control, see [`builder::StranglerBuilder`]
159 pub fn new(strangled_authority: http::uri::Authority) -> Self {
160 Strangler::builder(strangled_authority).build()
161 }
162
163 pub fn builder(strangled_authority: http::uri::Authority) -> builder::StranglerBuilder {
164 builder::StranglerBuilder::new(strangled_authority)
165 }
166
167 /// Forwards the request to the strangled service.
168 /// Meant to be used when you want to send something to the strangled application
169 /// based on some custom logic.
170 pub async fn forward_to_strangled(
171 &self,
172 req: http::Request<hyper::body::Body>,
173 ) -> axum_core::response::Response {
174 self.inner.forward_call_to_strangled(req).await
175 }
176}
177
178impl Service<http::Request<hyper::body::Body>> for Strangler {
179 type Response = axum_core::response::Response;
180 type Error = Infallible;
181 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
182
183 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
184 Poll::Ready(Ok(()))
185 }
186
187 fn call(&mut self, req: http::Request<hyper::body::Body>) -> Self::Future {
188 let inner = self.inner.clone();
189
190 let fut = async move { Ok(inner.forward_call_to_strangled(req).await) };
191 Box::pin(fut)
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use axum::{body::HttpBody, Router};
199 use wiremock::{
200 matchers::{method, path},
201 Mock, ResponseTemplate,
202 };
203
204 /// Create a mock service that's not connecting to anything.
205 fn make_svc() -> Strangler {
206 Strangler::new(axum::http::uri::Authority::from_static("127.0.0.1:0"))
207 }
208
209 #[tokio::test]
210 async fn can_be_used_as_fallback() {
211 let router = Router::new().fallback_service(make_svc());
212 axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service());
213 }
214
215 #[tokio::test]
216 async fn can_be_used_for_a_route() {
217 let router = Router::new().route_service("/api", make_svc());
218 axum::Server::bind(&"0.0.0.0:0".parse().unwrap()).serve(router.into_make_service());
219 }
220
221 #[tokio::test]
222 async fn proxies_strangled_http_service() {
223 let mock_server = wiremock::MockServer::start().await;
224
225 Mock::given(method("GET"))
226 .and(path("/api/something"))
227 .respond_with(ResponseTemplate::new(200).set_body_string("I'm being strangled"))
228 .mount(&mock_server)
229 .await;
230
231 let strangler_svc = Strangler::new(
232 axum::http::uri::Authority::try_from(format!(
233 "127.0.0.1:{}",
234 mock_server.address().port()
235 ))
236 .unwrap(),
237 );
238
239 let router = Router::new().fallback_service(strangler_svc);
240
241 let req = http::Request::get("/api/something")
242 .body(hyper::body::Body::empty())
243 .unwrap();
244 let mut res = router.clone().call(req).await.unwrap();
245
246 assert_eq!(res.status(), http::StatusCode::OK);
247
248 assert_eq!(
249 res.body_mut().data().await.unwrap().unwrap(),
250 "I'm being strangled".as_bytes()
251 );
252 }
253
254 #[cfg(feature = "nested-routers")]
255 #[tokio::test]
256 async fn handles_nested_routers() {
257 let mock_server = wiremock::MockServer::start().await;
258
259 Mock::given(method("GET"))
260 .and(path("/api/something"))
261 .respond_with(ResponseTemplate::new(200))
262 .mount(&mock_server)
263 .await;
264
265 Mock::given(method("GET"))
266 .and(path("/api/something-else"))
267 .respond_with(ResponseTemplate::new(418))
268 .mount(&mock_server)
269 .await;
270
271 let strangler_svc = Strangler::new(
272 axum::http::uri::Authority::try_from(format!(
273 "127.0.0.1:{}",
274 mock_server.address().port()
275 ))
276 .unwrap(),
277 );
278
279 let nested_router = Router::new()
280 .route_service("/something", strangler_svc.clone())
281 .route_service("/something-else", strangler_svc.clone());
282 let router = Router::new()
283 .nest("/api", nested_router)
284 .fallback_service(strangler_svc);
285
286 let req = http::Request::get("/api/something")
287 .body(hyper::body::Body::empty())
288 .unwrap();
289 let res = router.clone().call(req).await.unwrap();
290 assert_eq!(res.status(), http::StatusCode::OK);
291
292 let req = http::Request::get("/api/something-else")
293 .body(hyper::body::Body::empty())
294 .unwrap();
295 let res = router.clone().call(req).await.unwrap();
296 assert_eq!(res.status(), http::StatusCode::IM_A_TEAPOT);
297
298 let req = http::Request::get("/not-api/something-else")
299 .body(hyper::body::Body::empty())
300 .unwrap();
301 let res = router.clone().call(req).await.unwrap();
302 assert_eq!(res.status(), http::StatusCode::NOT_FOUND);
303 }
304}