rama_http/matcher/path/
mod.rs

1use crate::{IntoResponse, Request, StatusCode};
2use rama_core::{Context, context::Extensions};
3use std::collections::HashMap;
4
5mod de;
6
7#[derive(Debug, Clone, Default)]
8/// parameters that are inserted in the [`Context`],
9/// in case the [`PathMatcher`] found a match for the given [`Request`].
10pub struct UriParams {
11    params: Option<HashMap<String, String>>,
12    glob: Option<String>,
13}
14
15impl UriParams {
16    fn insert(&mut self, name: String, value: String) {
17        self.params
18            .get_or_insert_with(HashMap::new)
19            .insert(name, value);
20    }
21
22    /// Some str slice will be returned in case a param could be found for the given name.
23    pub fn get(&self, name: impl AsRef<str>) -> Option<&str> {
24        self.params
25            .as_ref()
26            .and_then(|params| params.get(name.as_ref()))
27            .map(String::as_str)
28    }
29
30    fn append_glob(&mut self, value: &str) {
31        match self.glob {
32            Some(ref mut glob) => {
33                glob.push('/');
34                glob.push_str(value);
35            }
36            None => self.glob = Some(format!("/{}", value)),
37        }
38    }
39
40    /// Some str slice will be returned in case a glob value was captured
41    /// for the last part of the Path that was matched on.
42    pub fn glob(&self) -> Option<&str> {
43        self.glob.as_deref()
44    }
45
46    /// Deserialize the [`UriParams`] into a given type.
47    pub fn deserialize<T>(&self) -> Result<T, UriParamsDeserializeError>
48    where
49        T: serde::de::DeserializeOwned,
50    {
51        match self.params {
52            Some(ref params) => {
53                let params: Vec<_> = params
54                    .iter()
55                    .map(|(k, v)| (k.as_str(), v.as_str()))
56                    .collect();
57                let deserializer = de::PathDeserializer::new(&params);
58                T::deserialize(deserializer)
59            }
60            None => Err(de::PathDeserializationError::new(de::ErrorKind::NoParams)),
61        }
62        .map_err(UriParamsDeserializeError)
63    }
64
65    /// Extend the [`UriParams`] with the given iterator.
66    pub fn extend<I, K, V>(&mut self, iter: I) -> &mut Self
67    where
68        I: IntoIterator<Item = (K, V)>,
69        K: Into<String>,
70        V: Into<String>,
71    {
72        let params = self.params.get_or_insert_with(HashMap::new);
73        for (k, v) in iter {
74            params.insert(k.into(), v.into());
75        }
76        self
77    }
78
79    pub fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
80        self.params
81            .as_ref()
82            .map(|params| params.iter().map(|(k, v)| (k.as_str(), v.as_str())))
83            .into_iter()
84            .flatten()
85    }
86}
87
88impl<'a> FromIterator<(&'a str, &'a str)> for UriParams {
89    fn from_iter<T: IntoIterator<Item = (&'a str, &'a str)>>(iter: T) -> Self {
90        let mut params = UriParams::default();
91        for (k, v) in iter {
92            params.insert(k.to_owned(), v.to_owned());
93        }
94        params
95    }
96}
97
98#[derive(Debug)]
99/// Error that can occur during the deserialization of the [`UriParams`].
100///
101/// See [`UriParams::deserialize`] for more information.
102pub struct UriParamsDeserializeError(de::PathDeserializationError);
103
104impl UriParamsDeserializeError {
105    /// Get the response body text used for this rejection.
106    pub fn body_text(&self) -> String {
107        use de::ErrorKind;
108        match self.0.kind {
109            ErrorKind::Message(_)
110            | ErrorKind::NoParams
111            | ErrorKind::ParseError { .. }
112            | ErrorKind::ParseErrorAtIndex { .. }
113            | ErrorKind::ParseErrorAtKey { .. } => format!("Invalid URL: {}", self.0.kind),
114            ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => {
115                self.0.kind.to_string()
116            }
117        }
118    }
119
120    /// Get the status code used for this rejection.
121    pub fn status(&self) -> StatusCode {
122        use de::ErrorKind;
123        match self.0.kind {
124            ErrorKind::Message(_)
125            | ErrorKind::NoParams
126            | ErrorKind::ParseError { .. }
127            | ErrorKind::ParseErrorAtIndex { .. }
128            | ErrorKind::ParseErrorAtKey { .. } => StatusCode::BAD_REQUEST,
129            ErrorKind::WrongNumberOfParameters { .. } | ErrorKind::UnsupportedType { .. } => {
130                StatusCode::INTERNAL_SERVER_ERROR
131            }
132        }
133    }
134}
135
136impl std::fmt::Display for UriParamsDeserializeError {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        self.0.fmt(f)
139    }
140}
141
142impl std::error::Error for UriParamsDeserializeError {}
143
144impl IntoResponse for UriParamsDeserializeError {
145    fn into_response(self) -> crate::Response {
146        crate::utils::macros::log_http_rejection!(
147            rejection_type = UriParamsDeserializeError,
148            body_text = self.body_text(),
149            status = self.status(),
150        );
151        (self.status(), self.body_text()).into_response()
152    }
153}
154
155#[derive(Debug, Clone)]
156enum PathFragment {
157    Literal(String),
158    Param(String),
159    Glob,
160}
161
162#[derive(Debug, Clone)]
163enum PathMatcherKind {
164    Literal(String),
165    FragmentList(Vec<PathFragment>),
166}
167
168#[derive(Debug, Clone)]
169/// Matcher based on the URI path.
170pub struct PathMatcher {
171    kind: PathMatcherKind,
172}
173
174impl PathMatcher {
175    /// Create a new [`PathMatcher`] for the given path.
176    pub fn new(path: impl AsRef<str>) -> Self {
177        let path = path.as_ref();
178        let path = path.trim().trim_matches('/');
179
180        if !path.contains([':', '*', '{', '}']) {
181            return Self {
182                kind: PathMatcherKind::Literal(path.to_lowercase()),
183            };
184        }
185
186        let path_parts: Vec<_> = path.split('/').filter(|s| !s.is_empty()).collect();
187        let fragment_length = path_parts.len();
188        if fragment_length == 1 && path_parts[0].is_empty() {
189            return Self {
190                kind: PathMatcherKind::FragmentList(vec![PathFragment::Glob]),
191            };
192        }
193
194        let fragments: Vec<PathFragment> = path_parts
195            .into_iter()
196            .enumerate()
197            .filter_map(|(index, s)| {
198                if s.is_empty() {
199                    return None;
200                }
201                if s.starts_with(':') {
202                    Some(PathFragment::Param(
203                        s.trim_start_matches(':').to_lowercase(),
204                    ))
205                } else if s.starts_with('{') && s.ends_with('}') && s.len() > 2 {
206                    let param_name = s[1..s.len() - 1].to_lowercase();
207                    Some(PathFragment::Param(param_name))
208                } else if s == "*" && index == fragment_length - 1 {
209                    Some(PathFragment::Glob)
210                } else {
211                    Some(PathFragment::Literal(s.to_lowercase()))
212                }
213            })
214            .collect();
215
216        Self {
217            kind: PathMatcherKind::FragmentList(fragments),
218        }
219    }
220
221    pub(crate) fn matches_path(&self, path: &str) -> Option<UriParams> {
222        let path = path.trim().trim_matches('/');
223        match &self.kind {
224            PathMatcherKind::Literal(literal) => {
225                if literal.eq_ignore_ascii_case(path) {
226                    Some(UriParams::default())
227                } else {
228                    None
229                }
230            }
231            PathMatcherKind::FragmentList(fragments) => {
232                let fragments_iter = fragments.iter().map(Some).chain(std::iter::repeat(None));
233                let mut params = UriParams::default();
234                for (segment, fragment) in path
235                    .split('/')
236                    .map(Some)
237                    .chain(std::iter::repeat(None))
238                    .zip(fragments_iter)
239                {
240                    match (segment, fragment) {
241                        (Some(segment), Some(fragment)) => match fragment {
242                            PathFragment::Literal(literal) => {
243                                if !literal.eq_ignore_ascii_case(segment) {
244                                    return None;
245                                }
246                            }
247                            PathFragment::Param(name) => {
248                                if segment.is_empty() {
249                                    return None;
250                                }
251                                let segment = percent_encoding::percent_decode(segment.as_bytes())
252                                    .decode_utf8()
253                                    .map(|s| s.to_string())
254                                    .unwrap_or_else(|_| segment.to_owned());
255                                params.insert(name.to_owned(), segment);
256                            }
257                            PathFragment::Glob => {
258                                params.append_glob(segment);
259                            }
260                        },
261                        (None, None) => {
262                            break;
263                        }
264                        (Some(segment), None) => {
265                            params.glob()?;
266                            params.append_glob(segment);
267                        }
268                        _ => {
269                            return None;
270                        }
271                    }
272                }
273
274                Some(params)
275            }
276        }
277    }
278}
279
280impl<State, Body> rama_core::matcher::Matcher<State, Request<Body>> for PathMatcher {
281    fn matches(
282        &self,
283        ext: Option<&mut Extensions>,
284        _ctx: &Context<State>,
285        req: &Request<Body>,
286    ) -> bool {
287        match self.matches_path(req.uri().path()) {
288            None => false,
289            Some(params) => {
290                if let Some(ext) = ext {
291                    ext.insert(params);
292                }
293                true
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod test {
301    use super::*;
302
303    #[test]
304    fn test_path_matcher_match_path() {
305        struct TestCase {
306            path: &'static str,
307            matcher_path: &'static str,
308            result: Option<UriParams>,
309        }
310
311        impl TestCase {
312            fn some(path: &'static str, matcher_path: &'static str, result: UriParams) -> Self {
313                Self {
314                    path,
315                    matcher_path,
316                    result: Some(result),
317                }
318            }
319
320            fn none(path: &'static str, matcher_path: &'static str) -> Self {
321                Self {
322                    path,
323                    matcher_path,
324                    result: None,
325                }
326            }
327        }
328
329        let test_cases = vec![
330            TestCase::some("/", "/", UriParams::default()),
331            TestCase::some("", "/", UriParams::default()),
332            TestCase::some("/", "", UriParams::default()),
333            TestCase::some("", "", UriParams::default()),
334            TestCase::some("/foo", "/foo", UriParams::default()),
335            TestCase::some("/foo", "//foo//", UriParams::default()),
336            TestCase::some("/*foo", "/*foo", UriParams::default()),
337            TestCase::some("/foo/*bar/baz", "/foo/*bar/baz", UriParams::default()),
338            TestCase::none("/foo/*bar/baz", "/foo/*bar"),
339            TestCase::none("/", "/:foo"),
340            TestCase::some(
341                "/",
342                "/*",
343                UriParams {
344                    glob: Some("/".to_owned()),
345                    ..UriParams::default()
346                },
347            ),
348            TestCase::none("/", "//:foo"),
349            TestCase::none("", "/:foo"),
350            TestCase::none("/foo", "/bar"),
351            TestCase::some(
352                "/person/glen%20dc/age",
353                "/person/:name/age",
354                UriParams {
355                    params: Some({
356                        let mut params = HashMap::new();
357                        params.insert("name".to_owned(), "glen dc".to_owned());
358                        params
359                    }),
360                    ..UriParams::default()
361                },
362            ),
363            TestCase::none("/foo", "/bar"),
364            TestCase::some("/foo", "foo", UriParams::default()),
365            TestCase::some("/foo/bar/", "foo/bar", UriParams::default()),
366            TestCase::none("/foo/bar/", "foo/baz"),
367            TestCase::some("/foo/bar/", "/foo/bar", UriParams::default()),
368            TestCase::some("/foo/bar", "/foo/bar", UriParams::default()),
369            TestCase::some("/foo/bar", "foo/bar", UriParams::default()),
370            TestCase::some("/book/oxford-dictionary/author", "/book/:title/author", {
371                let mut params = UriParams::default();
372                params.insert("title".to_owned(), "oxford-dictionary".to_owned());
373                params
374            }),
375            TestCase::some("/book/oxford-dictionary/author", "/book/{title}/author", {
376                let mut params = UriParams::default();
377                params.insert("title".to_owned(), "oxford-dictionary".to_owned());
378                params
379            }),
380            TestCase::some(
381                "/book/oxford-dictionary/author/0",
382                "/book/:title/author/:index",
383                {
384                    let mut params = UriParams::default();
385                    params.insert("title".to_owned(), "oxford-dictionary".to_owned());
386                    params.insert("index".to_owned(), "0".to_owned());
387                    params
388                },
389            ),
390            TestCase::some(
391                "/book/oxford-dictionary/author/1",
392                "/book/{title}/author/{index}",
393                {
394                    let mut params = UriParams::default();
395                    params.insert("title".to_owned(), "oxford-dictionary".to_owned());
396                    params.insert("index".to_owned(), "1".to_owned());
397                    params
398                },
399            ),
400            TestCase::none("/book/oxford-dictionary", "/book/:title/author"),
401            TestCase::none(
402                "/book/oxford-dictionary/author/birthdate",
403                "/book/:title/author",
404            ),
405            TestCase::none("oxford-dictionary/author", "/book/:title/author"),
406            TestCase::none("/foo", "/"),
407            TestCase::none("/foo", "/*f"),
408            TestCase::some(
409                "/foo",
410                "/*",
411                UriParams {
412                    glob: Some("/foo".to_owned()),
413                    ..UriParams::default()
414                },
415            ),
416            TestCase::some(
417                "/assets/css/reset.css",
418                "/assets/*",
419                UriParams {
420                    glob: Some("/css/reset.css".to_owned()),
421                    ..UriParams::default()
422                },
423            ),
424            TestCase::some("/assets/eu/css/reset.css", "/assets/:local/*", {
425                let mut params = UriParams::default();
426                params.insert("local".to_owned(), "eu".to_owned());
427                params.glob = Some("/css/reset.css".to_owned());
428                params
429            }),
430            TestCase::some("/assets/eu/css/reset.css", "/assets/:local/css/*", {
431                let mut params = UriParams::default();
432                params.insert("local".to_owned(), "eu".to_owned());
433                params.glob = Some("/reset.css".to_owned());
434                params
435            }),
436        ];
437        for test_case in test_cases.into_iter() {
438            let matcher = PathMatcher::new(test_case.matcher_path);
439            let result = matcher.matches_path(test_case.path);
440            match (result.as_ref(), test_case.result.as_ref()) {
441                (None, None) => (),
442                (Some(result), Some(expected_result)) => {
443                    assert_eq!(
444                        result.params,
445                        expected_result.params,
446                        "unexpected result params: ({}).matcher({}) => {:?} != {:?}",
447                        test_case.matcher_path,
448                        test_case.path,
449                        result.params,
450                        expected_result.params,
451                    );
452                    assert_eq!(
453                        result.glob, expected_result.glob,
454                        "unexpected result glob: ({}).matcher({}) => {:?} != {:?}",
455                        test_case.matcher_path, test_case.path, result.glob, expected_result.glob,
456                    );
457                }
458                _ => {
459                    panic!(
460                        "unexpected result: ({}).matcher({}) => {:?} != {:?}",
461                        test_case.matcher_path, test_case.path, result, test_case.result
462                    )
463                }
464            }
465        }
466    }
467
468    #[test]
469    fn test_deserialize_uri_params() {
470        let params = UriParams {
471            params: Some({
472                let mut params = HashMap::new();
473                params.insert("name".to_owned(), "glen dc".to_owned());
474                params.insert("age".to_owned(), "42".to_owned());
475                params
476            }),
477            glob: Some("/age".to_owned()),
478        };
479
480        #[derive(serde::Deserialize)]
481        struct Person {
482            name: String,
483            age: u8,
484        }
485
486        let person: Person = params.deserialize().unwrap();
487        assert_eq!(person.name, "glen dc");
488        assert_eq!(person.age, 42);
489    }
490}