1use axum::{
4 extract::{OriginalUri, Request},
5 response::{IntoResponse, Redirect, Response},
6 routing::{any, MethodRouter},
7 Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28#[rustversion::since(1.80)]
30#[doc(hidden)]
31#[must_use]
32pub const fn __private_validate_static_path(path: &'static str) -> &'static str {
33 if path.is_empty() {
34 panic!("Paths must start with a `/`. Use \"/\" for root routes")
35 }
36 if path.as_bytes()[0] != b'/' {
37 panic!("Paths must start with /");
38 }
39 path
40}
41
42#[rustversion::since(1.80)]
72#[macro_export]
73macro_rules! vpath {
74 ($e:expr) => {
75 const { $crate::routing::__private_validate_static_path($e) }
76 };
77}
78
79#[allow(clippy::return_self_not_must_use)]
81pub trait RouterExt<S>: sealed::Sealed {
82 #[cfg(feature = "typed-routing")]
89 fn typed_get<H, T, P>(self, handler: H) -> Self
90 where
91 H: axum::handler::Handler<T, S>,
92 T: SecondElementIs<P> + 'static,
93 P: TypedPath;
94
95 #[cfg(feature = "typed-routing")]
102 fn typed_delete<H, T, P>(self, handler: H) -> Self
103 where
104 H: axum::handler::Handler<T, S>,
105 T: SecondElementIs<P> + 'static,
106 P: TypedPath;
107
108 #[cfg(feature = "typed-routing")]
115 fn typed_head<H, T, P>(self, handler: H) -> Self
116 where
117 H: axum::handler::Handler<T, S>,
118 T: SecondElementIs<P> + 'static,
119 P: TypedPath;
120
121 #[cfg(feature = "typed-routing")]
128 fn typed_options<H, T, P>(self, handler: H) -> Self
129 where
130 H: axum::handler::Handler<T, S>,
131 T: SecondElementIs<P> + 'static,
132 P: TypedPath;
133
134 #[cfg(feature = "typed-routing")]
141 fn typed_patch<H, T, P>(self, handler: H) -> Self
142 where
143 H: axum::handler::Handler<T, S>,
144 T: SecondElementIs<P> + 'static,
145 P: TypedPath;
146
147 #[cfg(feature = "typed-routing")]
154 fn typed_post<H, T, P>(self, handler: H) -> Self
155 where
156 H: axum::handler::Handler<T, S>,
157 T: SecondElementIs<P> + 'static,
158 P: TypedPath;
159
160 #[cfg(feature = "typed-routing")]
167 fn typed_put<H, T, P>(self, handler: H) -> Self
168 where
169 H: axum::handler::Handler<T, S>,
170 T: SecondElementIs<P> + 'static,
171 P: TypedPath;
172
173 #[cfg(feature = "typed-routing")]
180 fn typed_trace<H, T, P>(self, handler: H) -> Self
181 where
182 H: axum::handler::Handler<T, S>,
183 T: SecondElementIs<P> + 'static,
184 P: TypedPath;
185
186 #[cfg(feature = "typed-routing")]
193 fn typed_connect<H, T, P>(self, handler: H) -> Self
194 where
195 H: axum::handler::Handler<T, S>,
196 T: SecondElementIs<P> + 'static,
197 P: TypedPath;
198
199 fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
225 where
226 Self: Sized;
227
228 fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
232 where
233 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
234 T::Response: IntoResponse,
235 T::Future: Send + 'static,
236 Self: Sized;
237}
238
239impl<S> RouterExt<S> for Router<S>
240where
241 S: Clone + Send + Sync + 'static,
242{
243 #[cfg(feature = "typed-routing")]
244 fn typed_get<H, T, P>(self, handler: H) -> Self
245 where
246 H: axum::handler::Handler<T, S>,
247 T: SecondElementIs<P> + 'static,
248 P: TypedPath,
249 {
250 self.route(P::PATH, axum::routing::get(handler))
251 }
252
253 #[cfg(feature = "typed-routing")]
254 fn typed_delete<H, T, P>(self, handler: H) -> Self
255 where
256 H: axum::handler::Handler<T, S>,
257 T: SecondElementIs<P> + 'static,
258 P: TypedPath,
259 {
260 self.route(P::PATH, axum::routing::delete(handler))
261 }
262
263 #[cfg(feature = "typed-routing")]
264 fn typed_head<H, T, P>(self, handler: H) -> Self
265 where
266 H: axum::handler::Handler<T, S>,
267 T: SecondElementIs<P> + 'static,
268 P: TypedPath,
269 {
270 self.route(P::PATH, axum::routing::head(handler))
271 }
272
273 #[cfg(feature = "typed-routing")]
274 fn typed_options<H, T, P>(self, handler: H) -> Self
275 where
276 H: axum::handler::Handler<T, S>,
277 T: SecondElementIs<P> + 'static,
278 P: TypedPath,
279 {
280 self.route(P::PATH, axum::routing::options(handler))
281 }
282
283 #[cfg(feature = "typed-routing")]
284 fn typed_patch<H, T, P>(self, handler: H) -> Self
285 where
286 H: axum::handler::Handler<T, S>,
287 T: SecondElementIs<P> + 'static,
288 P: TypedPath,
289 {
290 self.route(P::PATH, axum::routing::patch(handler))
291 }
292
293 #[cfg(feature = "typed-routing")]
294 fn typed_post<H, T, P>(self, handler: H) -> Self
295 where
296 H: axum::handler::Handler<T, S>,
297 T: SecondElementIs<P> + 'static,
298 P: TypedPath,
299 {
300 self.route(P::PATH, axum::routing::post(handler))
301 }
302
303 #[cfg(feature = "typed-routing")]
304 fn typed_put<H, T, P>(self, handler: H) -> Self
305 where
306 H: axum::handler::Handler<T, S>,
307 T: SecondElementIs<P> + 'static,
308 P: TypedPath,
309 {
310 self.route(P::PATH, axum::routing::put(handler))
311 }
312
313 #[cfg(feature = "typed-routing")]
314 fn typed_trace<H, T, P>(self, handler: H) -> Self
315 where
316 H: axum::handler::Handler<T, S>,
317 T: SecondElementIs<P> + 'static,
318 P: TypedPath,
319 {
320 self.route(P::PATH, axum::routing::trace(handler))
321 }
322
323 #[cfg(feature = "typed-routing")]
324 fn typed_connect<H, T, P>(self, handler: H) -> Self
325 where
326 H: axum::handler::Handler<T, S>,
327 T: SecondElementIs<P> + 'static,
328 P: TypedPath,
329 {
330 self.route(P::PATH, axum::routing::connect(handler))
331 }
332
333 #[track_caller]
334 fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
335 where
336 Self: Sized,
337 {
338 validate_tsr_path(path);
339 self = self.route(path, method_router);
340 add_tsr_redirect_route(self, path)
341 }
342
343 #[track_caller]
344 fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
345 where
346 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
347 T::Response: IntoResponse,
348 T::Future: Send + 'static,
349 Self: Sized,
350 {
351 validate_tsr_path(path);
352 self = self.route_service(path, service);
353 add_tsr_redirect_route(self, path)
354 }
355}
356
357#[track_caller]
358fn validate_tsr_path(path: &str) {
359 if path == "/" {
360 panic!("Cannot add a trailing slash redirect route for `/`")
361 }
362}
363
364fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
365where
366 S: Clone + Send + Sync + 'static,
367{
368 async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
369 let new_uri = map_path(uri, |path| {
370 path.strip_suffix('/')
371 .map(Cow::Borrowed)
372 .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
373 });
374
375 if let Some(new_uri) = new_uri {
376 Redirect::permanent(&new_uri.to_string()).into_response()
377 } else {
378 StatusCode::BAD_REQUEST.into_response()
379 }
380 }
381
382 if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
383 router.route(path_without_trailing_slash, any(redirect_handler))
384 } else {
385 router.route(&format!("{path}/"), any(redirect_handler))
386 }
387}
388
389fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
393where
394 F: FnOnce(&str) -> Cow<'_, str>,
395{
396 let mut parts = original_uri.into_parts();
397 let path_and_query = parts.path_and_query.as_ref()?;
398
399 let new_path = f(path_and_query.path());
400
401 let new_path_and_query = if let Some(query) = &path_and_query.query() {
402 format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
403 } else {
404 new_path.parse::<PathAndQuery>().ok()?
405 };
406 parts.path_and_query = Some(new_path_and_query);
407
408 Uri::from_parts(parts).ok()
409}
410
411mod sealed {
412 pub trait Sealed {}
413 impl<S> Sealed for axum::Router<S> {}
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use crate::test_helpers::*;
420 use axum::{extract::Path, routing::get};
421
422 #[tokio::test]
423 async fn test_tsr() {
424 let app = Router::new()
425 .route_with_tsr("/foo", get(|| async {}))
426 .route_with_tsr("/bar/", get(|| async {}));
427
428 let client = TestClient::new(app);
429
430 let res = client.get("/foo").await;
431 assert_eq!(res.status(), StatusCode::OK);
432
433 let res = client.get("/foo/").await;
434 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
435 assert_eq!(res.headers()["location"], "/foo");
436
437 let res = client.get("/bar/").await;
438 assert_eq!(res.status(), StatusCode::OK);
439
440 let res = client.get("/bar").await;
441 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
442 assert_eq!(res.headers()["location"], "/bar/");
443 }
444
445 #[tokio::test]
446 async fn tsr_with_params() {
447 let app = Router::new()
448 .route_with_tsr(
449 "/a/{a}",
450 get(|Path(param): Path<String>| async move { param }),
451 )
452 .route_with_tsr(
453 "/b/{b}/",
454 get(|Path(param): Path<String>| async move { param }),
455 );
456
457 let client = TestClient::new(app);
458
459 let res = client.get("/a/foo").await;
460 assert_eq!(res.status(), StatusCode::OK);
461 assert_eq!(res.text().await, "foo");
462
463 let res = client.get("/a/foo/").await;
464 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
465 assert_eq!(res.headers()["location"], "/a/foo");
466
467 let res = client.get("/b/foo/").await;
468 assert_eq!(res.status(), StatusCode::OK);
469 assert_eq!(res.text().await, "foo");
470
471 let res = client.get("/b/foo").await;
472 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
473 assert_eq!(res.headers()["location"], "/b/foo/");
474 }
475
476 #[tokio::test]
477 async fn tsr_maintains_query_params() {
478 let app = Router::new().route_with_tsr("/foo", get(|| async {}));
479
480 let client = TestClient::new(app);
481
482 let res = client.get("/foo/?a=a").await;
483 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
484 assert_eq!(res.headers()["location"], "/foo?a=a");
485 }
486
487 #[tokio::test]
488 async fn tsr_works_in_nested_router() {
489 let app = Router::new().nest(
490 "/neko",
491 Router::new().route_with_tsr("/nyan/", get(|| async {})),
492 );
493
494 let client = TestClient::new(app);
495 let res = client.get("/neko/nyan/").await;
496 assert_eq!(res.status(), StatusCode::OK);
497
498 let res = client.get("/neko/nyan").await;
499 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
500 assert_eq!(res.headers()["location"], "/neko/nyan/");
501 }
502
503 #[test]
504 #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
505 fn tsr_at_root() {
506 let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
507 }
508}