1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
use http::{uri::InvalidUri, Uri};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use urlpattern::{
    UrlPattern, UrlPatternInit, UrlPatternMatchInput, UrlPatternOptions, UrlPatternResult,
};

/// An error that occurs when creating a pattern.
#[derive(Debug, thiserror::Error)]
pub enum PatternError {
    /// Unable to create a url from a string.
    #[error("invalid url: {0}")]
    InvalidUrl(#[from] InvalidUri),

    /// Unable to use a url pattern matcher correctly.
    #[error("invalid url pattern: {0}")]
    InvalidPattern(#[from] urlpattern::Error),
}

/// A wrapper for a URL pattern, which can be used to match URLs and extract
/// data from them.
///
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pattern {
    raw: String,
    #[serde(with = "http_serde_ext::uri")]
    uri: Uri,
}

impl Pattern {
    fn url_pattern(&self) -> Option<UrlPattern> {
        <UrlPattern>::parse(self.pattern_init(), UrlPatternOptions::default()).ok()
    }

    /// Compare a pattern frame to a material frame, to try and see if the pattern
    /// can capture anything from the material frame. If it can, this returns a
    /// [`UrlPatternResult`] which contains the capture data.
    pub fn test(&self, other: impl TryInto<Pattern>) -> Result<bool, PatternError> {
        let other = match other.try_into() {
            Ok(other) => other,
            Err(_) => return Ok(false),
        };

        let matching_pattern = self.url_pattern();

        match matching_pattern {
            None => Ok(false),
            Some(matching_pattern) => Ok(matching_pattern.test(other.pattern_input())?),
        }
    }

    /// Compare a pattern frame to a material frame, to try and see if the pattern
    /// can capture anything from the material frame. If it can, this returns a
    /// [`UrlPatternResult`] which contains the capture data.
    pub fn captures(&self, other: impl TryInto<Pattern>) -> Option<UrlPatternResult> {
        let other = match other.try_into() {
            Ok(other) => other,
            Err(_) => return None,
        };

        let matching_pattern = self.url_pattern();

        if let Some(matching_pattern) = matching_pattern {
            return matching_pattern
                .exec(other.pattern_input())
                .unwrap_or_default();
        }

        None
    }

    /// Get a value from the query string of the URL pattern.
    ///
    pub fn get_from_pathname(&self, other: impl TryInto<Pattern>, key: &str) -> Option<String> {
        if let Ok(other) = other.try_into() {
            let captures = self.captures(other)?;

            match captures.pathname.groups.get(key) {
                Some(Some(value)) => Some(value.clone()),
                _ => None,
            }
        } else {
            None
        }
    }

    fn pattern_init(&self) -> UrlPatternInit {
        UrlPatternInit {
            pathname: Some(self.uri.path().to_owned()),
            search: self.uri.query().map(Into::into),
            ..Default::default()
        }
    }

    fn pattern_input(&self) -> UrlPatternMatchInput {
        UrlPatternMatchInput::Init(self.pattern_init())
    }

    fn try_from_str(input: &str) -> Result<Self, PatternError> {
        Ok(Pattern {
            raw: input.into(),
            uri: get_uri(input)?,
        })
    }
}

fn get_uri(input: &str) -> Result<Uri, PatternError> {
    let mut uri = input.to_owned();

    if !uri.starts_with('/') && !uri.contains("://") {
        uri.insert(0, '/')
    }

    Ok(uri.parse()?)
}

impl FromStr for Pattern {
    type Err = PatternError;

    fn from_str(input: &str) -> Result<Self, Self::Err> {
        Self::try_from_str(input)
    }
}

impl TryFrom<&str> for Pattern {
    type Error = PatternError;

    fn try_from(input: &str) -> Result<Self, Self::Error> {
        input.parse()
    }
}

#[test]
fn accepts_urls() {
    let pattern: Result<Pattern, _> = "frame://accounts/profile/:id".parse();

    assert!(pattern.is_ok());
}

#[test]
fn describes_capture_results() {
    use std::collections::HashMap;
    use urlpattern::UrlPatternComponentResult;

    let pattern: Pattern = "frame://accounts/profile/:id".parse().unwrap();
    let captures = pattern.captures("frame://accounts/profile/123");

    let url_pattern_result = UrlPatternComponentResult {
        input: "/profile/123".to_string(),
        groups: HashMap::from([("id".into(), Some("123".into()))]),
    };

    assert_eq!(captures.unwrap().pathname, url_pattern_result);
}

#[test]
fn partials() {
    let partial: Pattern = "/profile/:id".parse().unwrap();

    let id = match partial.get_from_pathname(
        "anyprotocol://localhost/profile/123?query=stuff#andfragments",
        "id",
    ) {
        Some(id) => id.clone(),
        _ => panic!("Expected an ID"),
    };

    assert!(partial
        .test("anyprotocol://localhost/profile/123?query=stuff#andfragments",)
        .unwrap());
    assert_eq!(&id, "123");
}