rama_http/matcher/path/
mod.rs

1use crate::{IntoResponse, Request, StatusCode};
2use rama_core::{context::Extensions, Context};
3use std::collections::HashMap;
4
5mod de;
6
7#[derive(Debug, Clone, Default)]
8/// parameters that are inserted in the [`Context`],
9/// in case the [`PathMatcher`] found a match for the given [`Request`].
10pub struct UriParams {
11    params: Option<HashMap<String, String>>,
12    glob: Option<String>,
13}
14
15impl UriParams {
16    fn insert(&mut self, name: String, value: String) {
17        self.params
18            .get_or_insert_with(HashMap::new)
19            .insert(name, value);
20    }
21
22    /// Some str slice will be returned in case a param could be found for the given name.
23    pub fn get(&self, name: impl AsRef<str>) -> Option<&str> {
24        self.params
25            .as_ref()
26            .and_then(|params| params.get(name.as_ref()))
27            .map(String::as_str)
28    }
29
30    fn append_glob(&mut self, value: &str) {
31        match self.glob {
32            Some(ref mut glob) => {
33                glob.push('/');
34                glob.push_str(value);
35            }
36            None => self.glob = Some(format!("/{}", value)),
37        }
38    }
39
40    /// Some str slice will be returned in case a glob value was captured
41    /// for the last part of the Path that was matched on.
42    pub fn glob(&self) -> Option<&str> {
43        self.glob.as_deref()
44    }
45
46    /// Deserialize the [`UriParams`] into a given type.
47    pub fn deserialize<T>(&self) -> Result<T, UriParamsDeserializeError>
48    where
49        T: serde::de::DeserializeOwned,
50    {
51        match self.params {
52            Some(ref params) => {
53                let params: Vec<_> = params
54                    .iter()
55                    .map(|(k, v)| (k.as_str(), v.as_str()))
56                    .collect();
57                let deserializer = de::PathDeserializer::new(&params);
58                T::deserialize(deserializer)
59            }
60            None => Err(de::PathDeserializationError::new(de::ErrorKind::NoParams)),
61        }
62        .map_err(UriParamsDeserializeError)
63    }
64}
65
66#[derive(Debug)]
67/// Error that can occur during the deserialization of the [`UriParams`].
68///
69/// See [`UriParams::deserialize`] for more information.
70pub struct UriParamsDeserializeError(de::PathDeserializationError);
71
72impl UriParamsDeserializeError {
73    /// Get the response body text used for this rejection.
74    pub fn body_text(&self) -> String {
75        use de::ErrorKind;
76        match self.0.kind {
77            ErrorKind::Message(_)
78            | ErrorKind::NoParams
79            | ErrorKind::ParseError { .. }
80            | ErrorKind::ParseErrorAtIndex { .. }
81            | ErrorKind::ParseErrorAtKey { .. } => format!("Invalid URL: {}", self.0.kind),
82            ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => {
83                self.0.kind.to_string()
84            }
85        }
86    }
87
88    /// Get the status code used for this rejection.
89    pub fn status(&self) -> StatusCode {
90        use de::ErrorKind;
91        match self.0.kind {
92            ErrorKind::Message(_)
93            | ErrorKind::NoParams
94            | ErrorKind::ParseError { .. }
95            | ErrorKind::ParseErrorAtIndex { .. }
96            | ErrorKind::ParseErrorAtKey { .. } => StatusCode::BAD_REQUEST,
97            ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => {
98                StatusCode::INTERNAL_SERVER_ERROR
99            }
100        }
101    }
102}
103
104impl std::fmt::Display for UriParamsDeserializeError {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        self.0.fmt(f)
107    }
108}
109
110impl std::error::Error for UriParamsDeserializeError {}
111
112impl IntoResponse for UriParamsDeserializeError {
113    fn into_response(self) -> crate::Response {
114        crate::utils::macros::log_http_rejection!(
115            rejection_type = UriParamsDeserializeError,
116            body_text = self.body_text(),
117            status = self.status(),
118        );
119        (self.status(), self.body_text()).into_response()
120    }
121}
122
123#[derive(Debug, Clone)]
124enum PathFragment {
125    Literal(String),
126    Param(String),
127    Glob,
128}
129
130#[derive(Debug, Clone)]
131enum PathMatcherKind {
132    Literal(String),
133    FragmentList(Vec<PathFragment>),
134}
135
136#[derive(Debug, Clone)]
137/// Matcher based on the URI path.
138pub struct PathMatcher {
139    kind: PathMatcherKind,
140}
141
142impl PathMatcher {
143    /// Create a new [`PathMatcher`] for the given path.
144    pub fn new(path: impl AsRef<str>) -> Self {
145        let path = path.as_ref();
146        let path = path.trim().trim_matches('/');
147
148        if !path.contains([':', '*']) {
149            return Self {
150                kind: PathMatcherKind::Literal(path.to_lowercase()),
151            };
152        }
153
154        let path_parts: Vec<_> = path.split('/').filter(|s| !s.is_empty()).collect();
155        let fragment_length = path_parts.len();
156        if fragment_length == 1 && path_parts[0].is_empty() {
157            return Self {
158                kind: PathMatcherKind::FragmentList(vec![PathFragment::Glob]),
159            };
160        }
161
162        let fragments: Vec<PathFragment> = path_parts
163            .into_iter()
164            .enumerate()
165            .filter_map(|(index, s)| {
166                if s.is_empty() {
167                    return None;
168                }
169                if s.starts_with(':') {
170                    Some(PathFragment::Param(
171                        s.trim_start_matches(':').to_lowercase(),
172                    ))
173                } else if s == "*" && index == fragment_length - 1 {
174                    Some(PathFragment::Glob)
175                } else {
176                    Some(PathFragment::Literal(s.to_lowercase()))
177                }
178            })
179            .collect();
180
181        Self {
182            kind: PathMatcherKind::FragmentList(fragments),
183        }
184    }
185
186    pub(crate) fn matches_path(&self, path: &str) -> Option<UriParams> {
187        let path = path.trim().trim_matches('/');
188        match &self.kind {
189            PathMatcherKind::Literal(literal) => {
190                if literal.eq_ignore_ascii_case(path) {
191                    Some(UriParams::default())
192                } else {
193                    None
194                }
195            }
196            PathMatcherKind::FragmentList(fragments) => {
197                let fragments_iter = fragments.iter().map(Some).chain(std::iter::repeat(None));
198                let mut params = UriParams::default();
199                for (segment, fragment) in path
200                    .split('/')
201                    .map(Some)
202                    .chain(std::iter::repeat(None))
203                    .zip(fragments_iter)
204                {
205                    match (segment, fragment) {
206                        (Some(segment), Some(fragment)) => match fragment {
207                            PathFragment::Literal(literal) => {
208                                if !literal.eq_ignore_ascii_case(segment) {
209                                    return None;
210                                }
211                            }
212                            PathFragment::Param(name) => {
213                                if segment.is_empty() {
214                                    return None;
215                                }
216                                let segment = percent_encoding::percent_decode(segment.as_bytes())
217                                    .decode_utf8()
218                                    .map(|s| s.to_string())
219                                    .unwrap_or_else(|_| segment.to_owned());
220                                params.insert(name.to_owned(), segment);
221                            }
222                            PathFragment::Glob => {
223                                params.append_glob(segment);
224                            }
225                        },
226                        (None, None) => {
227                            break;
228                        }
229                        (Some(segment), None) => {
230                            params.glob()?;
231                            params.append_glob(segment);
232                        }
233                        _ => {
234                            return None;
235                        }
236                    }
237                }
238
239                Some(params)
240            }
241        }
242    }
243}
244
245impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for PathMatcher {
246    fn matches(
247        &self,
248        ext: Option<&mut Extensions>,
249        _ctx: &Context<State>,
250        req: &Request<Body>,
251    ) -> bool {
252        match self.matches_path(req.uri().path()) {
253            None => false,
254            Some(params) => {
255                if let Some(ext) = ext {
256                    ext.insert(params);
257                }
258                true
259            }
260        }
261    }
262}
263
264#[cfg(test)]
265mod test {
266    use super::*;
267
268    #[test]
269    fn test_path_matcher_match_path() {
270        struct TestCase {
271            path: &'static str,
272            matcher_path: &'static str,
273            result: Option<UriParams>,
274        }
275
276        impl TestCase {
277            fn some(path: &'static str, matcher_path: &'static str, result: UriParams) -> Self {
278                Self {
279                    path,
280                    matcher_path,
281                    result: Some(result),
282                }
283            }
284
285            fn none(path: &'static str, matcher_path: &'static str) -> Self {
286                Self {
287                    path,
288                    matcher_path,
289                    result: None,
290                }
291            }
292        }
293
294        let test_cases = vec![
295            TestCase::some("/", "/", UriParams::default()),
296            TestCase::some("", "/", UriParams::default()),
297            TestCase::some("/", "", UriParams::default()),
298            TestCase::some("", "", UriParams::default()),
299            TestCase::some("/foo", "/foo", UriParams::default()),
300            TestCase::some("/foo", "//foo//", UriParams::default()),
301            TestCase::some("/*foo", "/*foo", UriParams::default()),
302            TestCase::some("/foo/*bar/baz", "/foo/*bar/baz", UriParams::default()),
303            TestCase::none("/foo/*bar/baz", "/foo/*bar"),
304            TestCase::none("/", "/:foo"),
305            TestCase::some(
306                "/",
307                "/*",
308                UriParams {
309                    glob: Some("/".to_owned()),
310                    ..UriParams::default()
311                },
312            ),
313            TestCase::none("/", "//:foo"),
314            TestCase::none("", "/:foo"),
315            TestCase::none("/foo", "/bar"),
316            TestCase::some(
317                "/person/glen%20dc/age",
318                "/person/:name/age",
319                UriParams {
320                    params: Some({
321                        let mut params = HashMap::new();
322                        params.insert("name".to_owned(), "glen dc".to_owned());
323                        params
324                    }),
325                    ..UriParams::default()
326                },
327            ),
328            TestCase::none("/foo", "/bar"),
329            TestCase::some("/foo", "foo", UriParams::default()),
330            TestCase::some("/foo/bar/", "foo/bar", UriParams::default()),
331            TestCase::none("/foo/bar/", "foo/baz"),
332            TestCase::some("/foo/bar/", "/foo/bar", UriParams::default()),
333            TestCase::some("/foo/bar", "/foo/bar", UriParams::default()),
334            TestCase::some("/foo/bar", "foo/bar", UriParams::default()),
335            TestCase::some("/book/oxford-dictionary/author", "/book/:title/author", {
336                let mut params = UriParams::default();
337                params.insert("title".to_owned(), "oxford-dictionary".to_owned());
338                params
339            }),
340            TestCase::some(
341                "/book/oxford-dictionary/author/0",
342                "/book/:title/author/:index",
343                {
344                    let mut params = UriParams::default();
345                    params.insert("title".to_owned(), "oxford-dictionary".to_owned());
346                    params.insert("index".to_owned(), "0".to_owned());
347                    params
348                },
349            ),
350            TestCase::none("/book/oxford-dictionary", "/book/:title/author"),
351            TestCase::none(
352                "/book/oxford-dictionary/author/birthdate",
353                "/book/:title/author",
354            ),
355            TestCase::none("oxford-dictionary/author", "/book/:title/author"),
356            TestCase::none("/foo", "/"),
357            TestCase::none("/foo", "/*f"),
358            TestCase::some(
359                "/foo",
360                "/*",
361                UriParams {
362                    glob: Some("/foo".to_owned()),
363                    ..UriParams::default()
364                },
365            ),
366            TestCase::some(
367                "/assets/css/reset.css",
368                "/assets/*",
369                UriParams {
370                    glob: Some("/css/reset.css".to_owned()),
371                    ..UriParams::default()
372                },
373            ),
374            TestCase::some("/assets/eu/css/reset.css", "/assets/:local/*", {
375                let mut params = UriParams::default();
376                params.insert("local".to_owned(), "eu".to_owned());
377                params.glob = Some("/css/reset.css".to_owned());
378                params
379            }),
380            TestCase::some("/assets/eu/css/reset.css", "/assets/:local/css/*", {
381                let mut params = UriParams::default();
382                params.insert("local".to_owned(), "eu".to_owned());
383                params.glob = Some("/reset.css".to_owned());
384                params
385            }),
386        ];
387        for test_case in test_cases.into_iter() {
388            let matcher = PathMatcher::new(test_case.matcher_path);
389            let result = matcher.matches_path(test_case.path);
390            match (result.as_ref(), test_case.result.as_ref()) {
391                (None, None) => (),
392                (Some(result), Some(expected_result)) => {
393                    assert_eq!(
394                        result.params,
395                        expected_result.params,
396                        "unexpected result params: ({}).matcher({}) => {:?} != {:?}",
397                        test_case.matcher_path,
398                        test_case.path,
399                        result.params,
400                        expected_result.params,
401                    );
402                    assert_eq!(
403                        result.glob, expected_result.glob,
404                        "unexpected result glob: ({}).matcher({}) => {:?} != {:?}",
405                        test_case.matcher_path, test_case.path, result.glob, expected_result.glob,
406                    );
407                }
408                _ => {
409                    panic!(
410                        "unexpected result: ({}).matcher({}) => {:?} != {:?}",
411                        test_case.matcher_path, test_case.path, result, test_case.result
412                    )
413                }
414            }
415        }
416    }
417
418    #[test]
419    fn test_deserialize_uri_params() {
420        let params = UriParams {
421            params: Some({
422                let mut params = HashMap::new();
423                params.insert("name".to_owned(), "glen dc".to_owned());
424                params.insert("age".to_owned(), "42".to_owned());
425                params
426            }),
427            glob: Some("/age".to_owned()),
428        };
429
430        #[derive(serde::Deserialize)]
431        struct Person {
432            name: String,
433            age: u8,
434        }
435
436        let person: Person = params.deserialize().unwrap();
437        assert_eq!(person.name, "glen dc");
438        assert_eq!(person.age, 42);
439    }
440}