normalize_path_except/
lib.rs1use http::{Request, Response, Uri};
43use std::{
44 borrow::Cow,
45 task::{Context, Poll},
46};
47use tower_layer::Layer;
48use tower_service::Service;
49
50#[allow(unused_macros)]
51macro_rules! define_inner_service_accessors {
52 () => {
53 pub fn get_ref(&self) -> &S {
55 &self.inner
56 }
57
58 pub fn get_mut(&mut self) -> &mut S {
60 &mut self.inner
61 }
62
63 pub fn into_inner(self) -> S {
65 self.inner
66 }
67 };
68}
69
70#[derive(Debug, Clone)]
74pub struct NormalizePathLayer {
75 exceptions: Vec<String>,
76}
77
78impl NormalizePathLayer {
79 pub fn trim_trailing_slash<S: AsRef<str>>(exceptions: &[S]) -> Self {
84 let exceptions = exceptions.iter().map(|x| x.as_ref().to_string()).collect();
85 NormalizePathLayer { exceptions }
86 }
87}
88
89impl<S> Layer<S> for NormalizePathLayer {
90 type Service = NormalizePath<S>;
91
92 fn layer(&self, inner: S) -> Self::Service {
93 NormalizePath::trim_trailing_slash(inner, &self.exceptions)
94 }
95}
96
97#[derive(Debug, Clone)]
101pub struct NormalizePath<S> {
102 exceptions: Vec<String>,
103 inner: S,
104}
105
106impl<S> NormalizePath<S> {
107 pub fn trim_trailing_slash<P: AsRef<str>>(inner: S, exceptions: &[P]) -> Self {
112 let exceptions = exceptions.iter().map(|x| x.as_ref().to_string()).collect();
113 Self { exceptions, inner }
114 }
115
116 define_inner_service_accessors!();
117}
118
119impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for NormalizePath<S>
120where
121 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
122{
123 type Response = S::Response;
124 type Error = S::Error;
125 type Future = S::Future;
126
127 #[inline]
128 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129 self.inner.poll_ready(cx)
130 }
131
132 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
133 let path = req.uri().path();
134 if !self.exceptions.iter().any(|x| path.starts_with(x)) {
135 normalize_trailing_slash(req.uri_mut());
136 }
137 self.inner.call(req)
138 }
139}
140
141fn normalize_trailing_slash(uri: &mut Uri) {
142 if !uri.path().ends_with('/') && !uri.path().starts_with("//") {
143 return;
144 }
145
146 let new_path = format!("/{}", uri.path().trim_matches('/'));
147
148 let mut parts = uri.clone().into_parts();
149
150 let new_path_and_query = if let Some(path_and_query) = &parts.path_and_query {
151 let new_path_and_query = if let Some(query) = path_and_query.query() {
152 Cow::Owned(format!("{}?{}", new_path, query))
153 } else {
154 new_path.into()
155 }
156 .parse()
157 .unwrap();
158
159 Some(new_path_and_query)
160 } else {
161 None
162 };
163
164 parts.path_and_query = new_path_and_query;
165 if let Ok(new_uri) = Uri::from_parts(parts) {
166 *uri = new_uri;
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use std::convert::Infallible;
174 use tower::{ServiceBuilder, ServiceExt};
175
176 #[tokio::test]
177 async fn works() {
178 async fn handle(request: Request<()>) -> Result<Response<String>, Infallible> {
179 Ok(Response::new(request.uri().to_string()))
180 }
181
182 let mut svc = ServiceBuilder::new()
183 .layer(NormalizePathLayer::trim_trailing_slash(&["/bar"]))
184 .service_fn(handle);
185
186 let body = svc
187 .ready()
188 .await
189 .unwrap()
190 .call(Request::builder().uri("/foo/").body(()).unwrap())
191 .await
192 .unwrap()
193 .into_body();
194
195 assert_eq!(body, "/foo");
196
197 let body = svc
198 .ready()
199 .await
200 .unwrap()
201 .call(Request::builder().uri("/foo/bar/").body(()).unwrap())
202 .await
203 .unwrap()
204 .into_body();
205
206 assert_eq!(body, "/foo/bar");
207
208 let body = svc
209 .ready()
210 .await
211 .unwrap()
212 .call(Request::builder().uri("/bar/").body(()).unwrap())
213 .await
214 .unwrap()
215 .into_body();
216
217 assert_eq!(body, "/bar/");
218
219 let body = svc
220 .ready()
221 .await
222 .unwrap()
223 .call(Request::builder().uri("/bar/baz/").body(()).unwrap())
224 .await
225 .unwrap()
226 .into_body();
227
228 assert_eq!(body, "/bar/baz/");
229 }
230
231 #[test]
232 fn is_noop_if_no_trailing_slash() {
233 let mut uri = "/foo".parse::<Uri>().unwrap();
234 normalize_trailing_slash(&mut uri);
235 assert_eq!(uri, "/foo");
236 }
237
238 #[test]
239 fn maintains_query() {
240 let mut uri = "/foo/?a=a".parse::<Uri>().unwrap();
241 normalize_trailing_slash(&mut uri);
242 assert_eq!(uri, "/foo?a=a");
243 }
244
245 #[test]
246 fn removes_multiple_trailing_slashes() {
247 let mut uri = "/foo////".parse::<Uri>().unwrap();
248 normalize_trailing_slash(&mut uri);
249 assert_eq!(uri, "/foo");
250 }
251
252 #[test]
253 fn removes_multiple_trailing_slashes_even_with_query() {
254 let mut uri = "/foo////?a=a".parse::<Uri>().unwrap();
255 normalize_trailing_slash(&mut uri);
256 assert_eq!(uri, "/foo?a=a");
257 }
258
259 #[test]
260 fn is_noop_on_index() {
261 let mut uri = "/".parse::<Uri>().unwrap();
262 normalize_trailing_slash(&mut uri);
263 assert_eq!(uri, "/");
264 }
265
266 #[test]
267 fn removes_multiple_trailing_slashes_on_index() {
268 let mut uri = "////".parse::<Uri>().unwrap();
269 normalize_trailing_slash(&mut uri);
270 assert_eq!(uri, "/");
271 }
272
273 #[test]
274 fn removes_multiple_trailing_slashes_on_index_even_with_query() {
275 let mut uri = "////?a=a".parse::<Uri>().unwrap();
276 normalize_trailing_slash(&mut uri);
277 assert_eq!(uri, "/?a=a");
278 }
279
280 #[test]
281 fn removes_multiple_preceding_slashes_even_with_query() {
282 let mut uri = "///foo//?a=a".parse::<Uri>().unwrap();
283 normalize_trailing_slash(&mut uri);
284 assert_eq!(uri, "/foo?a=a");
285 }
286
287 #[test]
288 fn removes_multiple_preceding_slashes() {
289 let mut uri = "///foo".parse::<Uri>().unwrap();
290 normalize_trailing_slash(&mut uri);
291 assert_eq!(uri, "/foo");
292 }
293}