rama_http/layer/
normalize_path.rs

1//! Middleware that normalizes paths.
2//!
3//! Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
4//! will be changed to `/foo` before reaching the inner service.
5//!
6//! # Example
7//!
8//! ```
9//! use std::{iter::once, convert::Infallible};
10//! use rama_core::error::BoxError;
11//! use rama_core::service::service_fn;
12//! use rama_core::{Context, Layer, Service};
13//! use rama_http::{Body, Request, Response, StatusCode};
14//! use rama_http::layer::normalize_path::NormalizePathLayer;
15//!
16//! # #[tokio::main]
17//! # async fn main() -> Result<(), BoxError> {
18//! async fn handle(req: Request) -> Result<Response, Infallible> {
19//!     // `req.uri().path()` will not have trailing slashes
20//!     # Ok(Response::new(Body::default()))
21//! }
22//!
23//! let mut service = (
24//!     // trim trailing slashes from paths
25//!     NormalizePathLayer::trim_trailing_slash(),
26//! ).into_layer(service_fn(handle));
27//!
28//! // call the service
29//! let request = Request::builder()
30//!     // `handle` will see `/foo`
31//!     .uri("/foo/")
32//!     .body(Body::default())?;
33//!
34//! service.serve(Context::default(), request).await?;
35//! #
36//! # Ok(())
37//! # }
38//! ```
39
40use crate::{Request, Response, Uri};
41use rama_core::{Context, Layer, Service};
42use rama_utils::macros::define_inner_service_accessors;
43use std::borrow::Cow;
44use std::fmt;
45
46/// Layer that applies [`NormalizePath`] which normalizes paths.
47///
48/// See the [module docs](self) for more details.
49#[derive(Debug, Clone, Default)]
50#[non_exhaustive]
51pub struct NormalizePathLayer;
52
53impl NormalizePathLayer {
54    /// Create a new [`NormalizePathLayer`].
55    ///
56    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
57    /// will be changed to `/foo` before reaching the inner service.
58    pub fn trim_trailing_slash() -> Self {
59        NormalizePathLayer
60    }
61}
62
63impl<S> Layer<S> for NormalizePathLayer {
64    type Service = NormalizePath<S>;
65
66    fn layer(&self, inner: S) -> Self::Service {
67        NormalizePath::trim_trailing_slash(inner)
68    }
69}
70
71/// Middleware that normalizes paths.
72///
73/// See the [module docs](self) for more details.
74pub struct NormalizePath<S> {
75    inner: S,
76}
77
78impl<S: fmt::Debug> fmt::Debug for NormalizePath<S> {
79    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80        f.debug_struct("NormalizePath")
81            .field("inner", &self.inner)
82            .finish()
83    }
84}
85
86impl<S: Clone> Clone for NormalizePath<S> {
87    fn clone(&self) -> Self {
88        Self {
89            inner: self.inner.clone(),
90        }
91    }
92}
93
94impl<S> NormalizePath<S> {
95    /// Create a new [`NormalizePath`].
96    ///
97    /// Alias for [`Self::trim_trailing_slash`].
98    #[inline]
99    pub fn new(inner: S) -> Self {
100        Self::trim_trailing_slash(inner)
101    }
102
103    /// Create a new [`NormalizePath`].
104    ///
105    /// Any trailing slashes from request paths will be removed. For example, a request with `/foo/`
106    /// will be changed to `/foo` before reaching the inner service.
107    pub fn trim_trailing_slash(inner: S) -> Self {
108        Self { inner }
109    }
110
111    define_inner_service_accessors!();
112}
113
114impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for NormalizePath<S>
115where
116    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
117    State: Clone + Send + Sync + 'static,
118    ReqBody: Send + 'static,
119    ResBody: Send + 'static,
120{
121    type Response = S::Response;
122    type Error = S::Error;
123
124    fn serve(
125        &self,
126        ctx: Context<State>,
127        mut req: Request<ReqBody>,
128    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
129        normalize_trailing_slash(req.uri_mut());
130        self.inner.serve(ctx, req)
131    }
132}
133
134fn normalize_trailing_slash(uri: &mut Uri) {
135    if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
136        return;
137    }
138
139    let new_path = format!("/{}", uri.path().trim_matches('/'));
140
141    let mut parts = uri.clone().into_parts();
142
143    let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
144        let new_path = if new_path.is_empty() {
145            "/"
146        } else {
147            new_path.as_str()
148        };
149
150        let new_path_and_query = if let Some(query) = path_and_query.query() {
151            Cow::Owned(format!("{}?{}", new_path, query))
152        } else {
153            new_path.into()
154        }
155        .parse()
156        .unwrap();
157
158        Some(new_path_and_query)
159    } else {
160        None
161    };
162
163    parts.path_and_query = new_path_and_query;
164    if let Ok(new_uri) = Uri::from_parts(parts) {
165        *uri = new_uri;
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use rama_core::Layer;
173    use rama_core::service::service_fn;
174    use std::convert::Infallible;
175
176    #[tokio::test]
177    async fn works() {
178        async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
179            Ok(Response::new(request.uri().to_string()))
180        }
181
182        let svc = NormalizePathLayer::trim_trailing_slash().into_layer(service_fn(handle));
183
184        let body = svc
185            .serve(
186                Context::default(),
187                Request::builder().uri("/foo/").body(()).unwrap(),
188            )
189            .await
190            .unwrap()
191            .into_body();
192
193        assert_eq!(body, "/foo");
194    }
195
196    #[test]
197    fn is_noop_if_no_trailing_slash() {
198        let mut uri = "/foo".parse::<Uri>().unwrap();
199        normalize_trailing_slash(&mut uri);
200        assert_eq!(uri, "/foo");
201    }
202
203    #[test]
204    fn maintains_query() {
205        let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
206        normalize_trailing_slash(&mut uri);
207        assert_eq!(uri, "/foo?a=a");
208    }
209
210    #[test]
211    fn removes_multiple_trailing_slashes() {
212        let mut uri = "/foo////".parse::<Uri>().unwrap();
213        normalize_trailing_slash(&mut uri);
214        assert_eq!(uri, "/foo");
215    }
216
217    #[test]
218    fn removes_multiple_trailing_slashes_even_with_query() {
219        let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
220        normalize_trailing_slash(&mut uri);
221        assert_eq!(uri, "/foo?a=a");
222    }
223
224    #[test]
225    fn is_noop_on_index() {
226        let mut uri = "/".parse::<Uri>().unwrap();
227        normalize_trailing_slash(&mut uri);
228        assert_eq!(uri, "/");
229    }
230
231    #[test]
232    fn removes_multiple_trailing_slashes_on_index() {
233        let mut uri = "////".parse::<Uri>().unwrap();
234        normalize_trailing_slash(&mut uri);
235        assert_eq!(uri, "/");
236    }
237
238    #[test]
239    fn removes_multiple_trailing_slashes_on_index_even_with_query() {
240        let mut uri = "////?a=a".parse::<Uri>().unwrap();
241        normalize_trailing_slash(&mut uri);
242        assert_eq!(uri, "/?a=a");
243    }
244
245    #[test]
246    fn removes_multiple_preceding_slashes_even_with_query() {
247        let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
248        normalize_trailing_slash(&mut uri);
249        assert_eq!(uri, "/foo?a=a");
250    }
251
252    #[test]
253    fn removes_multiple_preceding_slashes() {
254        let mut uri = "///foo".parse::<Uri>().unwrap();
255        normalize_trailing_slash(&mut uri);
256        assert_eq!(uri, "/foo");
257    }
258}