kubizone_common/
pattern.rs

1use std::fmt::{Display, Write};
2
3use schemars::JsonSchema;
4use serde::{de::Error, Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::{segment::DomainSegment, FullyQualifiedDomainName};
8
9#[derive(Error, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
10pub enum PatternError {}
11
12#[derive(Default, Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub struct Pattern(Vec<PatternSegment>);
14
15impl Pattern {
16    /// Returns a pattern that only matches the origin of the parent
17    /// FQDN.
18    pub fn origin() -> Self {
19        Pattern::default()
20    }
21
22    /// Iterates over the [`PatternSegment`]s of the pattern.
23    pub fn iter(&self) -> impl Iterator<Item = &PatternSegment> + '_ {
24        self.0.iter()
25    }
26
27    /// Returns a new pattern with the origin appended.
28    pub fn with_origin(&self, origin: &FullyQualifiedDomainName) -> Pattern {
29        let mut cloned = self.clone();
30        cloned.0.extend(origin.iter().map(PatternSegment::from));
31        cloned
32    }
33
34    /// Returns true if the papttern matches the given domain.
35    pub fn matches(&self, domain: &FullyQualifiedDomainName) -> bool {
36        let domain_segments = domain.as_ref().iter().rev();
37        let pattern_segments = self.0[..].iter().rev();
38
39        if domain_segments.len() < pattern_segments.len() {
40            // Patterns longer than the domain segment cannot possibly match.
41            return false;
42        }
43
44        if domain_segments.len() > pattern_segments.len()
45            // Domains longer than patterns can never match, unless the first
46            // segment of the pattern is a standalone wildcard (*)
47            && !self.0.first().is_some_and(|pattern| pattern.as_ref() == "*")
48        {
49            return false;
50        }
51
52        for (pattern, domain) in pattern_segments.zip(domain_segments) {
53            // If we have hit a pattern segment containing only a wildcard, the rest of the
54            // domain segments are automatically matched.
55            if pattern.as_ref() == "*" {
56                return true;
57            }
58
59            if !pattern.matches(domain) {
60                return false;
61            }
62        }
63
64        true
65    }
66}
67
68impl FromIterator<PatternSegment> for Pattern {
69    fn from_iter<T: IntoIterator<Item = PatternSegment>>(iter: T) -> Self {
70        Pattern(iter.into_iter().collect())
71    }
72}
73
74impl TryFrom<&str> for Pattern {
75    type Error = PatternSegmentError;
76
77    fn try_from(value: &str) -> Result<Self, Self::Error> {
78        let segments = Result::from_iter(
79            value
80                .trim_end_matches('.')
81                .split('.')
82                .map(PatternSegment::try_from),
83        )?;
84        Ok(Pattern(segments))
85    }
86}
87
88impl TryFrom<String> for Pattern {
89    type Error = PatternSegmentError;
90
91    fn try_from(value: String) -> Result<Self, Self::Error> {
92        Self::try_from(value.as_ref())
93    }
94}
95
96impl Display for Pattern {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        for segment in &self.0 {
99            write!(f, "{}", segment)?;
100            f.write_char('.')?;
101        }
102
103        Ok(())
104    }
105}
106
107impl JsonSchema for Pattern {
108    fn schema_name() -> String {
109        <String as schemars::JsonSchema>::schema_name()
110    }
111
112    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
113        <String as schemars::JsonSchema>::json_schema(gen)
114    }
115}
116
117impl<'de> Deserialize<'de> for Pattern {
118    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
119    where
120        D: serde::Deserializer<'de>,
121    {
122        let value = String::deserialize(deserializer)?;
123
124        Self::try_from(value).map_err(D::Error::custom)
125    }
126}
127
128impl Serialize for Pattern {
129    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
130    where
131        S: serde::Serializer,
132    {
133        self.to_string().serialize(serializer)
134    }
135}
136
137/// Segment of a pattern.
138///
139/// Used for matching against a single [`DomainSegment`].
140#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
141pub struct PatternSegment(String);
142
143impl PatternSegment {
144    /// Returns true if the pattern segment matches the provided domain segment.
145    pub fn matches(&self, domain_segment: &DomainSegment) -> bool {
146        if self.0 == domain_segment.as_ref() {
147            return true;
148        }
149
150        if let Some((head, tail)) = self.0.split_once('*') {
151            return domain_segment.as_ref().starts_with(head)
152                && domain_segment.as_ref().ends_with(tail);
153        }
154
155        false
156    }
157
158    // Segments cannot be empty.
159    #[allow(clippy::len_without_is_empty)]
160    pub fn len(&self) -> usize {
161        self.0.len()
162    }
163}
164
165/// Produced when attempting to construct a [`PatternSegment`]
166/// from an invalid string.
167#[derive(Error, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
168pub enum PatternSegmentError {
169    /// Domain name segments (and therefore pattern segments)
170    /// can contain hyphens, but crucially:
171    ///
172    /// * Not at the beginning of a segment.
173    /// * Not at the end of a segment.
174    /// * Not at the 3rd and 4th position *simultaneously* (used for [Punycode encoding](https://en.wikipedia.org/wiki/Punycode))
175    #[error("illegal hyphen at position {0}")]
176    IllegalHyphen(usize),
177    /// Segment contains invalid character.
178    #[error("invalid character {0}")]
179    InvalidCharacter(char),
180    /// Domain segment is longer than the permitted 63 characters.
181    #[error("pattern too long {0} > 63")]
182    TooLong(usize),
183    /// Domain segment is empty.
184    #[error("pattern is an empty string")]
185    EmptyString,
186    /// Pattern contains more than one wildcard (*) character.
187    #[error("patterns can only have one wildcard")]
188    MultipleWildcards,
189}
190
191const VALID_CHARACTERS: &str = "_-0123456789abcdefghijklmnopqrstuvwxyz*";
192
193impl TryFrom<&str> for PatternSegment {
194    type Error = PatternSegmentError;
195
196    fn try_from(value: &str) -> Result<Self, Self::Error> {
197        let value = value.to_ascii_lowercase();
198
199        if value.is_empty() {
200            return Err(PatternSegmentError::EmptyString);
201        }
202
203        if value.len() > 63 {
204            return Err(PatternSegmentError::TooLong(value.len()));
205        }
206
207        if let Some(character) = value.chars().find(|c| !VALID_CHARACTERS.contains(*c)) {
208            return Err(PatternSegmentError::InvalidCharacter(character));
209        }
210
211        if value.starts_with('-') {
212            return Err(PatternSegmentError::IllegalHyphen(1));
213        }
214
215        if value.ends_with('-') {
216            return Err(PatternSegmentError::IllegalHyphen(value.len()));
217        }
218
219        if value.get(2..4) == Some("--") {
220            return Err(PatternSegmentError::IllegalHyphen(3));
221        }
222
223        if value.chars().filter(|c| *c == '*').count() > 1 {
224            return Err(PatternSegmentError::MultipleWildcards);
225        }
226
227        Ok(PatternSegment(value))
228    }
229}
230
231impl From<DomainSegment> for PatternSegment {
232    fn from(value: DomainSegment) -> Self {
233        PatternSegment(value.to_string())
234    }
235}
236
237impl From<&DomainSegment> for PatternSegment {
238    fn from(value: &DomainSegment) -> Self {
239        PatternSegment(value.to_string())
240    }
241}
242
243impl TryFrom<String> for PatternSegment {
244    type Error = PatternSegmentError;
245
246    fn try_from(value: String) -> Result<Self, Self::Error> {
247        Self::try_from(value.as_str())
248    }
249}
250
251impl Display for PatternSegment {
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        f.write_str(&self.0)
254    }
255}
256
257impl AsRef<str> for PatternSegment {
258    fn as_ref(&self) -> &str {
259        self.0.as_str()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use crate::{
266        error::PatternSegmentError, pattern::PatternSegment, segment::DomainSegment,
267        FullyQualifiedDomainName, Pattern,
268    };
269
270    #[test]
271    fn literal_matches() {
272        assert!(PatternSegment::try_from("example")
273            .unwrap()
274            .matches(&DomainSegment::try_from("example").unwrap()))
275    }
276
277    #[test]
278    fn wildcard() {
279        assert!(PatternSegment::try_from("*")
280            .unwrap()
281            .matches(&DomainSegment::try_from("example").unwrap()))
282    }
283
284    #[test]
285    fn leading_wildcard() {
286        assert!(PatternSegment::try_from("*ample")
287            .unwrap()
288            .matches(&DomainSegment::try_from("example").unwrap()))
289    }
290
291    #[test]
292    fn trailing_wildcard() {
293        assert!(PatternSegment::try_from("examp*")
294            .unwrap()
295            .matches(&DomainSegment::try_from("example").unwrap()))
296    }
297
298    #[test]
299    fn splitting_wildcard() {
300        assert!(PatternSegment::try_from("ex*le")
301            .unwrap()
302            .matches(&DomainSegment::try_from("example").unwrap()))
303    }
304
305    #[test]
306    fn multiple_wildcards() {
307        assert_eq!(
308            PatternSegment::try_from("*amp*"),
309            Err(PatternSegmentError::MultipleWildcards)
310        );
311    }
312
313    #[test]
314    fn simple_pattern_match() {
315        assert!(Pattern::try_from("*.example.org")
316            .unwrap()
317            .matches(&FullyQualifiedDomainName::try_from("www.example.org.").unwrap()));
318    }
319
320    #[test]
321    fn longer_pattern_than_domain() {
322        assert!(!Pattern::try_from("*.*.example.org")
323            .unwrap()
324            .matches(&FullyQualifiedDomainName::try_from("www.example.org.").unwrap()));
325    }
326
327    #[test]
328    fn longer_domain_than_pattern() {
329        assert!(Pattern::try_from("*.example.org").unwrap().matches(
330            &FullyQualifiedDomainName::try_from("www.sub.test.dev.example.org.").unwrap()
331        ));
332    }
333
334    #[test]
335    fn wildcard_segments() {
336        let pattern = Pattern::try_from("dev*.example.org").unwrap();
337
338        assert!(pattern.matches(&FullyQualifiedDomainName::try_from("dev.example.org.").unwrap()));
339        assert!(pattern.matches(&FullyQualifiedDomainName::try_from("dev-1.example.org.").unwrap()));
340        assert!(
341            pattern.matches(&FullyQualifiedDomainName::try_from("dev-hello.example.org.").unwrap())
342        );
343        assert!(!pattern.matches(&FullyQualifiedDomainName::try_from("de.example.org.").unwrap()));
344        assert!(!pattern
345            .matches(&FullyQualifiedDomainName::try_from("www.dev-1.example.org.").unwrap()));
346    }
347
348    #[test]
349    fn patterns_assumed_wildcard() {
350        let fqdn = Pattern::try_from("example.org.").unwrap();
351        let pqdn = Pattern::try_from("example.org").unwrap();
352        assert_eq!(fqdn, pqdn);
353
354        assert_eq!(
355            fqdn.matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap()),
356            pqdn.matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap())
357        );
358    }
359
360    #[test]
361    fn origin_insertion() {
362        let pattern = Pattern::try_from("example").unwrap();
363
364        let domain = FullyQualifiedDomainName::try_from("example.org.").unwrap();
365
366        assert!(!pattern.matches(&domain));
367
368        assert!(pattern
369            .with_origin(&FullyQualifiedDomainName::try_from("org.").unwrap())
370            .matches(&FullyQualifiedDomainName::try_from("example.org.").unwrap()));
371    }
372}