use http::{uri::InvalidUri, Uri};
use serde::{Deserialize, Serialize};
use std::str::FromStr;
use urlpattern::{
UrlPattern, UrlPatternInit, UrlPatternMatchInput, UrlPatternOptions, UrlPatternResult,
};
#[derive(Debug, thiserror::Error)]
pub enum PatternError {
#[error("invalid url: {0}")]
InvalidUrl(#[from] InvalidUri),
#[error("invalid url pattern: {0}")]
InvalidPattern(#[from] urlpattern::Error),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Pattern {
raw: String,
#[serde(with = "http_serde_ext::uri")]
uri: Uri,
}
impl Pattern {
fn url_pattern(&self) -> Option<UrlPattern> {
<UrlPattern>::parse(self.pattern_init(), UrlPatternOptions::default()).ok()
}
pub fn test(&self, other: impl TryInto<Pattern>) -> Result<bool, PatternError> {
let other = match other.try_into() {
Ok(other) => other,
Err(_) => return Ok(false),
};
let matching_pattern = self.url_pattern();
match matching_pattern {
None => Ok(false),
Some(matching_pattern) => Ok(matching_pattern.test(other.pattern_input())?),
}
}
pub fn captures(&self, other: impl TryInto<Pattern>) -> Option<UrlPatternResult> {
let other = match other.try_into() {
Ok(other) => other,
Err(_) => return None,
};
let matching_pattern = self.url_pattern();
if let Some(matching_pattern) = matching_pattern {
return matching_pattern
.exec(other.pattern_input())
.unwrap_or_default();
}
None
}
pub fn get_from_pathname(&self, other: impl TryInto<Pattern>, key: &str) -> Option<String> {
if let Ok(other) = other.try_into() {
let captures = self.captures(other)?;
match captures.pathname.groups.get(key) {
Some(Some(value)) => Some(value.clone()),
_ => None,
}
} else {
None
}
}
fn pattern_init(&self) -> UrlPatternInit {
UrlPatternInit {
pathname: Some(self.uri.path().to_owned()),
search: self.uri.query().map(Into::into),
..Default::default()
}
}
fn pattern_input(&self) -> UrlPatternMatchInput {
UrlPatternMatchInput::Init(self.pattern_init())
}
fn try_from_str(input: &str) -> Result<Self, PatternError> {
Ok(Pattern {
raw: input.into(),
uri: get_uri(input)?,
})
}
}
fn get_uri(input: &str) -> Result<Uri, PatternError> {
let mut uri = input.to_owned();
if !uri.starts_with('/') && !uri.contains("://") {
uri.insert(0, '/')
}
Ok(uri.parse()?)
}
impl FromStr for Pattern {
type Err = PatternError;
fn from_str(input: &str) -> Result<Self, Self::Err> {
Self::try_from_str(input)
}
}
impl TryFrom<&str> for Pattern {
type Error = PatternError;
fn try_from(input: &str) -> Result<Self, Self::Error> {
input.parse()
}
}
#[test]
fn accepts_urls() {
let pattern: Result<Pattern, _> = "frame://accounts/profile/:id".parse();
assert!(pattern.is_ok());
}
#[test]
fn describes_capture_results() {
use std::collections::HashMap;
use urlpattern::UrlPatternComponentResult;
let pattern: Pattern = "frame://accounts/profile/:id".parse().unwrap();
let captures = pattern.captures("frame://accounts/profile/123");
let url_pattern_result = UrlPatternComponentResult {
input: "/profile/123".to_string(),
groups: HashMap::from([("id".into(), Some("123".into()))]),
};
assert_eq!(captures.unwrap().pathname, url_pattern_result);
}
#[test]
fn partials() {
let partial: Pattern = "/profile/:id".parse().unwrap();
let id = match partial.get_from_pathname(
"anyprotocol://localhost/profile/123?query=stuff#andfragments",
"id",
) {
Some(id) => id.clone(),
_ => panic!("Expected an ID"),
};
assert!(partial
.test("anyprotocol://localhost/profile/123?query=stuff#andfragments",)
.unwrap());
assert_eq!(&id, "123");
}