use std::{fmt::Display, str::FromStr, sync::Arc};
use derive_more::{AsRef, From, Into};
use dhttp_identity::name::DhttpName;
use regex::{Error as RegexError, Regex, RegexBuilder};
use serde::{Deserialize, Serialize};
use snafu::ResultExt;
use crate::expr::eval::Evaluable;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum NormalPatternKind {
Exact = 0,
Glob = 1,
Regex = 2,
}
impl NormalPatternKind {
const fn priority(&self) -> usize {
*self as usize
}
}
#[derive(
Debug, Clone, Copy, From, Into, AsRef, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
)]
#[serde(transparent)]
pub struct ClientNamePatternKind(NormalPatternKind);
impl ClientNamePatternKind {
const fn priority(&self) -> usize {
self.0.priority()
}
}
#[derive(
Debug, Clone, Copy, From, Into, AsRef, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
)]
#[serde(transparent)]
pub struct DomainPatternKind(NormalPatternKind);
impl DomainPatternKind {
const fn priority(&self) -> usize {
self.0.priority()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum LocationPatternKind {
Exact = 0,
Prefix = 1,
Regex = 2,
NormalPrefix = 3,
Common = 4,
}
impl LocationPatternKind {
const fn priority(&self) -> usize {
*self as usize
}
}
#[derive(Debug, Clone)]
pub struct Pattern<Kind> {
kind: Kind,
regex: Regex,
pattern: Arc<str>,
}
pub type NormalPattern = Pattern<NormalPatternKind>;
pub type LocationPattern = Pattern<LocationPatternKind>;
pub type ClientNamePattern = Pattern<ClientNamePatternKind>;
pub type DomainPattern = Pattern<DomainPatternKind>;
impl<Kind> Pattern<Kind> {
#[inline]
pub fn new(pattern: impl AsRef<str>) -> Result<Self, <Self as FromStr>::Err>
where
Self: FromStr,
{
pattern.as_ref().parse()
}
#[inline]
pub fn as_str(&self) -> &str {
&self.pattern
}
#[inline]
pub const fn kind(&self) -> &Kind {
&self.kind
}
}
impl Pattern<NormalPatternKind> {
#[inline]
pub fn is_match(&self, s: &str) -> bool {
self.regex.is_match(s)
}
#[inline]
pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
self.regex.find(s).map(|m| &s[m.range()])
}
}
impl Pattern<LocationPatternKind> {
#[inline]
pub fn is_match(&self, s: &str) -> bool {
self.regex.is_match(s)
}
#[inline]
pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
self.regex.find(s).map(|m| &s[m.range()])
}
}
pub fn trim_suffix_once<'s>(s: &'s str, suffix: &str) -> Option<&'s str> {
if let Some(pos) = s.rfind(suffix)
&& pos + suffix.len() == s.len()
{
return Some(&s[..pos]);
}
None
}
impl Pattern<ClientNamePatternKind> {
#[inline]
pub fn is_match(&self, s: &str) -> bool {
trim_suffix_once(s, DhttpName::SUFFIX).is_some_and(|s| self.regex.is_match(s))
}
#[inline]
pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
trim_suffix_once(s, DhttpName::SUFFIX)
.and_then(|s| self.regex.find(s).map(|m| &s[m.range()]))
}
}
impl Pattern<DomainPatternKind> {
#[inline]
pub fn is_match(&self, s: &str) -> bool {
trim_suffix_once(s, DhttpName::SUFFIX).is_some_and(|s| self.regex.is_match(s))
}
#[inline]
pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
trim_suffix_once(s, DhttpName::SUFFIX)
.and_then(|s| self.regex.find(s).map(|m| &s[m.range()]))
}
}
macro_rules! impl_pattern {
(impl Evaluable<&str> for Pattern<$kind:ident> { ... } $($tt:tt)*) => {
impl Evaluable<&str> for Pattern<$kind> {
type Value = bool;
fn eval(&self, argument: &&str) -> Self::Value {
self.is_match(argument)
}
}
impl_pattern!($($tt)*);
};
(impl Pattern<$kind:ident> { pub const fn priority(&self) -> usize { ... } } $($tt:tt)*) => {
impl Pattern<$kind> {
#[inline]
pub const fn priority(&self) -> usize {
self.kind.priority()
}
}
impl_pattern!($($tt)*);
};
(impl From<Pattern<$from:ident>> for Pattern<$into:ident> { ... } $($tt:tt)*) => {
impl From<Pattern<$from>> for Pattern<$into> {
fn from(value: Pattern<$from>) -> Self {
Self {
kind: value.kind.into(),
regex: value.regex,
pattern: value.pattern,
}
}
}
impl_pattern!($($tt)*);
};
(impl FromStr for Pattern<$into:ident> from Pattern<$from:ident> { ... } $($tt:tt)*) => {
impl FromStr for Pattern<$into> {
type Err = <Pattern<$from> as FromStr>::Err;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
<Pattern<$from>>::from_str(s).map(Into::into)
}
}
impl_pattern!($($tt)*);
};
(impl Orm for Pattern<$kind:ident> from json { ... } $($tt:tt)*) => {
const _: () = {
type __PatternType = Pattern<$kind>;
crate::orm_new_type!(@json __PatternType);
};
impl_pattern!($($tt)*);
};
() => {}
}
impl_pattern! {
impl Evaluable<&str> for Pattern<NormalPatternKind> { ... }
impl Pattern<NormalPatternKind> { pub const fn priority(&self) -> usize { ... } }
impl Orm for Pattern<NormalPatternKind> from json { ... }
impl Evaluable<&str> for Pattern<LocationPatternKind> { ... }
impl Pattern<LocationPatternKind> { pub const fn priority(&self) -> usize { ... } }
impl Orm for Pattern<LocationPatternKind> from json { ... }
impl Evaluable<&str> for Pattern<ClientNamePatternKind> { ... }
impl Pattern<ClientNamePatternKind> { pub const fn priority(&self) -> usize { ... } }
impl From<Pattern<NormalPatternKind>> for Pattern<ClientNamePatternKind> { ... }
impl FromStr for Pattern<ClientNamePatternKind> from Pattern<NormalPatternKind> { ... }
impl Orm for Pattern<ClientNamePatternKind> from json { ... }
impl Evaluable<&str> for Pattern<DomainPatternKind> { ... }
impl Pattern<DomainPatternKind> { pub const fn priority(&self) -> usize { ... } }
impl From<Pattern<NormalPatternKind>> for Pattern<DomainPatternKind> { ... }
impl FromStr for Pattern<DomainPatternKind> from Pattern<NormalPatternKind> { ... }
impl Orm for Pattern<DomainPatternKind> from json { ... }
}
mod regex_utils {
use super::*;
pub(super) fn case_insensitive_regex(pat: &str) -> Result<Regex, regex::Error> {
RegexBuilder::new(pat).case_insensitive(true).build()
}
pub(super) fn glob_to_regex(glob: &globset::Glob) -> Result<Regex, regex::Error> {
glob.regex()
.strip_prefix("(?-u)")
.unwrap_or(glob.regex())
.parse()
}
}
mod parse_pattern {
use globset::{Glob, GlobBuilder};
use super::{regex_utils, *};
#[derive(snafu::Snafu, Debug, Clone)]
pub enum ParsePatternError {
#[snafu(display("invalid regex pattern `{pattern}`"))]
InvalidRegex {
pattern: Arc<str>,
source: RegexError,
},
#[snafu(display("invalid glob pattern"))]
InvalidGlob { source: globset::Error },
}
impl FromStr for Pattern<NormalPatternKind> {
type Err = ParsePatternError;
fn from_str(pattern: &str) -> Result<Self, Self::Err> {
let pattern: Arc<str> = Arc::from(pattern);
let (kind, regex) = match pattern.split_once(' ') {
Some(("=", pat)) => (
NormalPatternKind::Exact,
Regex::new(&format!("^{}$", regex::escape(pat)))
.context(InvalidRegexSnafu { pattern: pat })?,
),
Some(("*", pattern)) => {
let glob = GlobBuilder::new(pattern)
.case_insensitive(true)
.build()
.context(InvalidGlobSnafu)?;
(
NormalPatternKind::Glob,
regex_utils::glob_to_regex(&glob).context(InvalidRegexSnafu { pattern })?,
)
}
Some(("~", pattern)) => (
NormalPatternKind::Regex,
Regex::new(pattern).context(InvalidRegexSnafu { pattern })?,
),
Some(("~*", pattern)) => (
NormalPatternKind::Regex,
regex_utils::case_insensitive_regex(pattern)
.context(InvalidRegexSnafu { pattern })?,
),
_ => {
let glob = Glob::new(&pattern).context(InvalidGlobSnafu)?;
(
NormalPatternKind::Glob,
regex_utils::glob_to_regex(&glob).context(InvalidRegexSnafu {
pattern: pattern.clone(),
})?,
)
}
};
Ok(Self {
kind,
regex,
pattern,
})
}
}
}
pub use parse_pattern::ParsePatternError;
mod parse_location_pattern {
use super::{regex_utils, *};
#[derive(snafu::Snafu, Debug, Clone)]
pub enum ParseLocationPatternError {
#[snafu(display("unknown symbol `{symbol}`, expected one of {expect:?}"))]
UnknownSymbol {
symbol: String,
expect: &'static [&'static str],
},
#[snafu(display("invalid regex pattern `{pattern}`"))]
InvalidRegex {
pattern: Arc<str>,
source: RegexError,
},
#[snafu(display("expected common pattern or normal prefix starting with `{prefix}`"))]
UndefinedPrefixOrCommon { prefix: &'static str },
}
impl FromStr for Pattern<LocationPatternKind> {
type Err = ParseLocationPatternError;
fn from_str(pattern: &str) -> Result<Self, Self::Err> {
let pattern: Arc<str> = Arc::from(pattern);
let (kind, regex) = match pattern.split_once(' ') {
None if pattern.as_ref() == "/" => (
LocationPatternKind::Common,
Regex::new(r"^/").context(InvalidRegexSnafu {
pattern: pattern.clone(),
})?,
),
None if pattern.starts_with("/") => (
LocationPatternKind::NormalPrefix,
Regex::new(format!("^{}", regex::escape(&pattern)).as_str()).context(
InvalidRegexSnafu {
pattern: pattern.clone(),
},
)?,
),
None => return UndefinedPrefixOrCommonSnafu { prefix: "/" }.fail(),
Some(("=", pattern)) => (
LocationPatternKind::Exact,
Regex::new(&format!("^{}$", regex::escape(pattern)))
.context(InvalidRegexSnafu { pattern })?,
),
Some(("^~", pattern)) => (
LocationPatternKind::Prefix,
Regex::new(format!("^{}", regex::escape(pattern)).as_str())
.context(InvalidRegexSnafu { pattern })?,
),
Some(("~", pattern)) => (
LocationPatternKind::Regex,
Regex::new(pattern).context(InvalidRegexSnafu { pattern })?,
),
Some(("~*", pattern)) => (
LocationPatternKind::Regex,
regex_utils::case_insensitive_regex(pattern)
.context(InvalidRegexSnafu { pattern })?,
),
Some((symbol, ..)) => {
return UnknownSymbolSnafu::fail(UnknownSymbolSnafu {
symbol: symbol.to_string(),
expect: &["=", "^~", "~", "~*"] as &'static [&'static str],
});
}
};
Ok(Self {
kind,
regex,
pattern,
})
}
}
}
pub use parse_location_pattern::ParseLocationPatternError;
impl<Kind> Display for Pattern<Kind> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl<Kind> Serialize for Pattern<Kind> {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.as_str().serialize(serializer)
}
}
impl<'de, Kind> Deserialize<'de> for Pattern<Kind>
where
Self: FromStr<Err: Display>,
{
#[inline]
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl<Kind: PartialEq> PartialEq for Pattern<Kind> {
fn eq(&self, other: &Self) -> bool {
self.kind == other.kind && self.pattern == other.pattern
}
}
impl<Kind: Eq> Eq for Pattern<Kind> {}
impl<Kind: PartialOrd> PartialOrd for Pattern<Kind> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.kind
.partial_cmp(&other.kind)
.map(|ord| ord.then_with(|| self.pattern.cmp(&other.pattern)))
}
}
impl<Kind: Ord> Ord for Pattern<Kind> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.kind
.cmp(&other.kind)
.then_with(|| self.pattern.cmp(&other.pattern))
}
}
#[cfg(test)]
mod dhttp_suffix_tests {
use super::*;
#[test]
fn client_name_pattern_uses_dhttp_name_suffix() {
let pattern = Pattern::<ClientNamePatternKind>::new("~ ^reimu\\.pilot$")
.expect("valid client name pattern");
assert!(pattern.is_match("reimu.pilot.dhttp.net"));
assert!(!pattern.is_match("reimu.pilot.genmeta.net"));
}
}