1use std::{
4 marker::PhantomData,
5 pin::Pin,
6 task::{Context, Poll, ready},
7};
8
9use actix_service::{Service, Transform};
10use actix_utils::future::{Ready, ready};
11use actix_web::{
12 Error, HttpResponse,
13 body::EitherBody,
14 dev::{ServiceRequest, ServiceResponse},
15 http::{
16 StatusCode, header,
17 uri::{PathAndQuery, Uri},
18 },
19 middleware::TrailingSlash,
20};
21use bytes::Bytes;
22use pin_project_lite::pin_project;
23use regex::Regex;
24
25#[derive(Debug, Clone, Copy)]
74pub struct NormalizePath {
75 trailing_slash_behavior: TrailingSlash,
77
78 use_redirects: Option<StatusCode>,
80}
81
82impl Default for NormalizePath {
83 fn default() -> Self {
84 Self {
85 trailing_slash_behavior: TrailingSlash::Trim,
86 use_redirects: None,
87 }
88 }
89}
90
91impl NormalizePath {
92 pub fn new(behavior: TrailingSlash) -> Self {
94 Self {
95 trailing_slash_behavior: behavior,
96 use_redirects: None,
97 }
98 }
99
100 pub fn trim() -> Self {
104 Self::new(TrailingSlash::Trim)
105 }
106
107 pub fn use_redirects(mut self) -> Self {
116 self.use_redirects = Some(StatusCode::TEMPORARY_REDIRECT);
117 self
118 }
119
120 pub fn use_redirects_with(mut self, status_code: StatusCode) -> Self {
128 assert!(status_code.is_redirection());
129 self.use_redirects = Some(status_code);
130 self
131 }
132}
133
134impl<S, B> Transform<S, ServiceRequest> for NormalizePath
135where
136 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
137 S::Future: 'static,
138{
139 type Response = ServiceResponse<EitherBody<B, ()>>;
140 type Error = Error;
141 type Transform = NormalizePathService<S>;
142 type InitError = ();
143 type Future = Ready<Result<Self::Transform, Self::InitError>>;
144
145 fn new_transform(&self, service: S) -> Self::Future {
146 ready(Ok(NormalizePathService {
147 service,
148 merge_slash: Regex::new("//+").unwrap(),
149 trailing_slash_behavior: self.trailing_slash_behavior,
150 use_redirects: self.use_redirects,
151 }))
152 }
153}
154
155#[doc(hidden)]
157#[allow(missing_debug_implementations)]
158pub struct NormalizePathService<S> {
159 service: S,
160 merge_slash: Regex,
161 trailing_slash_behavior: TrailingSlash,
162 use_redirects: Option<StatusCode>,
163}
164
165impl<S, B> Service<ServiceRequest> for NormalizePathService<S>
166where
167 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
168 S::Future: 'static,
169{
170 type Response = ServiceResponse<EitherBody<B, ()>>;
171 type Error = Error;
172 type Future = NormalizePathFuture<S, B>;
173
174 actix_service::forward_ready!(service);
175
176 fn call(&self, mut req: ServiceRequest) -> Self::Future {
177 let head = req.head_mut();
178
179 let mut path_altered = false;
180 let original_path = head.uri.path();
181
182 if !original_path.is_empty() {
185 let path = match self.trailing_slash_behavior {
188 TrailingSlash::Always => format!("{original_path}/"),
189 TrailingSlash::MergeOnly => original_path.to_string(),
190 TrailingSlash::Trim => original_path.trim_end_matches('/').to_string(),
191 ts_behavior => panic!("unknown trailing slash behavior: {ts_behavior:?}"),
192 };
193
194 let path = self.merge_slash.replace_all(&path, "/");
196
197 let path = if path.is_empty() { "/" } else { path.as_ref() };
200
201 if path != original_path {
213 let mut parts = head.uri.clone().into_parts();
214 let query = parts.path_and_query.as_ref().and_then(|pq| pq.query());
215
216 let path = match query {
217 Some(query) => Bytes::from(format!("{path}?{query}")),
218 None => Bytes::copy_from_slice(path.as_bytes()),
219 };
220 parts.path_and_query = Some(PathAndQuery::from_maybe_shared(path).unwrap());
221
222 let uri = Uri::from_parts(parts).unwrap();
223 req.match_info_mut().get_mut().update(&uri);
224 req.head_mut().uri = uri;
225
226 path_altered = true;
227 }
228 }
229
230 match self.use_redirects {
231 Some(code) if path_altered => {
232 let mut res = HttpResponse::with_body(code, ());
233 res.headers_mut().insert(
234 header::LOCATION,
235 req.head_mut().uri.to_string().parse().unwrap(),
236 );
237 NormalizePathFuture::redirect(req.into_response(res))
238 }
239
240 _ => NormalizePathFuture::service(self.service.call(req)),
241 }
242 }
243}
244
245pin_project! {
246 pub struct NormalizePathFuture<S: Service<ServiceRequest>, B> {
247 #[pin] inner: Inner<S, B>,
248 }
249}
250
251impl<S: Service<ServiceRequest>, B> NormalizePathFuture<S, B> {
252 fn service(fut: S::Future) -> Self {
253 Self {
254 inner: Inner::Service {
255 fut,
256 _body: PhantomData,
257 },
258 }
259 }
260
261 fn redirect(res: ServiceResponse<()>) -> Self {
262 Self {
263 inner: Inner::Redirect { res: Some(res) },
264 }
265 }
266}
267
268pin_project! {
269 #[project = InnerProj]
270 enum Inner<S: Service<ServiceRequest>, B> {
271 Redirect { res: Option<ServiceResponse<()>>, },
272 Service {
273 #[pin] fut: S::Future,
274 _body: PhantomData<B>,
275 },
276 }
277}
278
279impl<S, B> Future for NormalizePathFuture<S, B>
280where
281 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
282{
283 type Output = Result<ServiceResponse<EitherBody<B, ()>>, Error>;
284
285 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
286 let this = self.project();
287
288 match this.inner.project() {
289 InnerProj::Redirect { res } => {
290 Poll::Ready(Ok(res.take().unwrap().map_into_right_body()))
291 }
292
293 InnerProj::Service { fut, .. } => {
294 let res = ready!(fut.poll(cx))?;
295 Poll::Ready(Ok(res.map_into_left_body()))
296 }
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use actix_service::IntoService;
304 use actix_web::{
305 App, HttpRequest, HttpResponse,
306 dev::ServiceRequest,
307 guard::fn_guard,
308 test::{self, TestRequest, call_service, init_service},
309 web,
310 };
311
312 use super::*;
313
314 #[actix_web::test]
315 async fn default_is_trim_no_redirect() {
316 let app = init_service(App::new().wrap(NormalizePath::default()).service(
317 web::resource("/test").to(|req: HttpRequest| async move { req.path().to_owned() }),
318 ))
319 .await;
320
321 let req = TestRequest::with_uri("/test/").to_request();
322 let res = call_service(&app, req).await;
323 assert!(res.status().is_success());
324 assert_eq!(test::read_body(res).await, "/test");
325 }
326
327 #[actix_web::test]
328 async fn trim_trailing_slashes() {
329 let app = init_service(
330 App::new()
331 .wrap(NormalizePath::trim())
332 .service(web::resource("/").to(HttpResponse::Ok))
333 .service(web::resource("/v1/something").to(HttpResponse::Ok))
334 .service(
335 web::resource("/v2/something")
336 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
337 .to(HttpResponse::Ok),
338 ),
339 )
340 .await;
341
342 let test_uris = vec![
343 "/",
344 "/?query=test",
345 "///",
346 "/v1//something",
347 "/v1//something////",
348 "//v1/something",
349 "//v1//////something",
350 "/v2//something?query=test",
351 "/v2//something////?query=test",
352 "//v2/something?query=test",
353 "//v2//////something?query=test",
354 ];
355
356 for uri in test_uris {
357 let req = TestRequest::with_uri(uri).to_request();
358 let res = call_service(&app, req).await;
359 assert!(res.status().is_success(), "Failed uri: {uri}");
360 }
361 }
362
363 #[actix_web::test]
364 async fn always_trailing_slashes() {
365 let app = init_service(
366 App::new()
367 .wrap(NormalizePath::new(TrailingSlash::Always))
368 .service(web::resource("/").to(HttpResponse::Ok))
369 .service(web::resource("/v1/something/").to(HttpResponse::Ok))
370 .service(
371 web::resource("/v2/something/")
372 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
373 .to(HttpResponse::Ok),
374 ),
375 )
376 .await;
377
378 let test_uris = vec![
379 "/",
380 "///",
381 "/v1/something",
382 "/v1/something/",
383 "/v1/something////",
384 "//v1//something",
385 "//v1//something//",
386 "/v2/something?query=test",
387 "/v2/something/?query=test",
388 "/v2/something////?query=test",
389 "//v2//something?query=test",
390 "//v2//something//?query=test",
391 ];
392
393 for uri in test_uris {
394 let req = TestRequest::with_uri(uri).to_request();
395 let res = call_service(&app, req).await;
396 assert!(res.status().is_success(), "Failed uri: {uri}");
397 }
398 }
399
400 #[actix_web::test]
401 async fn trim_root_trailing_slashes_with_query() {
402 let app = init_service(
403 App::new()
404 .wrap(NormalizePath::new(TrailingSlash::Trim))
405 .service(
406 web::resource("/")
407 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
408 .to(HttpResponse::Ok),
409 ),
410 )
411 .await;
412
413 let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
414
415 for uri in test_uris {
416 let req = TestRequest::with_uri(uri).to_request();
417 let res = call_service(&app, req).await;
418 assert!(res.status().is_success(), "Failed uri: {uri}");
419 }
420 }
421
422 #[actix_web::test]
423 async fn ensure_trailing_slash() {
424 let app = init_service(
425 App::new()
426 .wrap(NormalizePath::new(TrailingSlash::Always))
427 .service(web::resource("/").to(HttpResponse::Ok))
428 .service(web::resource("/v1/something/").to(HttpResponse::Ok))
429 .service(
430 web::resource("/v2/something/")
431 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
432 .to(HttpResponse::Ok),
433 ),
434 )
435 .await;
436
437 let test_uris = vec![
438 "/",
439 "///",
440 "/v1/something",
441 "/v1/something/",
442 "/v1/something////",
443 "//v1//something",
444 "//v1//something//",
445 "/v2/something?query=test",
446 "/v2/something/?query=test",
447 "/v2/something////?query=test",
448 "//v2//something?query=test",
449 "//v2//something//?query=test",
450 ];
451
452 for uri in test_uris {
453 let req = TestRequest::with_uri(uri).to_request();
454 let res = call_service(&app, req).await;
455 assert!(res.status().is_success(), "Failed uri: {uri}");
456 }
457 }
458
459 #[actix_web::test]
460 async fn ensure_root_trailing_slash_with_query() {
461 let app = init_service(
462 App::new()
463 .wrap(NormalizePath::new(TrailingSlash::Always))
464 .service(
465 web::resource("/")
466 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
467 .to(HttpResponse::Ok),
468 ),
469 )
470 .await;
471
472 let test_uris = vec!["/?query=test", "//?query=test", "///?query=test"];
473
474 for uri in test_uris {
475 let req = TestRequest::with_uri(uri).to_request();
476 let res = call_service(&app, req).await;
477 assert!(res.status().is_success(), "Failed uri: {uri}");
478 }
479 }
480
481 #[actix_web::test]
482 async fn keep_trailing_slash_unchanged() {
483 let app = init_service(
484 App::new()
485 .wrap(NormalizePath::new(TrailingSlash::MergeOnly))
486 .service(web::resource("/").to(HttpResponse::Ok))
487 .service(web::resource("/v1/something").to(HttpResponse::Ok))
488 .service(web::resource("/v1/").to(HttpResponse::Ok))
489 .service(
490 web::resource("/v2/something")
491 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
492 .to(HttpResponse::Ok),
493 ),
494 )
495 .await;
496
497 let tests = vec![
498 ("/", true), ("/?query=test", true),
500 ("///", true),
501 ("/v1/something////", false),
502 ("/v1/something/", false),
503 ("//v1//something", true),
504 ("/v1/", true),
505 ("/v1", false),
506 ("/v1////", true),
507 ("//v1//", true),
508 ("///v1", false),
509 ("/v2/something?query=test", true),
510 ("/v2/something/?query=test", false),
511 ("/v2/something//?query=test", false),
512 ("//v2//something?query=test", true),
513 ];
514
515 for (uri, success) in tests {
516 let req = TestRequest::with_uri(uri).to_request();
517 let res = call_service(&app, req).await;
518 assert_eq!(res.status().is_success(), success, "Failed uri: {uri}");
519 }
520 }
521
522 #[actix_web::test]
523 async fn no_path() {
524 let app = init_service(
525 App::new()
526 .wrap(NormalizePath::default())
527 .service(web::resource("/").to(HttpResponse::Ok)),
528 )
529 .await;
530
531 let req = TestRequest::with_uri("eh").to_request();
534 let res = call_service(&app, req).await;
535 assert_eq!(res.status(), StatusCode::NOT_FOUND);
536 }
537
538 #[actix_web::test]
539 async fn test_in_place_normalization() {
540 let srv = |req: ServiceRequest| {
541 assert_eq!("/v1/something", req.path());
542 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
543 };
544
545 let normalize = NormalizePath::default()
546 .new_transform(srv.into_service())
547 .await
548 .unwrap();
549
550 let test_uris = vec![
551 "/v1//something////",
552 "///v1/something",
553 "//v1///something",
554 "/v1//something",
555 ];
556
557 for uri in test_uris {
558 let req = TestRequest::with_uri(uri).to_srv_request();
559 let res = normalize.call(req).await.unwrap();
560 assert!(res.status().is_success(), "Failed uri: {uri}");
561 }
562 }
563
564 #[actix_web::test]
565 async fn should_normalize_nothing() {
566 const URI: &str = "/v1/something";
567
568 let srv = |req: ServiceRequest| {
569 assert_eq!(URI, req.path());
570 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
571 };
572
573 let normalize = NormalizePath::default()
574 .new_transform(srv.into_service())
575 .await
576 .unwrap();
577
578 let req = TestRequest::with_uri(URI).to_srv_request();
579 let res = normalize.call(req).await.unwrap();
580 assert!(res.status().is_success());
581 }
582
583 #[actix_web::test]
584 async fn should_normalize_no_trail() {
585 let srv = |req: ServiceRequest| {
586 assert_eq!("/v1/something", req.path());
587 ready(Ok(req.into_response(HttpResponse::Ok().finish())))
588 };
589
590 let normalize = NormalizePath::default()
591 .new_transform(srv.into_service())
592 .await
593 .unwrap();
594
595 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
596 let res = normalize.call(req).await.unwrap();
597 assert!(res.status().is_success());
598 }
599
600 #[actix_web::test]
601 async fn should_return_redirects_when_configured() {
602 let normalize = NormalizePath::trim()
603 .use_redirects()
604 .new_transform(test::ok_service())
605 .await
606 .unwrap();
607
608 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
609 let res = normalize.call(req).await.unwrap();
610 assert_eq!(res.status(), StatusCode::TEMPORARY_REDIRECT);
611
612 let normalize = NormalizePath::trim()
613 .use_redirects_with(StatusCode::PERMANENT_REDIRECT)
614 .new_transform(test::ok_service())
615 .await
616 .unwrap();
617
618 let req = TestRequest::with_uri("/v1/something/").to_srv_request();
619 let res = normalize.call(req).await.unwrap();
620 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
621 }
622
623 #[actix_web::test]
624 async fn trim_with_redirect() {
625 let app = init_service(
626 App::new()
627 .wrap(NormalizePath::trim().use_redirects())
628 .service(web::resource("/").to(HttpResponse::Ok))
629 .service(web::resource("/v1/something").to(HttpResponse::Ok))
630 .service(
631 web::resource("/v2/something")
632 .guard(fn_guard(|ctx| ctx.head().uri.query() == Some("query=test")))
633 .to(HttpResponse::Ok),
634 ),
635 )
636 .await;
637
638 let test_uris = vec![
640 ("/", false),
641 ("///", true),
642 ("/v1/something", false),
643 ("/v1/something/", true),
644 ("/v1/something////", true),
645 ("//v1//something", true),
646 ("//v1//something//", true),
647 ("/v2/something?query=test", false),
648 ("/v2/something/?query=test", true),
649 ("/v2/something////?query=test", true),
650 ("//v2//something?query=test", true),
651 ("//v2//something//?query=test", true),
652 ];
653
654 for (uri, should_redirect) in test_uris {
655 let req = TestRequest::with_uri(uri).to_request();
656 let res = call_service(&app, req).await;
657
658 if should_redirect {
659 assert!(res.status().is_redirection(), "URI did not redirect: {uri}");
660 } else {
661 assert!(res.status().is_success(), "Failed URI: {uri}");
662 }
663 }
664 }
665}