Skip to main content

dhttp_access/
pattern.rs

1use std::{fmt::Display, str::FromStr, sync::Arc};
2
3use derive_more::{AsRef, From, Into};
4use dhttp_identity::name::DhttpName;
5use regex::{Error as RegexError, Regex, RegexBuilder};
6use serde::{Deserialize, Serialize};
7use snafu::ResultExt;
8
9use crate::expr::eval::Evaluable;
10
11/// 普通模式类型,支持精确匹配、Glob模式和正则表达式
12///
13/// # 示例
14///
15/// ```
16/// use dhttp_access::pattern::{NormalPattern, NormalPatternKind};
17///
18/// // 精确匹配
19/// let pattern: NormalPattern = "= hello".parse().unwrap();
20/// assert_eq!(pattern.kind(), &NormalPatternKind::Exact);
21/// assert!(pattern.is_match("hello"));
22/// assert!(!pattern.is_match("hello world"));
23///
24/// // Glob 模式(默认)
25/// let pattern: NormalPattern = "*.txt".parse().unwrap();
26/// assert_eq!(pattern.kind(), &NormalPatternKind::Glob);
27/// assert!(pattern.is_match("file.txt"));
28/// assert!(!pattern.is_match("file.doc"));
29///
30/// // Glob 模式(不区分大小写)
31/// let pattern: NormalPattern = "* *.TXT".parse().unwrap();
32/// assert_eq!(pattern.kind(), &NormalPatternKind::Glob);
33/// assert!(pattern.is_match("file.txt"));
34/// assert!(pattern.is_match("FILE.TXT"));
35///
36/// // 正则表达式
37/// let pattern: NormalPattern = r"~ \d+".parse().unwrap();
38/// assert_eq!(pattern.kind(), &NormalPatternKind::Regex);
39/// assert!(pattern.is_match("123"));
40/// assert!(!pattern.is_match("abc"));
41///
42/// // 正则表达式(不区分大小写)
43/// let pattern: NormalPattern = "~* hello".parse().unwrap();
44/// assert_eq!(pattern.kind(), &NormalPatternKind::Regex);
45/// assert!(pattern.is_match("HELLO"));
46/// assert!(pattern.is_match("hello"));
47/// ```
48#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
49pub enum NormalPatternKind {
50    /// 精确匹配模式 - 语法:`= pattern`
51    ///
52    /// # Examples
53    ///
54    /// ```
55    /// use dhttp_access::pattern::{NormalPattern, NormalPatternKind};
56    ///
57    /// let pattern: NormalPattern = "= hello".parse().unwrap();
58    /// assert!(matches!(pattern.kind(), NormalPatternKind::Exact));
59    /// assert!(pattern.is_match("hello"));
60    /// assert!(!pattern.is_match("Hello"));
61    /// assert!(!pattern.is_match("hello world"));
62    /// ```
63    Exact = 0,
64    /// Glob 模式匹配 - 语法:`pattern` (默认) 或 `* pattern` (不区分大小写)
65    ///
66    /// # Examples
67    ///
68    /// ```
69    /// use dhttp_access::pattern::{NormalPattern, NormalPatternKind};
70    ///
71    /// // 默认 glob 模式
72    /// let pattern: NormalPattern = "*.txt".parse().unwrap();
73    /// assert!(matches!(pattern.kind(), NormalPatternKind::Glob));
74    /// assert!(pattern.is_match("test.txt"));
75    /// assert!(pattern.is_match("hello.txt"));
76    /// assert!(!pattern.is_match("test.doc"));
77    ///
78    /// // 不区分大小写的 glob 模式
79    /// let pattern: NormalPattern = "* *.TXT".parse().unwrap();
80    /// assert!(pattern.is_match("test.txt"));
81    /// assert!(pattern.is_match("TEST.TXT"));
82    /// ```
83    Glob = 1,
84    /// 正则表达式匹配 - 语法:`~ regex` 或 `~* regex` (不区分大小写)
85    ///
86    /// # Examples
87    ///
88    /// ```
89    /// use dhttp_access::pattern::{NormalPattern, NormalPatternKind};
90    ///
91    /// // 区分大小写的正则
92    /// let pattern: NormalPattern = "~ test\\d+".parse().unwrap();
93    /// assert!(matches!(pattern.kind(), NormalPatternKind::Regex));
94    /// assert!(pattern.is_match("test123"));
95    /// assert!(!pattern.is_match("Test123"));
96    ///
97    /// // 不区分大小写的正则
98    /// let pattern: NormalPattern = "~* test\\d+".parse().unwrap();
99    /// assert!(pattern.is_match("test123"));
100    /// assert!(pattern.is_match("Test123"));
101    /// assert!(pattern.is_match("TEST123"));
102    /// ```
103    Regex = 2,
104}
105
106impl NormalPatternKind {
107    const fn priority(&self) -> usize {
108        *self as usize
109    }
110}
111
112#[derive(
113    Debug, Clone, Copy, From, Into, AsRef, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
114)]
115#[serde(transparent)]
116pub struct ClientNamePatternKind(NormalPatternKind);
117
118impl ClientNamePatternKind {
119    const fn priority(&self) -> usize {
120        self.0.priority()
121    }
122}
123
124#[derive(
125    Debug, Clone, Copy, From, Into, AsRef, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize,
126)]
127#[serde(transparent)]
128pub struct DomainPatternKind(NormalPatternKind);
129
130impl DomainPatternKind {
131    const fn priority(&self) -> usize {
132        self.0.priority()
133    }
134}
135
136/// 位置模式类型,类似 Nginx location 配置
137///
138/// # 示例
139///
140/// ```
141/// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
142///
143/// // 精确匹配
144/// let pattern: LocationPattern = "= /api/v1".parse().unwrap();
145/// assert_eq!(pattern.kind(), &LocationPatternKind::Exact);
146/// assert!(pattern.is_match("/api/v1"));
147/// assert!(!pattern.is_match("/api/v1/users"));
148///
149/// // 字面量前缀匹配
150/// let pattern: LocationPattern = "^~ /static/".parse().unwrap();
151/// assert_eq!(pattern.kind(), &LocationPatternKind::Prefix);
152/// assert!(pattern.is_match("/static/css/style.css"));
153/// assert!(!pattern.is_match("/images/logo.png"));
154///
155/// // 正则表达式匹配
156/// let pattern: LocationPattern = r"~ ^/api/\d+$".parse().unwrap();
157/// assert_eq!(pattern.kind(), &LocationPatternKind::Regex);
158/// assert!(pattern.is_match("/api/123"));
159/// assert!(!pattern.is_match("/api/abc"));
160///
161/// // 普通前缀匹配
162/// let pattern: LocationPattern = "/uploads".parse().unwrap();
163/// assert_eq!(pattern.kind(), &LocationPatternKind::NormalPrefix);
164/// assert!(pattern.is_match("/uploads/file.jpg"));
165///
166/// // 通用匹配(根路径)
167/// let pattern: LocationPattern = "/".parse().unwrap();
168/// assert_eq!(pattern.kind(), &LocationPatternKind::Common);
169/// assert!(pattern.is_match("/anything"));
170/// ```
171#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
172pub enum LocationPatternKind {
173    /// 精确匹配 - 语法:`= pattern`
174    ///
175    /// # Examples
176    ///
177    /// ```
178    /// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
179    ///
180    /// let pattern: LocationPattern = "= /home".parse().unwrap();
181    /// assert!(matches!(pattern.kind(), LocationPatternKind::Exact));
182    /// assert!(pattern.is_match("/home"));
183    /// assert!(!pattern.is_match("/home/"));
184    /// assert!(!pattern.is_match("/home/user"));
185    /// ```
186    Exact = 0,
187    /// 字面量前缀匹配 - 语法:`^~ pattern`
188    ///
189    /// # Examples
190    ///
191    /// ```
192    /// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
193    ///
194    /// let pattern: LocationPattern = "^~ /api".parse().unwrap();
195    /// assert!(matches!(pattern.kind(), LocationPatternKind::Prefix));
196    /// assert!(pattern.is_match("/api"));
197    /// assert!(pattern.is_match("/api/"));
198    /// assert!(pattern.is_match("/api/users"));
199    /// assert!(!pattern.is_match("/app"));
200    /// ```
201    Prefix = 1,
202    /// 正则表达式匹配 - 语法:`~ regex` 或 `~* regex` (不区分大小写)
203    ///
204    /// # Examples
205    ///
206    /// ```
207    /// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
208    ///
209    /// // 区分大小写的正则
210    /// let pattern: LocationPattern = "~ /api/\\d+".parse().unwrap();
211    /// assert!(matches!(pattern.kind(), LocationPatternKind::Regex));
212    /// assert!(pattern.is_match("/api/123"));
213    /// assert!(!pattern.is_match("/API/123"));
214    ///
215    /// // 不区分大小写的正则
216    /// let pattern: LocationPattern = "~* /api/\\d+".parse().unwrap();
217    /// assert!(pattern.is_match("/api/123"));
218    /// assert!(pattern.is_match("/API/123"));
219    /// ```
220    Regex = 2,
221    /// 普通前缀匹配 - 语法:`/xxx` (以 / 开头的路径)
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
227    ///
228    /// let pattern: LocationPattern = "/admin".parse().unwrap();
229    /// assert!(matches!(pattern.kind(), LocationPatternKind::NormalPrefix));
230    /// assert!(pattern.is_match("/admin"));
231    /// assert!(pattern.is_match("/admin/"));
232    /// assert!(pattern.is_match("/admin/users"));
233    /// assert!(!pattern.is_match("/app"));
234    /// ```
235    NormalPrefix = 3,
236    /// 通用匹配 - 语法:`/` (根路径)
237    ///
238    /// # Examples
239    ///
240    /// ```
241    /// use dhttp_access::pattern::{LocationPattern, LocationPatternKind};
242    ///
243    /// let pattern: LocationPattern = "/".parse().unwrap();
244    /// assert!(matches!(pattern.kind(), LocationPatternKind::Common));
245    /// assert!(pattern.is_match("/"));
246    /// assert!(pattern.is_match("/anything"));
247    /// assert!(pattern.is_match("/deeply/nested/path"));
248    /// ```
249    Common = 4,
250}
251
252impl LocationPatternKind {
253    const fn priority(&self) -> usize {
254        *self as usize
255    }
256}
257
258/// 通用模式匹配结构
259///
260/// 支持泛型的模式类型,可以用于不同场景的模式匹配。
261///
262/// # 示例
263///
264/// ```
265/// use dhttp_access::pattern::{NormalPattern, LocationPattern};
266///
267/// // 普通模式
268/// let pattern: NormalPattern = "*.log".parse().unwrap();
269/// assert!(pattern.is_match("app.log"));
270/// assert_eq!(pattern.as_str(), "*.log");
271///
272/// // 位置模式
273/// let pattern: LocationPattern = "/api".parse().unwrap();
274/// assert!(pattern.is_match("/api/users"));
275/// assert_eq!(pattern.as_str(), "/api");
276///
277/// // 匹配子字符串
278/// let pattern: NormalPattern = "~ test".parse().unwrap();
279/// assert_eq!(pattern.r#match("this is a test"), Some("test"));
280/// ```
281#[derive(Debug, Clone)]
282pub struct Pattern<Kind> {
283    kind: Kind,
284    regex: Regex,
285    pattern: Arc<str>,
286}
287
288/// 普通模式类型别名
289pub type NormalPattern = Pattern<NormalPatternKind>;
290
291/// 位置模式类型别名
292pub type LocationPattern = Pattern<LocationPatternKind>;
293
294pub type ClientNamePattern = Pattern<ClientNamePatternKind>;
295
296pub type DomainPattern = Pattern<DomainPatternKind>;
297
298impl<Kind> Pattern<Kind> {
299    /// 创建新的模式实例
300    ///
301    /// # 示例
302    ///
303    /// ```
304    /// use dhttp_access::pattern::NormalPattern;
305    ///
306    /// let pattern = NormalPattern::new("*.txt").unwrap();
307    /// assert!(pattern.is_match("file.txt"));
308    /// ```
309    #[inline]
310    pub fn new(pattern: impl AsRef<str>) -> Result<Self, <Self as FromStr>::Err>
311    where
312        Self: FromStr,
313    {
314        pattern.as_ref().parse()
315    }
316
317    /// 获取原始模式字符串
318    #[inline]
319    pub fn as_str(&self) -> &str {
320        &self.pattern
321    }
322
323    /// 获取模式类型
324    #[inline]
325    pub const fn kind(&self) -> &Kind {
326        &self.kind
327    }
328}
329
330impl Pattern<NormalPatternKind> {
331    /// 测试字符串是否匹配模式
332    #[inline]
333    pub fn is_match(&self, s: &str) -> bool {
334        self.regex.is_match(s)
335    }
336
337    /// 获取匹配的子字符串
338    #[inline]
339    pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
340        self.regex.find(s).map(|m| &s[m.range()])
341    }
342}
343
344impl Pattern<LocationPatternKind> {
345    /// 测试字符串是否匹配模式
346    #[inline]
347    pub fn is_match(&self, s: &str) -> bool {
348        self.regex.is_match(s)
349    }
350
351    /// 获取匹配的子字符串
352    #[inline]
353    pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
354        self.regex.find(s).map(|m| &s[m.range()])
355    }
356}
357
358pub fn trim_suffix_once<'s>(s: &'s str, suffix: &str) -> Option<&'s str> {
359    if let Some(pos) = s.rfind(suffix)
360        && pos + suffix.len() == s.len()
361    {
362        return Some(&s[..pos]);
363    }
364    None
365}
366
367impl Pattern<ClientNamePatternKind> {
368    /// 测试字符串是否匹配模式
369    #[inline]
370    pub fn is_match(&self, s: &str) -> bool {
371        trim_suffix_once(s, DhttpName::SUFFIX).is_some_and(|s| self.regex.is_match(s))
372    }
373
374    /// 获取匹配的子字符串
375    #[inline]
376    pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
377        trim_suffix_once(s, DhttpName::SUFFIX)
378            .and_then(|s| self.regex.find(s).map(|m| &s[m.range()]))
379    }
380}
381
382impl Pattern<DomainPatternKind> {
383    /// 测试字符串是否匹配模式
384    #[inline]
385    pub fn is_match(&self, s: &str) -> bool {
386        trim_suffix_once(s, DhttpName::SUFFIX).is_some_and(|s| self.regex.is_match(s))
387    }
388
389    /// 获取匹配的子字符串
390    #[inline]
391    pub fn r#match<'s>(&self, s: &'s str) -> Option<&'s str> {
392        trim_suffix_once(s, DhttpName::SUFFIX)
393            .and_then(|s| self.regex.find(s).map(|m| &s[m.range()]))
394    }
395}
396
397macro_rules! impl_pattern {
398    (impl Evaluable<&str> for Pattern<$kind:ident> { ... } $($tt:tt)*) => {
399        impl Evaluable<&str> for Pattern<$kind> {
400            type Value = bool;
401
402            fn eval(&self, argument: &&str) -> Self::Value {
403                self.is_match(argument)
404            }
405        }
406        impl_pattern!($($tt)*);
407    };
408    (impl Pattern<$kind:ident> { pub const fn priority(&self) -> usize { ... } } $($tt:tt)*) => {
409        impl Pattern<$kind> {
410            /// 获取模式优先级,数值越小优先级越高
411            #[inline]
412            pub const fn priority(&self) -> usize {
413                self.kind.priority()
414            }
415        }
416        impl_pattern!($($tt)*);
417    };
418    (impl From<Pattern<$from:ident>> for Pattern<$into:ident> { ... } $($tt:tt)*) => {
419        impl From<Pattern<$from>> for Pattern<$into> {
420            fn from(value: Pattern<$from>) -> Self {
421                Self {
422                    kind: value.kind.into(),
423                    regex: value.regex,
424                    pattern: value.pattern,
425                }
426            }
427        }
428        impl_pattern!($($tt)*);
429    };
430    (impl FromStr for Pattern<$into:ident> from Pattern<$from:ident> { ... } $($tt:tt)*) => {
431        impl FromStr for Pattern<$into> {
432            type Err = <Pattern<$from> as FromStr>::Err;
433
434            #[inline]
435            fn from_str(s: &str) -> Result<Self, Self::Err> {
436                <Pattern<$from>>::from_str(s).map(Into::into)
437            }
438        }
439        impl_pattern!($($tt)*);
440    };
441    (impl Orm for Pattern<$kind:ident> from json { ... } $($tt:tt)*) => {
442        const _: () = {
443            type __PatternType = Pattern<$kind>;
444            crate::orm_new_type!(@json __PatternType);
445        };
446        impl_pattern!($($tt)*);
447    };
448    () => {}
449
450}
451
452impl_pattern! {
453    impl Evaluable<&str> for Pattern<NormalPatternKind> { ... }
454    impl Pattern<NormalPatternKind> { pub const fn priority(&self) -> usize { ... } }
455    impl Orm for Pattern<NormalPatternKind> from json { ... }
456
457    impl Evaluable<&str> for Pattern<LocationPatternKind> { ... }
458    impl Pattern<LocationPatternKind> { pub const fn priority(&self) -> usize { ... } }
459    impl Orm for Pattern<LocationPatternKind> from json { ... }
460
461    impl Evaluable<&str> for Pattern<ClientNamePatternKind> { ... }
462    impl Pattern<ClientNamePatternKind> { pub const fn priority(&self) -> usize { ... } }
463    impl From<Pattern<NormalPatternKind>> for Pattern<ClientNamePatternKind> { ... }
464    impl FromStr for Pattern<ClientNamePatternKind> from Pattern<NormalPatternKind> { ... }
465    impl Orm for Pattern<ClientNamePatternKind> from json { ... }
466
467    impl Evaluable<&str> for Pattern<DomainPatternKind> { ... }
468    impl Pattern<DomainPatternKind> { pub const fn priority(&self) -> usize { ... } }
469    impl From<Pattern<NormalPatternKind>> for Pattern<DomainPatternKind> { ... }
470    impl FromStr for Pattern<DomainPatternKind> from Pattern<NormalPatternKind> { ... }
471    impl Orm for Pattern<DomainPatternKind> from json { ... }
472}
473
474/// 共同的正则表达式构建工具
475mod regex_utils {
476    use super::*;
477
478    /// 创建不区分大小写的正则表达式
479    pub(super) fn case_insensitive_regex(pat: &str) -> Result<Regex, regex::Error> {
480        RegexBuilder::new(pat).case_insensitive(true).build()
481    }
482
483    /// 将 Glob 模式转换为支持非 UTF-8 字符串的正则表达式
484    ///
485    /// 这是处理 Glob 模式的核心函数,配置了特殊的正则表达式设置:
486    /// - utf8(false): 支持非 UTF-8 字节序列匹配
487    /// - dot_matches_new_line(true): 允许 . 匹配换行符
488    /// - 设置了合理的内存限制防止 DoS 攻击
489    pub(super) fn glob_to_regex(glob: &globset::Glob) -> Result<Regex, regex::Error> {
490        glob.regex()
491            .strip_prefix("(?-u)")
492            .unwrap_or(glob.regex())
493            .parse()
494    }
495}
496
497mod parse_pattern {
498    use globset::{Glob, GlobBuilder};
499
500    use super::{regex_utils, *};
501
502    /// 普通模式解析错误
503    ///
504    /// # 示例
505    ///
506    /// ```
507    /// use dhttp_access::pattern::{NormalPattern, ParsePatternError};
508    ///
509    /// // 无效的正则表达式
510    /// let result: Result<NormalPattern, _> = "~ [".parse();
511    /// assert!(matches!(result, Err(ParsePatternError::InvalidRegex { .. })));
512    ///
513    /// // 注意:大多数 glob 模式实际上是有效的,这里只是示例
514    /// // 实际的 InvalidGlob 错误比较难构造,通常发生在内部处理时
515    /// ```
516    #[derive(snafu::Snafu, Debug, Clone)]
517    pub enum ParsePatternError {
518        /// 无效的正则表达式
519        ///
520        /// # Examples
521        ///
522        /// ```
523        /// use dhttp_access::pattern::{NormalPattern, ParsePatternError};
524        ///
525        /// let result: Result<NormalPattern, _> = "~ [invalid".parse();
526        /// assert!(matches!(result, Err(ParsePatternError::InvalidRegex { .. })));
527        ///
528        /// let result: Result<NormalPattern, _> = "~* (?P<invalid".parse();
529        /// assert!(matches!(result, Err(ParsePatternError::InvalidRegex { .. })));
530        /// ```
531        #[snafu(display("invalid regex pattern `{pattern}`"))]
532        InvalidRegex {
533            pattern: Arc<str>,
534            source: RegexError,
535        },
536
537        /// 无效的 Glob 模式
538        ///
539        /// # Examples
540        ///
541        /// ```
542        /// use dhttp_access::pattern::{NormalPattern, ParsePatternError};
543        ///
544        /// // 注意:实际上多数 glob 模式是有效的,这里用简化的示例
545        /// let result = NormalPattern::new("***/invalid");
546        /// // 由于这个例子可能不会失败,我们使用 expect 来说明预期的错误类型
547        /// // assert!(matches!(result, Err(ParsePatternError::InvalidGlob { .. })));
548        /// ```
549        #[snafu(display("invalid glob pattern"))]
550        InvalidGlob { source: globset::Error },
551    }
552
553    impl FromStr for Pattern<NormalPatternKind> {
554        type Err = ParsePatternError;
555
556        fn from_str(pattern: &str) -> Result<Self, Self::Err> {
557            let pattern: Arc<str> = Arc::from(pattern);
558            let (kind, regex) = match pattern.split_once(' ') {
559                Some(("=", pat)) => (
560                    NormalPatternKind::Exact,
561                    Regex::new(&format!("^{}$", regex::escape(pat)))
562                        .context(InvalidRegexSnafu { pattern: pat })?,
563                ),
564                Some(("*", pattern)) => {
565                    let glob = GlobBuilder::new(pattern)
566                        .case_insensitive(true)
567                        .build()
568                        .context(InvalidGlobSnafu)?;
569                    (
570                        NormalPatternKind::Glob,
571                        regex_utils::glob_to_regex(&glob).context(InvalidRegexSnafu { pattern })?,
572                    )
573                }
574                Some(("~", pattern)) => (
575                    NormalPatternKind::Regex,
576                    Regex::new(pattern).context(InvalidRegexSnafu { pattern })?,
577                ),
578                Some(("~*", pattern)) => (
579                    NormalPatternKind::Regex,
580                    regex_utils::case_insensitive_regex(pattern)
581                        .context(InvalidRegexSnafu { pattern })?,
582                ),
583                _ => {
584                    // 对于默认的 Glob 模式
585                    let glob = Glob::new(&pattern).context(InvalidGlobSnafu)?;
586                    (
587                        NormalPatternKind::Glob,
588                        regex_utils::glob_to_regex(&glob).context(InvalidRegexSnafu {
589                            pattern: pattern.clone(),
590                        })?,
591                    )
592                }
593            };
594            Ok(Self {
595                kind,
596                regex,
597                pattern,
598            })
599        }
600    }
601}
602
603pub use parse_pattern::ParsePatternError;
604
605mod parse_location_pattern {
606    use super::{regex_utils, *};
607
608    /// 位置模式解析错误
609    ///
610    /// # 示例
611    ///
612    /// ```
613    /// use dhttp_access::pattern::{LocationPattern, ParseLocationPatternError};
614    ///
615    /// // 未知符号
616    /// let result: Result<LocationPattern, _> = "@ /invalid".parse();
617    /// assert!(matches!(result, Err(ParseLocationPatternError::UnknownSymbol { .. })));
618    ///
619    /// // 无效的正则表达式
620    /// let result: Result<LocationPattern, _> = "~ [".parse();
621    /// assert!(matches!(result, Err(ParseLocationPatternError::InvalidRegex { .. })));
622    ///
623    /// // 未定义的前缀或通用模式
624    /// let result: Result<LocationPattern, _> = "invalid".parse();
625    /// assert!(matches!(result, Err(ParseLocationPatternError::UndefinedPrefixOrCommon { .. })));
626    /// ```
627    #[derive(snafu::Snafu, Debug, Clone)]
628    pub enum ParseLocationPatternError {
629        /// 未知的符号
630        ///
631        /// # Examples
632        ///
633        /// ```
634        /// use dhttp_access::pattern::{LocationPattern, ParseLocationPatternError};
635        ///
636        /// let result: Result<LocationPattern, _> = "@ /invalid".parse();
637        /// assert!(matches!(result, Err(ParseLocationPatternError::UnknownSymbol { .. })));
638        ///
639        /// let result: Result<LocationPattern, _> = "! /bad".parse();
640        /// assert!(matches!(result, Err(ParseLocationPatternError::UnknownSymbol { .. })));
641        /// ```
642        #[snafu(display("unknown symbol `{symbol}`, expected one of {expect:?}"))]
643        UnknownSymbol {
644            symbol: String,
645            expect: &'static [&'static str],
646        },
647
648        /// 无效的正则表达式
649        ///
650        /// # Examples
651        ///
652        /// ```
653        /// use dhttp_access::pattern::{LocationPattern, ParseLocationPatternError};
654        ///
655        /// let result: Result<LocationPattern, _> = "~ [invalid".parse();
656        /// assert!(matches!(result, Err(ParseLocationPatternError::InvalidRegex { .. })));
657        ///
658        /// let result: Result<LocationPattern, _> = "~* (?P<bad".parse();
659        /// assert!(matches!(result, Err(ParseLocationPatternError::InvalidRegex { .. })));
660        /// ```
661        #[snafu(display("invalid regex pattern `{pattern}`"))]
662        InvalidRegex {
663            pattern: Arc<str>,
664            source: RegexError,
665        },
666
667        /// 未定义的前缀或通用模式
668        ///
669        /// # Examples
670        ///
671        /// ```
672        /// use dhttp_access::pattern::{LocationPattern, ParseLocationPatternError};
673        ///
674        /// let result: Result<LocationPattern, _> = "invalid".parse();
675        /// assert!(matches!(result, Err(ParseLocationPatternError::UndefinedPrefixOrCommon { .. })));
676        ///
677        /// let result: Result<LocationPattern, _> = "not_starting_with_slash".parse();
678        /// assert!(matches!(result, Err(ParseLocationPatternError::UndefinedPrefixOrCommon { .. })));
679        /// ```
680        #[snafu(display("expected common pattern or normal prefix starting with `{prefix}`"))]
681        UndefinedPrefixOrCommon { prefix: &'static str },
682    }
683
684    impl FromStr for Pattern<LocationPatternKind> {
685        type Err = ParseLocationPatternError;
686
687        fn from_str(pattern: &str) -> Result<Self, Self::Err> {
688            let pattern: Arc<str> = Arc::from(pattern);
689            let (kind, regex) = match pattern.split_once(' ') {
690                None if pattern.as_ref() == "/" => (
691                    LocationPatternKind::Common,
692                    Regex::new(r"^/").context(InvalidRegexSnafu {
693                        pattern: pattern.clone(),
694                    })?,
695                ),
696                None if pattern.starts_with("/") => (
697                    LocationPatternKind::NormalPrefix,
698                    Regex::new(format!("^{}", regex::escape(&pattern)).as_str()).context(
699                        InvalidRegexSnafu {
700                            pattern: pattern.clone(),
701                        },
702                    )?,
703                ),
704                None => return UndefinedPrefixOrCommonSnafu { prefix: "/" }.fail(),
705                Some(("=", pattern)) => (
706                    LocationPatternKind::Exact,
707                    Regex::new(&format!("^{}$", regex::escape(pattern)))
708                        .context(InvalidRegexSnafu { pattern })?,
709                ),
710                Some(("^~", pattern)) => (
711                    LocationPatternKind::Prefix,
712                    Regex::new(format!("^{}", regex::escape(pattern)).as_str())
713                        .context(InvalidRegexSnafu { pattern })?,
714                ),
715                Some(("~", pattern)) => (
716                    LocationPatternKind::Regex,
717                    Regex::new(pattern).context(InvalidRegexSnafu { pattern })?,
718                ),
719                Some(("~*", pattern)) => (
720                    LocationPatternKind::Regex,
721                    regex_utils::case_insensitive_regex(pattern)
722                        .context(InvalidRegexSnafu { pattern })?,
723                ),
724                Some((symbol, ..)) => {
725                    return UnknownSymbolSnafu::fail(UnknownSymbolSnafu {
726                        symbol: symbol.to_string(),
727                        expect: &["=", "^~", "~", "~*"] as &'static [&'static str],
728                    });
729                }
730            };
731            Ok(Self {
732                kind,
733                regex,
734                pattern,
735            })
736        }
737    }
738}
739
740pub use parse_location_pattern::ParseLocationPatternError;
741
742impl<Kind> Display for Pattern<Kind> {
743    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744        write!(f, "{}", self.as_str())
745    }
746}
747
748impl<Kind> Serialize for Pattern<Kind> {
749    #[inline]
750    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
751    where
752        S: serde::Serializer,
753    {
754        self.as_str().serialize(serializer)
755    }
756}
757
758impl<'de, Kind> Deserialize<'de> for Pattern<Kind>
759where
760    Self: FromStr<Err: Display>,
761{
762    #[inline]
763    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
764    where
765        D: serde::Deserializer<'de>,
766    {
767        let s = String::deserialize(deserializer)?;
768        s.parse().map_err(serde::de::Error::custom)
769    }
770}
771
772impl<Kind: PartialEq> PartialEq for Pattern<Kind> {
773    fn eq(&self, other: &Self) -> bool {
774        self.kind == other.kind && self.pattern == other.pattern
775    }
776}
777
778impl<Kind: Eq> Eq for Pattern<Kind> {}
779
780impl<Kind: PartialOrd> PartialOrd for Pattern<Kind> {
781    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
782        self.kind
783            .partial_cmp(&other.kind)
784            .map(|ord| ord.then_with(|| self.pattern.cmp(&other.pattern)))
785    }
786}
787
788impl<Kind: Ord> Ord for Pattern<Kind> {
789    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
790        self.kind
791            .cmp(&other.kind)
792            .then_with(|| self.pattern.cmp(&other.pattern))
793    }
794}
795
796#[cfg(test)]
797mod dhttp_suffix_tests {
798    use super::*;
799
800    #[test]
801    fn client_name_pattern_uses_dhttp_name_suffix() {
802        let pattern = Pattern::<ClientNamePatternKind>::new("~ ^reimu\\.pilot$")
803            .expect("valid client name pattern");
804
805        assert!(pattern.is_match("reimu.pilot.dhttp.net"));
806        assert!(!pattern.is_match("reimu.pilot.genmeta.net"));
807    }
808}