1use http::{uri::InvalidUri, Uri};
2use serde::{Deserialize, Serialize};
3use std::str::FromStr;
4use urlpattern::{
5 UrlPattern, UrlPatternInit, UrlPatternMatchInput, UrlPatternOptions, UrlPatternResult,
6};
7
8#[derive(Debug, thiserror::Error)]
10pub enum PatternError {
11 #[error("invalid url: {0}")]
13 InvalidUrl(#[from] InvalidUri),
14
15 #[error("invalid url pattern: {0}")]
17 InvalidPattern(#[from] urlpattern::Error),
18}
19
20#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
24pub struct Pattern {
25 raw: String,
26 #[serde(with = "http_serde_ext::uri")]
27 uri: Uri,
28}
29
30impl Pattern {
31 fn url_pattern(&self) -> Option<UrlPattern> {
32 <UrlPattern>::parse(self.pattern_init(), UrlPatternOptions::default()).ok()
33 }
34
35 pub fn test(&self, other: impl TryInto<Pattern>) -> Result<bool, PatternError> {
39 let other = match other.try_into() {
40 Ok(other) => other,
41 Err(_) => return Ok(false),
42 };
43
44 let matching_pattern = self.url_pattern();
45
46 match matching_pattern {
47 None => Ok(false),
48 Some(matching_pattern) => Ok(matching_pattern.test(other.pattern_input())?),
49 }
50 }
51
52 pub fn captures(&self, other: impl TryInto<Pattern>) -> Option<UrlPatternResult> {
56 let other = match other.try_into() {
57 Ok(other) => other,
58 Err(_) => return None,
59 };
60
61 let matching_pattern = self.url_pattern();
62
63 if let Some(matching_pattern) = matching_pattern {
64 return matching_pattern
65 .exec(other.pattern_input())
66 .unwrap_or_default();
67 }
68
69 None
70 }
71
72 pub fn get_from_pathname(&self, other: impl TryInto<Pattern>, key: &str) -> Option<String> {
75 if let Ok(other) = other.try_into() {
76 let captures = self.captures(other)?;
77
78 match captures.pathname.groups.get(key) {
79 Some(Some(value)) => Some(value.clone()),
80 _ => None,
81 }
82 } else {
83 None
84 }
85 }
86
87 fn pattern_init(&self) -> UrlPatternInit {
88 UrlPatternInit {
89 pathname: Some(self.uri.path().to_owned()),
90 search: self.uri.query().map(Into::into),
91 ..Default::default()
92 }
93 }
94
95 fn pattern_input(&self) -> UrlPatternMatchInput {
96 UrlPatternMatchInput::Init(self.pattern_init())
97 }
98
99 fn try_from_str(input: &str) -> Result<Self, PatternError> {
100 Ok(Pattern {
101 raw: input.into(),
102 uri: get_uri(input)?,
103 })
104 }
105}
106
107fn get_uri(input: &str) -> Result<Uri, PatternError> {
108 let mut uri = input.to_owned();
109
110 if !uri.starts_with('/') && !uri.contains("://") {
111 uri.insert(0, '/')
112 }
113
114 Ok(uri.parse()?)
115}
116
117impl FromStr for Pattern {
118 type Err = PatternError;
119
120 fn from_str(input: &str) -> Result<Self, Self::Err> {
121 Self::try_from_str(input)
122 }
123}
124
125impl TryFrom<&str> for Pattern {
126 type Error = PatternError;
127
128 fn try_from(input: &str) -> Result<Self, Self::Error> {
129 input.parse()
130 }
131}
132
133#[test]
134fn accepts_urls() {
135 let pattern: Result<Pattern, _> = "frame://accounts/profile/:id".parse();
136
137 assert!(pattern.is_ok());
138}
139
140#[test]
141fn describes_capture_results() {
142 use std::collections::HashMap;
143 use urlpattern::UrlPatternComponentResult;
144
145 let pattern: Pattern = "frame://accounts/profile/:id".parse().unwrap();
146 let captures = pattern.captures("frame://accounts/profile/123");
147
148 let url_pattern_result = UrlPatternComponentResult {
149 input: "/profile/123".to_string(),
150 groups: HashMap::from([("id".into(), Some("123".into()))]),
151 };
152
153 assert_eq!(captures.unwrap().pathname, url_pattern_result);
154}
155
156#[test]
157fn partials() {
158 let partial: Pattern = "/profile/:id".parse().unwrap();
159
160 let id = match partial.get_from_pathname(
161 "anyprotocol://localhost/profile/123?query=stuff#andfragments",
162 "id",
163 ) {
164 Some(id) => id.clone(),
165 _ => panic!("Expected an ID"),
166 };
167
168 assert!(partial
169 .test("anyprotocol://localhost/profile/123?query=stuff#andfragments",)
170 .unwrap());
171 assert_eq!(&id, "123");
172}