rama_net/address/
domain.rs

1use super::Host;
2use rama_core::error::{ErrorContext, OpaqueError};
3use smol_str::SmolStr;
4use std::{cmp::Ordering, fmt, iter::repeat};
5
6/// A domain.
7///
8/// # Remarks
9///
10/// The validation of domains created by this type is very shallow.
11/// Proper validation is offloaded to other services such as DNS resolvers.
12#[derive(Debug, Clone)]
13pub struct Domain(SmolStr);
14
15impl Domain {
16    /// Creates a domain at compile time.
17    ///
18    /// This function requires the static string to be a valid domain
19    ///
20    /// # Panics
21    ///
22    /// This function panics at **compile time** when the static string is not a valid domain.
23    pub const fn from_static(s: &'static str) -> Self {
24        if !is_valid_name(s.as_bytes()) {
25            panic!("static str is an invalid domain");
26        }
27        Self(SmolStr::new_static(s))
28    }
29
30    /// Creates the example [`Domain].
31    pub fn example() -> Self {
32        Self::from_static("example.com")
33    }
34
35    /// Create an new apex [`Domain`] (TLD) meant for loopback purposes.
36    ///
37    /// As proposed in
38    /// <https://itp.cdn.icann.org/en/files/security-and-stability-advisory-committee-ssac-reports/sac-113-en.pdf>.
39    ///
40    /// In specific this means that it will match on any domain with the TLD `.internal`.
41    pub fn tld_private() -> Self {
42        Self::from_static("internal")
43    }
44
45    /// Creates the localhost [`Domain`].
46    pub fn tld_localhost() -> Self {
47        Self::from_static("localhost")
48    }
49
50    /// Consumes the domain as a host.
51    pub fn into_host(self) -> Host {
52        Host::Name(self)
53    }
54
55    /// Returns `true` if this domain is a Fully Qualified Domain Name.
56    pub fn is_fqdn(&self) -> bool {
57        self.0.ends_with('.')
58    }
59
60    /// Returns `true` if this [`Domain`] is a parent of the other.
61    ///
62    /// Note that a [`Domain`] is a sub of itself.
63    pub fn is_sub_of(&self, other: &Domain) -> bool {
64        let a = self.as_ref().trim_matches('.');
65        let b = other.as_ref().trim_matches('.');
66        match a.len().cmp(&b.len()) {
67            Ordering::Equal => a.eq_ignore_ascii_case(b),
68            Ordering::Greater => {
69                let n = a.len() - b.len();
70                let dot_char = a.chars().nth(n - 1);
71                let host_parent = &a[n..];
72                dot_char == Some('.') && b.eq_ignore_ascii_case(host_parent)
73            }
74            Ordering::Less => false,
75        }
76    }
77
78    #[inline]
79    /// Returns `true` if this [`Domain`] is a subdomain of the other.
80    ///
81    /// Note that a [`Domain`] is a sub of itself.
82    pub fn is_parent_of(&self, other: &Domain) -> bool {
83        other.is_sub_of(self)
84    }
85
86    /// Compare the registrable domain
87    ///
88    /// # Example
89    ///
90    /// ```
91    /// use rama_net::address::Domain;
92    ///
93    /// assert!(Domain::from_static("www.example.com")
94    ///     .have_same_registrable_domain(&Domain::from_static("example.com")));
95    ///
96    /// assert!(Domain::from_static("example.com")
97    ///     .have_same_registrable_domain(&Domain::from_static("www.example.com")));
98    ///
99    /// assert!(Domain::from_static("a.example.com")
100    ///     .have_same_registrable_domain(&Domain::from_static("b.example.com")));
101    ///
102    /// assert!(Domain::from_static("example.com")
103    ///     .have_same_registrable_domain(&Domain::from_static("example.com")));
104    /// ```
105    pub fn have_same_registrable_domain(&self, other: &Domain) -> bool {
106        let this_rd = psl::domain_str(self.as_str());
107        let other_rd = psl::domain_str(other.as_str());
108        this_rd == other_rd
109    }
110
111    /// Get the public suffix of the domain
112    ///
113    /// # Example
114    ///
115    /// ```
116    /// use rama_net::address::Domain;
117    ///
118    /// assert_eq!(Some("com"), Domain::from_static("www.example.com").suffix());
119    /// assert_eq!(Some("co.uk"), Domain::from_static("site.co.uk").suffix());
120    /// ```
121    pub fn suffix(&self) -> Option<&str> {
122        psl::suffix_str(self.as_str())
123    }
124
125    /// Gets the domain name as reference.
126    pub fn as_str(&self) -> &str {
127        self.as_ref()
128    }
129
130    /// Returns the domain name inner value.
131    ///
132    /// Should not be exposed in the public rama API.
133    pub(crate) fn into_inner(self) -> SmolStr {
134        self.0
135    }
136}
137
138impl std::hash::Hash for Domain {
139    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
140        let this = self.as_ref();
141        let this = this.strip_prefix('.').unwrap_or(this);
142        for b in this.bytes() {
143            let b = b.to_ascii_lowercase();
144            b.hash(state);
145        }
146    }
147}
148
149impl AsRef<str> for Domain {
150    fn as_ref(&self) -> &str {
151        self.0.as_str()
152    }
153}
154
155impl fmt::Display for Domain {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
157        self.0.fmt(f)
158    }
159}
160
161impl std::str::FromStr for Domain {
162    type Err = OpaqueError;
163
164    fn from_str(s: &str) -> Result<Self, Self::Err> {
165        Domain::try_from(s.to_owned())
166    }
167}
168
169impl TryFrom<String> for Domain {
170    type Error = OpaqueError;
171
172    fn try_from(name: String) -> Result<Self, Self::Error> {
173        if is_valid_name(name.as_bytes()) {
174            Ok(Self(SmolStr::new(name)))
175        } else {
176            Err(OpaqueError::from_display("invalid domain"))
177        }
178    }
179}
180
181impl<'a> TryFrom<&'a [u8]> for Domain {
182    type Error = OpaqueError;
183
184    fn try_from(name: &'a [u8]) -> Result<Self, Self::Error> {
185        if is_valid_name(name) {
186            Ok(Self(SmolStr::new(
187                std::str::from_utf8(name).context("convert domain bytes to utf-8 string")?,
188            )))
189        } else {
190            Err(OpaqueError::from_display("invalid domain"))
191        }
192    }
193}
194
195impl TryFrom<Vec<u8>> for Domain {
196    type Error = OpaqueError;
197
198    fn try_from(name: Vec<u8>) -> Result<Self, Self::Error> {
199        if is_valid_name(name.as_slice()) {
200            Ok(Self(SmolStr::new(
201                String::from_utf8(name).context("convert domain bytes to utf-8 string")?,
202            )))
203        } else {
204            Err(OpaqueError::from_display("invalid domain"))
205        }
206    }
207}
208
209fn cmp_domain(a: impl AsRef<str>, b: impl AsRef<str>) -> Ordering {
210    let a = a.as_ref();
211    let a = a.strip_prefix('.').unwrap_or(a);
212    let a = a.bytes().map(Some).chain(repeat(None));
213
214    let b = b.as_ref();
215    let b = b.strip_prefix('.').unwrap_or(b);
216    let b = b.bytes().map(Some).chain(repeat(None));
217
218    a.zip(b)
219        .find_map(|(a, b)| match (a, b) {
220            (Some(a), Some(b)) => match a.to_ascii_lowercase().cmp(&b.to_ascii_lowercase()) {
221                Ordering::Greater => Some(Ordering::Greater),
222                Ordering::Less => Some(Ordering::Less),
223                Ordering::Equal => None,
224            },
225            (Some(_), None) => Some(Ordering::Greater),
226            (None, Some(_)) => Some(Ordering::Less),
227            (None, None) => Some(Ordering::Equal),
228        })
229        .unwrap() // should always be possible to find given we are in an infinite zip :)
230}
231
232impl PartialOrd<Domain> for Domain {
233    fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
234        Some(self.cmp(other))
235    }
236}
237
238impl Ord for Domain {
239    fn cmp(&self, other: &Self) -> Ordering {
240        cmp_domain(self, other)
241    }
242}
243
244impl PartialOrd<str> for Domain {
245    fn partial_cmp(&self, other: &str) -> Option<Ordering> {
246        Some(cmp_domain(self, other))
247    }
248}
249
250impl PartialOrd<Domain> for str {
251    fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
252        Some(cmp_domain(self, other))
253    }
254}
255
256impl PartialOrd<&str> for Domain {
257    fn partial_cmp(&self, other: &&str) -> Option<Ordering> {
258        Some(cmp_domain(self, other))
259    }
260}
261
262impl PartialOrd<Domain> for &str {
263    fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
264        Some(cmp_domain(self, other))
265    }
266}
267
268impl PartialOrd<String> for Domain {
269    fn partial_cmp(&self, other: &String) -> Option<Ordering> {
270        Some(cmp_domain(self, other))
271    }
272}
273
274impl PartialOrd<Domain> for String {
275    fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
276        Some(cmp_domain(self, other))
277    }
278}
279
280fn partial_eq_domain(a: impl AsRef<str>, b: impl AsRef<str>) -> bool {
281    let a = a.as_ref();
282    let a = a.strip_prefix('.').unwrap_or(a);
283
284    let b = b.as_ref();
285    let b = b.strip_prefix('.').unwrap_or(b);
286
287    a.eq_ignore_ascii_case(b)
288}
289
290impl PartialEq<Domain> for Domain {
291    fn eq(&self, other: &Domain) -> bool {
292        partial_eq_domain(self, other)
293    }
294}
295
296impl Eq for Domain {}
297
298impl PartialEq<str> for Domain {
299    fn eq(&self, other: &str) -> bool {
300        partial_eq_domain(self, other)
301    }
302}
303
304impl PartialEq<&str> for Domain {
305    fn eq(&self, other: &&str) -> bool {
306        partial_eq_domain(self, other)
307    }
308}
309
310impl PartialEq<Domain> for str {
311    fn eq(&self, other: &Domain) -> bool {
312        other == self
313    }
314}
315
316impl PartialEq<Domain> for &str {
317    fn eq(&self, other: &Domain) -> bool {
318        partial_eq_domain(self, other)
319    }
320}
321
322impl PartialEq<String> for Domain {
323    fn eq(&self, other: &String) -> bool {
324        partial_eq_domain(self, other)
325    }
326}
327
328impl PartialEq<Domain> for String {
329    fn eq(&self, other: &Domain) -> bool {
330        partial_eq_domain(self, other)
331    }
332}
333
334impl serde::Serialize for Domain {
335    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
336    where
337        S: serde::Serializer,
338    {
339        self.0.serialize(serializer)
340    }
341}
342
343impl<'de> serde::Deserialize<'de> for Domain {
344    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
345    where
346        D: serde::Deserializer<'de>,
347    {
348        let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
349        s.parse().map_err(serde::de::Error::custom)
350    }
351}
352
353impl Domain {
354    /// The maximum length of a domain label.
355    const MAX_LABEL_LEN: usize = 63;
356
357    /// The maximum length of a domain name.
358    const MAX_NAME_LEN: usize = 253;
359}
360
361const fn is_valid_label(name: &[u8], start: usize, stop: usize) -> bool {
362    if start >= stop
363        || stop - start > Domain::MAX_LABEL_LEN
364        || name[start] == b'-'
365        || start == stop
366        || name[stop - 1] == b'-'
367    {
368        false
369    } else {
370        let mut i = start;
371        while i < stop {
372            let c = name[i];
373            if !c.is_ascii_alphanumeric() && (c != b'-' || i == start) {
374                return false;
375            }
376            i += 1;
377        }
378        true
379    }
380}
381
382/// Checks if the domain name is valid.
383const fn is_valid_name(name: &[u8]) -> bool {
384    if name.is_empty() || name.len() > Domain::MAX_NAME_LEN {
385        false
386    } else {
387        let mut non_empty_groups = 0;
388        let mut i = 0;
389        let mut offset = 0;
390        while i < name.len() {
391            let c = name[i];
392            if c == b'.' {
393                if offset == i {
394                    // empty
395                    if i == 0 || i == name.len() - 1 {
396                        i += 1;
397                        offset = i + 1;
398                        continue;
399                    } else {
400                        // double dot not allowed
401                        return false;
402                    }
403                }
404                if !is_valid_label(name, offset, i) {
405                    return false;
406                }
407                offset = i + 1;
408                non_empty_groups += 1;
409            }
410            i += 1;
411        }
412        if offset == i {
413            non_empty_groups > 0
414        } else {
415            is_valid_label(name, offset, i)
416        }
417    }
418}
419
420#[cfg(test)]
421#[allow(clippy::expect_fun_call)]
422mod tests {
423    use super::*;
424    use std::collections::HashMap;
425
426    #[test]
427    fn test_specials() {
428        assert_eq!(Domain::tld_localhost(), "localhost");
429        assert_eq!(Domain::tld_private(), "internal");
430        assert_eq!(Domain::example(), "example.com");
431    }
432
433    #[test]
434    fn test_domain_parse_valid() {
435        for str in [
436            "example.com",
437            "www.example.com",
438            "a-b-c.com",
439            "a-b-c.example.com",
440            "a-b-c.example",
441            "aA1",
442            ".example.com",
443            "example.com.",
444            ".example.com.",
445            "rr5---sn-q4fl6n6s.video.com", // multiple dashes
446            "127.0.0.1",
447        ] {
448            let msg = format!("to parse: {}", str);
449            assert_eq!(Domain::try_from(str.to_owned()).expect(msg.as_str()), str);
450            assert_eq!(
451                Domain::try_from(str.as_bytes().to_vec()).expect(msg.as_str()),
452                str
453            );
454        }
455    }
456
457    #[test]
458    fn test_domain_parse_invalid() {
459        for str in [
460            "",
461            ".",
462            "..",
463            "-",
464            ".-",
465            "-.",
466            ".-.",
467            "-.-.",
468            "-.-.-",
469            ".-.-",
470            "2001:db8:3333:4444:5555:6666:7777:8888",
471            "-example.com",
472            "local!host",
473            "thislabeliswaytoolongforbeingeversomethingwewishtocareabout-example.com",
474            "example-thislabeliswaytoolongforbeingeversomethingwewishtocareabout.com",
475            "こんにちは",
476            "こんにちは.com",
477            "😀",
478            "example..com",
479            "example dot com",
480            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
481        ] {
482            assert!(Domain::try_from(str.to_owned()).is_err());
483            assert!(Domain::try_from(str.as_bytes().to_vec()).is_err());
484        }
485    }
486
487    #[test]
488    fn is_parent() {
489        let test_cases = vec![
490            ("www.example.com", "www.example.com"),
491            ("www.example.com", "www.example.com."),
492            ("www.example.com", ".www.example.com."),
493            (".www.example.com", "www.example.com"),
494            (".www.example.com", "www.example.com."),
495            (".www.example.com.", "www.example.com."),
496            ("www.example.com", "WwW.ExamplE.COM"),
497            ("example.com", "www.example.com"),
498            ("example.com", "m.example.com"),
499            ("example.com", "www.EXAMPLE.com"),
500            ("example.com", "M.example.com"),
501        ];
502        for (a, b) in test_cases.into_iter() {
503            let a = Domain::from_static(a);
504            let b = Domain::from_static(b);
505            assert!(a.is_parent_of(&b), "({:?}).is_parent_of({})", a, b);
506        }
507    }
508
509    #[test]
510    fn is_not_parent() {
511        let test_cases = vec![
512            ("www.example.com", "www.example.co"),
513            ("www.example.com", "www.ejemplo.com"),
514            ("www.example.com", "www3.example.com"),
515            ("w.example.com", "www.example.com"),
516            ("gel.com", "kegel.com"),
517        ];
518        for (a, b) in test_cases.into_iter() {
519            let a = Domain::from_static(a);
520            let b = Domain::from_static(b);
521            assert!(!a.is_parent_of(&b), "!({:?}).is_parent_of({})", a, b);
522        }
523    }
524
525    #[test]
526    fn is_equal() {
527        let test_cases = vec![
528            ("example.com", "example.com"),
529            ("example.com", "EXAMPLE.com"),
530            (".example.com", ".example.com"),
531            (".example.com", "example.com"),
532            ("example.com", ".example.com"),
533        ];
534        for (a, b) in test_cases.into_iter() {
535            assert_eq!(Domain::from_static(a), b);
536            assert_eq!(Domain::from_static(a), b.to_owned());
537            assert_eq!(Domain::from_static(a), Domain::from_static(b));
538            assert_eq!(a, Domain::from_static(b));
539            assert_eq!(a.to_owned(), Domain::from_static(b));
540        }
541    }
542
543    #[test]
544    fn is_not_equal() {
545        let test_cases = vec![
546            ("example.com", "localhost"),
547            ("example.com", "example.com."),
548            ("example.com", "example.co"),
549            ("example.com", "examine.com"),
550            ("example.com", "example.com.us"),
551            ("example.com", "www.example.com"),
552        ];
553        for (a, b) in test_cases.into_iter() {
554            assert_ne!(Domain::from_static(a), b);
555            assert_ne!(Domain::from_static(a), b.to_owned());
556            assert_ne!(Domain::from_static(a), Domain::from_static(b));
557            assert_ne!(a, Domain::from_static(b));
558            assert_ne!(a.to_owned(), Domain::from_static(b));
559        }
560    }
561
562    #[test]
563    fn cmp() {
564        let test_cases = vec![
565            ("example.com", "example.com", Ordering::Equal),
566            ("example.com", "EXAMPLE.com", Ordering::Equal),
567            (".example.com", ".example.com", Ordering::Equal),
568            (".example.com", "example.com", Ordering::Equal),
569            ("example.com", ".example.com", Ordering::Equal),
570            ("example.com", "localhost", Ordering::Less),
571            ("example.com", "example.com.", Ordering::Less),
572            ("example.com", "example.co", Ordering::Greater),
573            ("example.com", "examine.com", Ordering::Greater),
574            ("example.com", "example.com.us", Ordering::Less),
575            ("example.com", "www.example.com", Ordering::Less),
576        ];
577        for (a, b, expected) in test_cases.into_iter() {
578            assert_eq!(Some(expected), Domain::from_static(a).partial_cmp(&b));
579            assert_eq!(
580                Some(expected),
581                Domain::from_static(a).partial_cmp(&b.to_owned())
582            );
583            assert_eq!(
584                Some(expected),
585                Domain::from_static(a).partial_cmp(&Domain::from_static(b))
586            );
587            assert_eq!(
588                expected,
589                Domain::from_static(a).cmp(&Domain::from_static(b))
590            );
591            assert_eq!(Some(expected), a.partial_cmp(&Domain::from_static(b)));
592            assert_eq!(
593                Some(expected),
594                a.to_owned().partial_cmp(&Domain::from_static(b))
595            );
596        }
597    }
598
599    #[test]
600    fn test_hash() {
601        let mut m = HashMap::new();
602
603        assert!(!m.contains_key(&Domain::from_static("example.com")));
604        assert!(!m.contains_key(&Domain::from_static("EXAMPLE.COM")));
605        assert!(!m.contains_key(&Domain::from_static(".example.com")));
606        assert!(!m.contains_key(&Domain::from_static(".example.COM")));
607
608        m.insert(Domain::from_static("eXaMpLe.COm"), ());
609
610        assert!(m.contains_key(&Domain::from_static("example.com")));
611        assert!(m.contains_key(&Domain::from_static("EXAMPLE.COM")));
612        assert!(m.contains_key(&Domain::from_static(".example.com")));
613        assert!(m.contains_key(&Domain::from_static(".example.COM")));
614
615        assert!(!m.contains_key(&Domain::from_static("www.example.com")));
616        assert!(!m.contains_key(&Domain::from_static("examine.com")));
617        assert!(!m.contains_key(&Domain::from_static("example.com.")));
618        assert!(!m.contains_key(&Domain::from_static("example.co")));
619        assert!(!m.contains_key(&Domain::from_static("example.commerce")));
620    }
621}