axum_extra/routing/
mod.rs

1//! Additional types for defining routes.
2
3use axum::{
4    extract::{OriginalUri, Request},
5    response::{IntoResponse, Redirect, Response},
6    routing::{any, MethodRouter},
7    Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28// Validates a path at compile time, used with the vpath macro.
29#[rustversion::since(1.80)]
30#[doc(hidden)]
31#[must_use]
32pub const fn __private_validate_static_path(path: &'static str) -> &'static str {
33    if path.is_empty() {
34        panic!("Paths must start with a `/`. Use \"/\" for root routes")
35    }
36    if path.as_bytes()[0] != b'/' {
37        panic!("Paths must start with /");
38    }
39    path
40}
41
42/// This macro aborts compilation if the path is invalid.
43///
44/// This example will fail to compile:
45///
46/// ```compile_fail
47/// use axum::routing::{Router, get};
48/// use axum_extra::vpath;
49///
50/// let router = axum::Router::<()>::new()
51///     .route(vpath!("invalid_path"), get(root))
52///     .to_owned();
53///
54/// async fn root() {}
55/// ```
56///
57/// This one will compile without problems:
58///
59/// ```no_run
60/// use axum::routing::{Router, get};
61/// use axum_extra::vpath;
62///
63/// let router = axum::Router::<()>::new()
64///     .route(vpath!("/valid_path"), get(root))
65///     .to_owned();
66///
67/// async fn root() {}
68/// ```
69///
70/// This macro is available only on rust versions 1.80 and above.
71#[rustversion::since(1.80)]
72#[macro_export]
73macro_rules! vpath {
74    ($e:expr) => {
75        const { $crate::routing::__private_validate_static_path($e) }
76    };
77}
78
79/// Extension trait that adds additional methods to [`Router`].
80#[allow(clippy::return_self_not_must_use)]
81pub trait RouterExt<S>: sealed::Sealed {
82    /// Add a typed `GET` route to the router.
83    ///
84    /// The path will be inferred from the first argument to the handler function which must
85    /// implement [`TypedPath`].
86    ///
87    /// See [`TypedPath`] for more details and examples.
88    #[cfg(feature = "typed-routing")]
89    fn typed_get<H, T, P>(self, handler: H) -> Self
90    where
91        H: axum::handler::Handler<T, S>,
92        T: SecondElementIs<P> + 'static,
93        P: TypedPath;
94
95    /// Add a typed `DELETE` route to the router.
96    ///
97    /// The path will be inferred from the first argument to the handler function which must
98    /// implement [`TypedPath`].
99    ///
100    /// See [`TypedPath`] for more details and examples.
101    #[cfg(feature = "typed-routing")]
102    fn typed_delete<H, T, P>(self, handler: H) -> Self
103    where
104        H: axum::handler::Handler<T, S>,
105        T: SecondElementIs<P> + 'static,
106        P: TypedPath;
107
108    /// Add a typed `HEAD` route to the router.
109    ///
110    /// The path will be inferred from the first argument to the handler function which must
111    /// implement [`TypedPath`].
112    ///
113    /// See [`TypedPath`] for more details and examples.
114    #[cfg(feature = "typed-routing")]
115    fn typed_head<H, T, P>(self, handler: H) -> Self
116    where
117        H: axum::handler::Handler<T, S>,
118        T: SecondElementIs<P> + 'static,
119        P: TypedPath;
120
121    /// Add a typed `OPTIONS` route to the router.
122    ///
123    /// The path will be inferred from the first argument to the handler function which must
124    /// implement [`TypedPath`].
125    ///
126    /// See [`TypedPath`] for more details and examples.
127    #[cfg(feature = "typed-routing")]
128    fn typed_options<H, T, P>(self, handler: H) -> Self
129    where
130        H: axum::handler::Handler<T, S>,
131        T: SecondElementIs<P> + 'static,
132        P: TypedPath;
133
134    /// Add a typed `PATCH` route to the router.
135    ///
136    /// The path will be inferred from the first argument to the handler function which must
137    /// implement [`TypedPath`].
138    ///
139    /// See [`TypedPath`] for more details and examples.
140    #[cfg(feature = "typed-routing")]
141    fn typed_patch<H, T, P>(self, handler: H) -> Self
142    where
143        H: axum::handler::Handler<T, S>,
144        T: SecondElementIs<P> + 'static,
145        P: TypedPath;
146
147    /// Add a typed `POST` route to the router.
148    ///
149    /// The path will be inferred from the first argument to the handler function which must
150    /// implement [`TypedPath`].
151    ///
152    /// See [`TypedPath`] for more details and examples.
153    #[cfg(feature = "typed-routing")]
154    fn typed_post<H, T, P>(self, handler: H) -> Self
155    where
156        H: axum::handler::Handler<T, S>,
157        T: SecondElementIs<P> + 'static,
158        P: TypedPath;
159
160    /// Add a typed `PUT` route to the router.
161    ///
162    /// The path will be inferred from the first argument to the handler function which must
163    /// implement [`TypedPath`].
164    ///
165    /// See [`TypedPath`] for more details and examples.
166    #[cfg(feature = "typed-routing")]
167    fn typed_put<H, T, P>(self, handler: H) -> Self
168    where
169        H: axum::handler::Handler<T, S>,
170        T: SecondElementIs<P> + 'static,
171        P: TypedPath;
172
173    /// Add a typed `TRACE` route to the router.
174    ///
175    /// The path will be inferred from the first argument to the handler function which must
176    /// implement [`TypedPath`].
177    ///
178    /// See [`TypedPath`] for more details and examples.
179    #[cfg(feature = "typed-routing")]
180    fn typed_trace<H, T, P>(self, handler: H) -> Self
181    where
182        H: axum::handler::Handler<T, S>,
183        T: SecondElementIs<P> + 'static,
184        P: TypedPath;
185
186    /// Add a typed `CONNECT` route to the router.
187    ///
188    /// The path will be inferred from the first argument to the handler function which must
189    /// implement [`TypedPath`].
190    ///
191    /// See [`TypedPath`] for more details and examples.
192    #[cfg(feature = "typed-routing")]
193    fn typed_connect<H, T, P>(self, handler: H) -> Self
194    where
195        H: axum::handler::Handler<T, S>,
196        T: SecondElementIs<P> + 'static,
197        P: TypedPath;
198
199    /// Add another route to the router with an additional "trailing slash redirect" route.
200    ///
201    /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
202    /// route for `/foo/` that redirects to `/foo`.
203    ///
204    /// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a
205    /// route for `/bar` that redirects to `/bar/`.
206    ///
207    /// This is similar to what axum 0.5.x did by default, except this explicitly adds another
208    /// route, so trying to add a `/foo/` route after calling `.route_with_tsr("/foo", /* ... */)`
209    /// will result in a panic due to route overlap.
210    ///
211    /// # Example
212    ///
213    /// ```
214    /// use axum::{Router, routing::get};
215    /// use axum_extra::routing::RouterExt;
216    ///
217    /// let app = Router::new()
218    ///     // `/foo/` will redirect to `/foo`
219    ///     .route_with_tsr("/foo", get(|| async {}))
220    ///     // `/bar` will redirect to `/bar/`
221    ///     .route_with_tsr("/bar/", get(|| async {}));
222    /// # let _: Router = app;
223    /// ```
224    fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
225    where
226        Self: Sized;
227
228    /// Add another route to the router with an additional "trailing slash redirect" route.
229    ///
230    /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
231    fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
232    where
233        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
234        T::Response: IntoResponse,
235        T::Future: Send + 'static,
236        Self: Sized;
237}
238
239impl<S> RouterExt<S> for Router<S>
240where
241    S: Clone + Send + Sync + 'static,
242{
243    #[cfg(feature = "typed-routing")]
244    fn typed_get<H, T, P>(self, handler: H) -> Self
245    where
246        H: axum::handler::Handler<T, S>,
247        T: SecondElementIs<P> + 'static,
248        P: TypedPath,
249    {
250        self.route(P::PATH, axum::routing::get(handler))
251    }
252
253    #[cfg(feature = "typed-routing")]
254    fn typed_delete<H, T, P>(self, handler: H) -> Self
255    where
256        H: axum::handler::Handler<T, S>,
257        T: SecondElementIs<P> + 'static,
258        P: TypedPath,
259    {
260        self.route(P::PATH, axum::routing::delete(handler))
261    }
262
263    #[cfg(feature = "typed-routing")]
264    fn typed_head<H, T, P>(self, handler: H) -> Self
265    where
266        H: axum::handler::Handler<T, S>,
267        T: SecondElementIs<P> + 'static,
268        P: TypedPath,
269    {
270        self.route(P::PATH, axum::routing::head(handler))
271    }
272
273    #[cfg(feature = "typed-routing")]
274    fn typed_options<H, T, P>(self, handler: H) -> Self
275    where
276        H: axum::handler::Handler<T, S>,
277        T: SecondElementIs<P> + 'static,
278        P: TypedPath,
279    {
280        self.route(P::PATH, axum::routing::options(handler))
281    }
282
283    #[cfg(feature = "typed-routing")]
284    fn typed_patch<H, T, P>(self, handler: H) -> Self
285    where
286        H: axum::handler::Handler<T, S>,
287        T: SecondElementIs<P> + 'static,
288        P: TypedPath,
289    {
290        self.route(P::PATH, axum::routing::patch(handler))
291    }
292
293    #[cfg(feature = "typed-routing")]
294    fn typed_post<H, T, P>(self, handler: H) -> Self
295    where
296        H: axum::handler::Handler<T, S>,
297        T: SecondElementIs<P> + 'static,
298        P: TypedPath,
299    {
300        self.route(P::PATH, axum::routing::post(handler))
301    }
302
303    #[cfg(feature = "typed-routing")]
304    fn typed_put<H, T, P>(self, handler: H) -> Self
305    where
306        H: axum::handler::Handler<T, S>,
307        T: SecondElementIs<P> + 'static,
308        P: TypedPath,
309    {
310        self.route(P::PATH, axum::routing::put(handler))
311    }
312
313    #[cfg(feature = "typed-routing")]
314    fn typed_trace<H, T, P>(self, handler: H) -> Self
315    where
316        H: axum::handler::Handler<T, S>,
317        T: SecondElementIs<P> + 'static,
318        P: TypedPath,
319    {
320        self.route(P::PATH, axum::routing::trace(handler))
321    }
322
323    #[cfg(feature = "typed-routing")]
324    fn typed_connect<H, T, P>(self, handler: H) -> Self
325    where
326        H: axum::handler::Handler<T, S>,
327        T: SecondElementIs<P> + 'static,
328        P: TypedPath,
329    {
330        self.route(P::PATH, axum::routing::connect(handler))
331    }
332
333    #[track_caller]
334    fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
335    where
336        Self: Sized,
337    {
338        validate_tsr_path(path);
339        self = self.route(path, method_router);
340        add_tsr_redirect_route(self, path)
341    }
342
343    #[track_caller]
344    fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
345    where
346        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
347        T::Response: IntoResponse,
348        T::Future: Send + 'static,
349        Self: Sized,
350    {
351        validate_tsr_path(path);
352        self = self.route_service(path, service);
353        add_tsr_redirect_route(self, path)
354    }
355}
356
357#[track_caller]
358fn validate_tsr_path(path: &str) {
359    if path == "/" {
360        panic!("Cannot add a trailing slash redirect route for `/`")
361    }
362}
363
364fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
365where
366    S: Clone + Send + Sync + 'static,
367{
368    async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
369        let new_uri = map_path(uri, |path| {
370            path.strip_suffix('/')
371                .map(Cow::Borrowed)
372                .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
373        });
374
375        if let Some(new_uri) = new_uri {
376            Redirect::permanent(&new_uri.to_string()).into_response()
377        } else {
378            StatusCode::BAD_REQUEST.into_response()
379        }
380    }
381
382    if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
383        router.route(path_without_trailing_slash, any(redirect_handler))
384    } else {
385        router.route(&format!("{path}/"), any(redirect_handler))
386    }
387}
388
389/// Map the path of a `Uri`.
390///
391/// Returns `None` if the `Uri` cannot be put back together with the new path.
392fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
393where
394    F: FnOnce(&str) -> Cow<'_, str>,
395{
396    let mut parts = original_uri.into_parts();
397    let path_and_query = parts.path_and_query.as_ref()?;
398
399    let new_path = f(path_and_query.path());
400
401    let new_path_and_query = if let Some(query) = &path_and_query.query() {
402        format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
403    } else {
404        new_path.parse::<PathAndQuery>().ok()?
405    };
406    parts.path_and_query = Some(new_path_and_query);
407
408    Uri::from_parts(parts).ok()
409}
410
411mod sealed {
412    pub trait Sealed {}
413    impl<S> Sealed for axum::Router<S> {}
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use crate::test_helpers::*;
420    use axum::{extract::Path, routing::get};
421
422    #[tokio::test]
423    async fn test_tsr() {
424        let app = Router::new()
425            .route_with_tsr("/foo", get(|| async {}))
426            .route_with_tsr("/bar/", get(|| async {}));
427
428        let client = TestClient::new(app);
429
430        let res = client.get("/foo").await;
431        assert_eq!(res.status(), StatusCode::OK);
432
433        let res = client.get("/foo/").await;
434        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
435        assert_eq!(res.headers()["location"], "/foo");
436
437        let res = client.get("/bar/").await;
438        assert_eq!(res.status(), StatusCode::OK);
439
440        let res = client.get("/bar").await;
441        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
442        assert_eq!(res.headers()["location"], "/bar/");
443    }
444
445    #[tokio::test]
446    async fn tsr_with_params() {
447        let app = Router::new()
448            .route_with_tsr(
449                "/a/{a}",
450                get(|Path(param): Path<String>| async move { param }),
451            )
452            .route_with_tsr(
453                "/b/{b}/",
454                get(|Path(param): Path<String>| async move { param }),
455            );
456
457        let client = TestClient::new(app);
458
459        let res = client.get("/a/foo").await;
460        assert_eq!(res.status(), StatusCode::OK);
461        assert_eq!(res.text().await, "foo");
462
463        let res = client.get("/a/foo/").await;
464        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
465        assert_eq!(res.headers()["location"], "/a/foo");
466
467        let res = client.get("/b/foo/").await;
468        assert_eq!(res.status(), StatusCode::OK);
469        assert_eq!(res.text().await, "foo");
470
471        let res = client.get("/b/foo").await;
472        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
473        assert_eq!(res.headers()["location"], "/b/foo/");
474    }
475
476    #[tokio::test]
477    async fn tsr_maintains_query_params() {
478        let app = Router::new().route_with_tsr("/foo", get(|| async {}));
479
480        let client = TestClient::new(app);
481
482        let res = client.get("/foo/?a=a").await;
483        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
484        assert_eq!(res.headers()["location"], "/foo?a=a");
485    }
486
487    #[tokio::test]
488    async fn tsr_works_in_nested_router() {
489        let app = Router::new().nest(
490            "/neko",
491            Router::new().route_with_tsr("/nyan/", get(|| async {})),
492        );
493
494        let client = TestClient::new(app);
495        let res = client.get("/neko/nyan/").await;
496        assert_eq!(res.status(), StatusCode::OK);
497
498        let res = client.get("/neko/nyan").await;
499        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
500        assert_eq!(res.headers()["location"], "/neko/nyan/");
501    }
502
503    #[test]
504    #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
505    fn tsr_at_root() {
506        let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
507    }
508}