normalize_path_except/
lib.rs

1//! Middleware that normalizes paths, with exceptions.
2//!
3//! Forked with minimal changes from tower_http::NormalizePathLayer.
4//!
5//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
6//! will be changed to `/foo` before reaching the inner service.
7//!
8//! # Example
9//!
10//! ```
11//! use normalize_path_except::NormalizePathLayer;
12//! use http::{Request, Response, StatusCode};
13//! use http_body_util::Full;
14//! use bytes::Bytes;
15//! use std::{iter::once, convert::Infallible};
16//! use tower::{ServiceBuilder, Service, ServiceExt};
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! async fn handle(req: Request<Full<Bytes>>) -> Result<Response<Full<Bytes>>, Infallible> {
21//!     // `req.uri().path()` will not have trailing slashes
22//!     # Ok(Response::new(Full::default()))
23//! }
24//!
25//! let mut service = ServiceBuilder::new()
26//!     // trim trailing slashes from paths except `exceptions`
27//!     .layer(NormalizePathLayer::trim_trailing_slash(&["/swagger-ui"]))
28//!     .service_fn(handle);
29//!
30//! // call the service
31//! let request = Request::builder()
32//!     // `handle` will see `/foo`
33//!     .uri("/foo/")
34//!     .body(Full::default())?;
35//!
36//! service.ready().await?.call(request).await?;
37//! #
38//! # Ok(())
39//! # }
40//! ```
41
42use http::{Request, Response, Uri};
43use std::{
44    borrow::Cow,
45    task::{Context, Poll},
46};
47use tower_layer::Layer;
48use tower_service::Service;
49
50#[allow(unused_macros)]
51macro_rules! define_inner_service_accessors {
52    () => {
53        /// Gets a reference to the underlying service.
54        pub fn get_ref(&self) -> &S {
55            &self.inner
56        }
57
58        /// Gets a mutable reference to the underlying service.
59        pub fn get_mut(&mut self) -> &mut S {
60            &mut self.inner
61        }
62
63        /// Consumes `self`, returning the underlying service.
64        pub fn into_inner(self) -> S {
65            self.inner
66        }
67    };
68}
69
70/// Layer that applies [`NormalizePath`] which normalizes paths.
71///
72/// See the [module docs](self) for more details.
73#[derive(Debug, Clone)]
74pub struct NormalizePathLayer {
75    exceptions: Vec<String>,
76}
77
78impl NormalizePathLayer {
79    /// Create a new [`NormalizePathLayer`].
80    ///
81    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
82    /// will be changed to `/foo` before reaching the inner service.
83    pub fn trim_trailing_slash<S: AsRef<str>>(exceptions: &[S]) -> Self {
84        let exceptions = exceptions.iter().map(|x| x.as_ref().to_string()).collect();
85        NormalizePathLayer { exceptions }
86    }
87}
88
89impl<S> Layer<S> for NormalizePathLayer {
90    type Service = NormalizePath<S>;
91
92    fn layer(&self, inner: S) -> Self::Service {
93        NormalizePath::trim_trailing_slash(inner, &self.exceptions)
94    }
95}
96
97/// Middleware that normalizes paths.
98///
99/// See the [module docs](self) for more details.
100#[derive(Debug, Clone)]
101pub struct NormalizePath<S> {
102    exceptions: Vec<String>,
103    inner: S,
104}
105
106impl<S> NormalizePath<S> {
107    /// Create a new [`NormalizePath`].
108    ///
109    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
110    /// will be changed to `/foo` before reaching the inner service.
111    pub fn trim_trailing_slash<P: AsRef<str>>(inner: S, exceptions: &[P]) -> Self {
112        let exceptions = exceptions.iter().map(|x| x.as_ref().to_string()).collect();
113        Self { exceptions, inner }
114    }
115
116    define_inner_service_accessors!();
117}
118
119impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
120where
121    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
122{
123    type Response = S::Response;
124    type Error = S::Error;
125    type Future = S::Future;
126
127    #[inline]
128    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129        self.inner.poll_ready(cx)
130    }
131
132    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
133        let path = req.uri().path();
134        if !self.exceptions.iter().any(|x| path.starts_with(x)) {
135            normalize_trailing_slash(req.uri_mut());
136        }
137        self.inner.call(req)
138    }
139}
140
141fn normalize_trailing_slash(uri: &mut Uri) {
142    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
143        return;
144    }
145
146    let new_path = format!("/{}", uri.path().trim_matches('/'));
147
148    let mut parts = uri.clone().into_parts();
149
150    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
151        let new_path_and_query = if let Some(query) = path_and_query.query() {
152            Cow::Owned(format!("{}?{}", new_path, query))
153        } else {
154            new_path.into()
155        }
156        .parse()
157        .unwrap();
158
159        Some(new_path_and_query)
160    } else {
161        None
162    };
163
164    parts.path_and_query = new_path_and_query;
165    if let Ok(new_uri) = Uri::from_parts(parts) {
166        *uri = new_uri;
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use std::convert::Infallible;
174    use tower::{ServiceBuilder, ServiceExt};
175
176    #[tokio::test]
177    async fn works() {
178        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
179            Ok(Response::new(request.uri().to_string()))
180        }
181
182        let mut svc = ServiceBuilder::new()
183            .layer(NormalizePathLayer::trim_trailing_slash(&["/bar"]))
184            .service_fn(handle);
185
186        let body = svc
187            .ready()
188            .await
189            .unwrap()
190            .call(Request::builder().uri("/foo/").body(()).unwrap())
191            .await
192            .unwrap()
193            .into_body();
194
195        assert_eq!(body, "/foo");
196
197        let body = svc
198            .ready()
199            .await
200            .unwrap()
201            .call(Request::builder().uri("/foo/bar/").body(()).unwrap())
202            .await
203            .unwrap()
204            .into_body();
205
206        assert_eq!(body, "/foo/bar");
207
208        let body = svc
209            .ready()
210            .await
211            .unwrap()
212            .call(Request::builder().uri("/bar/").body(()).unwrap())
213            .await
214            .unwrap()
215            .into_body();
216
217        assert_eq!(body, "/bar/");
218
219        let body = svc
220            .ready()
221            .await
222            .unwrap()
223            .call(Request::builder().uri("/bar/baz/").body(()).unwrap())
224            .await
225            .unwrap()
226            .into_body();
227
228        assert_eq!(body, "/bar/baz/");
229    }
230
231    #[test]
232    fn is_noop_if_no_trailing_slash() {
233        let mut uri = "/foo".parse::<Uri>().unwrap();
234        normalize_trailing_slash(&mut uri);
235        assert_eq!(uri, "/foo");
236    }
237
238    #[test]
239    fn maintains_query() {
240        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
241        normalize_trailing_slash(&mut uri);
242        assert_eq!(uri, "/foo?a=a");
243    }
244
245    #[test]
246    fn removes_multiple_trailing_slashes() {
247        let mut uri = "/foo////".parse::<Uri>().unwrap();
248        normalize_trailing_slash(&mut uri);
249        assert_eq!(uri, "/foo");
250    }
251
252    #[test]
253    fn removes_multiple_trailing_slashes_even_with_query() {
254        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
255        normalize_trailing_slash(&mut uri);
256        assert_eq!(uri, "/foo?a=a");
257    }
258
259    #[test]
260    fn is_noop_on_index() {
261        let mut uri = "/".parse::<Uri>().unwrap();
262        normalize_trailing_slash(&mut uri);
263        assert_eq!(uri, "/");
264    }
265
266    #[test]
267    fn removes_multiple_trailing_slashes_on_index() {
268        let mut uri = "////".parse::<Uri>().unwrap();
269        normalize_trailing_slash(&mut uri);
270        assert_eq!(uri, "/");
271    }
272
273    #[test]
274    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
275        let mut uri = "////?a=a".parse::<Uri>().unwrap();
276        normalize_trailing_slash(&mut uri);
277        assert_eq!(uri, "/?a=a");
278    }
279
280    #[test]
281    fn removes_multiple_preceding_slashes_even_with_query() {
282        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
283        normalize_trailing_slash(&mut uri);
284        assert_eq!(uri, "/foo?a=a");
285    }
286
287    #[test]
288    fn removes_multiple_preceding_slashes() {
289        let mut uri = "///foo".parse::<Uri>().unwrap();
290        normalize_trailing_slash(&mut uri);
291        assert_eq!(uri, "/foo");
292    }
293}