rama_http/layer/
normalize_path.rs1use crate::{Request, Response, Uri};
41use rama_core::{Context, Layer, Service};
42use rama_utils::macros::define_inner_service_accessors;
43use std::borrow::Cow;
44use std::fmt;
45
46#[derive(Debug, Clone, Default)]
50#[non_exhaustive]
51pub struct NormalizePathLayer;
52
53impl NormalizePathLayer {
54 pub fn trim_trailing_slash() -> Self {
59 NormalizePathLayer
60 }
61}
62
63impl<S> Layer<S> for NormalizePathLayer {
64 type Service = NormalizePath<S>;
65
66 fn layer(&self, inner: S) -> Self::Service {
67 NormalizePath::trim_trailing_slash(inner)
68 }
69}
70
71pub struct NormalizePath<S> {
75 inner: S,
76}
77
78impl<S: fmt::Debug> fmt::Debug for NormalizePath<S> {
79 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
80 f.debug_struct("NormalizePath")
81 .field("inner", &self.inner)
82 .finish()
83 }
84}
85
86impl<S: Clone> Clone for NormalizePath<S> {
87 fn clone(&self) -> Self {
88 Self {
89 inner: self.inner.clone(),
90 }
91 }
92}
93
94impl<S> NormalizePath<S> {
95 #[inline]
99 pub fn new(inner: S) -> Self {
100 Self::trim_trailing_slash(inner)
101 }
102
103 pub fn trim_trailing_slash(inner: S) -> Self {
108 Self { inner }
109 }
110
111 define_inner_service_accessors!();
112}
113
114impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for NormalizePath<S>
115where
116 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
117 State: Clone + Send + Sync + 'static,
118 ReqBody: Send + 'static,
119 ResBody: Send + 'static,
120{
121 type Response = S::Response;
122 type Error = S::Error;
123
124 fn serve(
125 &self,
126 ctx: Context<State>,
127 mut req: Request<ReqBody>,
128 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
129 normalize_trailing_slash(req.uri_mut());
130 self.inner.serve(ctx, req)
131 }
132}
133
134fn normalize_trailing_slash(uri: &mut Uri) {
135 if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
136 return;
137 }
138
139 let new_path = format!("/{}", uri.path().trim_matches('/'));
140
141 let mut parts = uri.clone().into_parts();
142
143 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
144 let new_path = if new_path.is_empty() {
145 "/"
146 } else {
147 new_path.as_str()
148 };
149
150 let new_path_and_query = if let Some(query) = path_and_query.query() {
151 Cow::Owned(format!("{}?{}", new_path, query))
152 } else {
153 new_path.into()
154 }
155 .parse()
156 .unwrap();
157
158 Some(new_path_and_query)
159 } else {
160 None
161 };
162
163 parts.path_and_query = new_path_and_query;
164 if let Ok(new_uri) = Uri::from_parts(parts) {
165 *uri = new_uri;
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172 use rama_core::Layer;
173 use rama_core::service::service_fn;
174 use std::convert::Infallible;
175
176 #[tokio::test]
177 async fn works() {
178 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
179 Ok(Response::new(request.uri().to_string()))
180 }
181
182 let svc = NormalizePathLayer::trim_trailing_slash().into_layer(service_fn(handle));
183
184 let body = svc
185 .serve(
186 Context::default(),
187 Request::builder().uri("/foo/").body(()).unwrap(),
188 )
189 .await
190 .unwrap()
191 .into_body();
192
193 assert_eq!(body, "/foo");
194 }
195
196 #[test]
197 fn is_noop_if_no_trailing_slash() {
198 let mut uri = "/foo".parse::<Uri>().unwrap();
199 normalize_trailing_slash(&mut uri);
200 assert_eq!(uri, "/foo");
201 }
202
203 #[test]
204 fn maintains_query() {
205 let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
206 normalize_trailing_slash(&mut uri);
207 assert_eq!(uri, "/foo?a=a");
208 }
209
210 #[test]
211 fn removes_multiple_trailing_slashes() {
212 let mut uri = "/foo////".parse::<Uri>().unwrap();
213 normalize_trailing_slash(&mut uri);
214 assert_eq!(uri, "/foo");
215 }
216
217 #[test]
218 fn removes_multiple_trailing_slashes_even_with_query() {
219 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
220 normalize_trailing_slash(&mut uri);
221 assert_eq!(uri, "/foo?a=a");
222 }
223
224 #[test]
225 fn is_noop_on_index() {
226 let mut uri = "/".parse::<Uri>().unwrap();
227 normalize_trailing_slash(&mut uri);
228 assert_eq!(uri, "/");
229 }
230
231 #[test]
232 fn removes_multiple_trailing_slashes_on_index() {
233 let mut uri = "////".parse::<Uri>().unwrap();
234 normalize_trailing_slash(&mut uri);
235 assert_eq!(uri, "/");
236 }
237
238 #[test]
239 fn removes_multiple_trailing_slashes_on_index_even_with_query() {
240 let mut uri = "////?a=a".parse::<Uri>().unwrap();
241 normalize_trailing_slash(&mut uri);
242 assert_eq!(uri, "/?a=a");
243 }
244
245 #[test]
246 fn removes_multiple_preceding_slashes_even_with_query() {
247 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
248 normalize_trailing_slash(&mut uri);
249 assert_eq!(uri, "/foo?a=a");
250 }
251
252 #[test]
253 fn removes_multiple_preceding_slashes() {
254 let mut uri = "///foo".parse::<Uri>().unwrap();
255 normalize_trailing_slash(&mut uri);
256 assert_eq!(uri, "/foo");
257 }
258}