use std::fmt::{Display, Write};
use schemars::JsonSchema;
use serde::{de::Error, Deserialize, Serialize};
use thiserror::Error;
use crate::{segment::DomainSegment, FullyQualifiedDomainName};
#[derive(Error, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum PatternError {}
#[derive(Default, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Pattern(Vec<PatternSegment>);
impl Pattern {
pub fn origin() -> Self {
Pattern::default()
}
pub fn iter(&self) -> impl Iterator<Item = &PatternSegment> + '_ {
self.0.iter()
}
pub fn with_origin(&self, origin: &FullyQualifiedDomainName) -> Pattern {
let mut cloned = self.clone();
cloned.0.extend(origin.iter().map(PatternSegment::from));
cloned
}
pub fn matches(&self, domain: &FullyQualifiedDomainName) -> bool {
let domain_segments = domain.as_ref().iter().rev();
let pattern_segments = self.0[..].iter().rev();
if domain_segments.len() < pattern_segments.len() {
return false;
}
if domain_segments.len() > pattern_segments.len()
&& !self.0.first().is_some_and(|pattern| pattern.as_ref() == "*")
{
return false;
}
for (pattern, domain) in pattern_segments.zip(domain_segments) {
if pattern.as_ref() == "*" {
return true;
}
if !pattern.matches(domain) {
return false;
}
}
true
}
}
impl FromIterator<PatternSegment> for Pattern {
fn from_iter<T: IntoIterator<Item = PatternSegment>>(iter: T) -> Self {
Pattern(iter.into_iter().collect())
}
}
impl TryFrom<&str> for Pattern {
type Error = PatternSegmentError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let segments = Result::from_iter(
value
.trim_end_matches('.')
.split('.')
.map(PatternSegment::try_from),
)?;
Ok(Pattern(segments))
}
}
impl TryFrom<String> for Pattern {
type Error = PatternSegmentError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_ref())
}
}
impl Display for Pattern {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for segment in &self.0 {
write!(f, "{}", segment)?;
f.write_char('.')?;
}
Ok(())
}
}
impl JsonSchema for Pattern {
fn schema_name() -> String {
<String as schemars::JsonSchema>::schema_name()
}
fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
<String as schemars::JsonSchema>::json_schema(gen)
}
}
impl<'de> Deserialize<'de> for Pattern {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let value = String::deserialize(deserializer)?;
Self::try_from(value).map_err(D::Error::custom)
}
}
impl Serialize for Pattern {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_string().serialize(serializer)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PatternSegment(String);
impl PatternSegment {
pub fn matches(&self, domain_segment: &DomainSegment) -> bool {
if self.0 == domain_segment.as_ref() {
return true;
}
if let Some((head, tail)) = self.0.split_once('*') {
return domain_segment.as_ref().starts_with(head)
&& domain_segment.as_ref().ends_with(tail);
}
false
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.0.len()
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
pub enum PatternSegmentError {
#[error("illegal hyphen at position {0}")]
IllegalHyphen(usize),
#[error("invalid character {0}")]
InvalidCharacter(char),
#[error("pattern too long {0} > 63")]
TooLong(usize),
#[error("pattern is an empty string")]
EmptyString,
#[error("patterns can only have one wildcard")]
MultipleWildcards,
}
const VALID_CHARACTERS: &str = "_-0123456789abcdefghijklmnopqrstuvwxyz*";
impl TryFrom<&str> for PatternSegment {
type Error = PatternSegmentError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let value = value.to_ascii_lowercase();
if value.is_empty() {
return Err(PatternSegmentError::EmptyString);
}
if value.len() > 63 {
return Err(PatternSegmentError::TooLong(value.len()));
}
if let Some(character) = value.chars().find(|c| !VALID_CHARACTERS.contains(*c)) {
return Err(PatternSegmentError::InvalidCharacter(character));
}
if value.starts_with('-') {
return Err(PatternSegmentError::IllegalHyphen(1));
}
if value.ends_with('-') {
return Err(PatternSegmentError::IllegalHyphen(value.len()));
}
if value.get(2..4) == Some("--") {
return Err(PatternSegmentError::IllegalHyphen(3));
}
if value.chars().filter(|c| *c == '*').count() > 1 {
return Err(PatternSegmentError::MultipleWildcards);
}
Ok(PatternSegment(value))
}
}
impl From<DomainSegment> for PatternSegment {
fn from(value: DomainSegment) -> Self {
PatternSegment(value.to_string())
}
}
impl From<&DomainSegment> for PatternSegment {
fn from(value: &DomainSegment) -> Self {
PatternSegment(value.to_string())
}
}
impl TryFrom<String> for PatternSegment {
type Error = PatternSegmentError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl Display for PatternSegment {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}
impl AsRef<str> for PatternSegment {
fn as_ref(&self) -> &str {
self.0.as_str()
}
}
#[cfg(test)]
mod tests {
use crate::{
error::PatternSegmentError, pattern::PatternSegment, segment::DomainSegment,
FullyQualifiedDomainName, Pattern,
};
#[test]
fn literal_matches() {
assert!(PatternSegment::try_from("example")
.unwrap()
.matches(&DomainSegment::try_from("example").unwrap()))
}
#[test]
fn wildcard() {
assert!(PatternSegment::try_from("*")
.unwrap()
.matches(&DomainSegment::try_from("example").unwrap()))
}
#[test]
fn leading_wildcard() {
assert!(PatternSegment::try_from("*ample")
.unwrap()
.matches(&DomainSegment::try_from("example").unwrap()))
}
#[test]
fn trailing_wildcard() {
assert!(PatternSegment::try_from("examp*")
.unwrap()
.matches(&DomainSegment::try_from("example").unwrap()))
}
#[test]
fn splitting_wildcard() {
assert!(PatternSegment::try_from("ex*le")
.unwrap()
.matches(&DomainSegment::try_from("example").unwrap()))
}
#[test]
fn multiple_wildcards() {
assert_eq!(
PatternSegment::try_from("*amp*"),
Err(PatternSegmentError::MultipleWildcards)
);
}
#[test]
fn simple_pattern_match() {
assert!(Pattern::try_from("*.example.org")
.unwrap()
.matches(&FullyQualifiedDomainName::try_from("www.example.org.").unwrap()));
}
#[test]
fn longer_pattern_than_domain() {
assert!(!Pattern::try_from("*.*.example.org")
.unwrap()
.matches(&FullyQualifiedDomainName::try_from("www.example.org.").unwrap()));
}
#[test]
fn longer_domain_than_pattern() {
assert!(Pattern::try_from("*.example.org").unwrap().matches(
&FullyQualifiedDomainName::try_from("www.sub.test.dev.example.org.").unwrap()
));
}
#[test]
fn wildcard_segments() {
let pattern = Pattern::try_from("dev*.example.org").unwrap();
assert!(pattern.matches(&FullyQualifiedDomainName::try_from("dev.example.org.").unwrap()));
assert!(pattern.matches(&FullyQualifiedDomainName::try_from("dev-1.example.org.").unwrap()));
assert!(
pattern.matches(&FullyQualifiedDomainName::try_from("dev-hello.example.org.").unwrap())
);
assert!(!pattern.matches(&FullyQualifiedDomainName::try_from("de.example.org.").unwrap()));
assert!(!pattern
.matches(&FullyQualifiedDomainName::try_from("www.dev-1.example.org.").unwrap()));
}
#[test]
fn patterns_assumed_wildcard() {
let fqdn = Pattern::try_from("example.org.").unwrap();
let pqdn = Pattern::try_from("example.org").unwrap();
assert_eq!(fqdn, pqdn);
assert_eq!(
fqdn.matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap()),
pqdn.matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap())
);
}
#[test]
fn origin_insertion() {
let pattern = Pattern::try_from("example").unwrap();
let domain = FullyQualifiedDomainName::try_from("example.org.").unwrap();
assert!(!pattern.matches(&domain));
assert!(pattern
.with_origin(&FullyQualifiedDomainName::try_from("org.").unwrap())
.matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap()));
}
}