1use super::Host;
2use rama_core::error::{ErrorContext, OpaqueError};
3use smol_str::SmolStr;
4use std::{cmp::Ordering, fmt, iter::repeat};
5
6#[derive(Debug, Clone)]
13pub struct Domain(SmolStr);
14
15impl Domain {
16 pub const fn from_static(s: &'static str) -> Self {
24 if !is_valid_name(s.as_bytes()) {
25 panic!("static str is an invalid domain");
26 }
27 Self(SmolStr::new_static(s))
28 }
29
30 pub fn example() -> Self {
32 Self::from_static("example.com")
33 }
34
35 pub fn tld_private() -> Self {
42 Self::from_static("internal")
43 }
44
45 pub fn tld_localhost() -> Self {
47 Self::from_static("localhost")
48 }
49
50 pub fn into_host(self) -> Host {
52 Host::Name(self)
53 }
54
55 pub fn is_fqdn(&self) -> bool {
57 self.0.ends_with('.')
58 }
59
60 pub fn is_sub_of(&self, other: &Domain) -> bool {
64 let a = self.as_ref().trim_matches('.');
65 let b = other.as_ref().trim_matches('.');
66 match a.len().cmp(&b.len()) {
67 Ordering::Equal => a.eq_ignore_ascii_case(b),
68 Ordering::Greater => {
69 let n = a.len() - b.len();
70 let dot_char = a.chars().nth(n - 1);
71 let host_parent = &a[n..];
72 dot_char == Some('.') && b.eq_ignore_ascii_case(host_parent)
73 }
74 Ordering::Less => false,
75 }
76 }
77
78 #[inline]
79 pub fn is_parent_of(&self, other: &Domain) -> bool {
83 other.is_sub_of(self)
84 }
85
86 pub fn have_same_registrable_domain(&self, other: &Domain) -> bool {
106 let this_rd = psl::domain_str(self.as_str());
107 let other_rd = psl::domain_str(other.as_str());
108 this_rd == other_rd
109 }
110
111 pub fn suffix(&self) -> Option<&str> {
122 psl::suffix_str(self.as_str())
123 }
124
125 pub fn as_str(&self) -> &str {
127 self.as_ref()
128 }
129
130 pub(crate) fn into_inner(self) -> SmolStr {
134 self.0
135 }
136}
137
138impl std::hash::Hash for Domain {
139 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
140 let this = self.as_ref();
141 let this = this.strip_prefix('.').unwrap_or(this);
142 for b in this.bytes() {
143 let b = b.to_ascii_lowercase();
144 b.hash(state);
145 }
146 }
147}
148
149impl AsRef<str> for Domain {
150 fn as_ref(&self) -> &str {
151 self.0.as_str()
152 }
153}
154
155impl fmt::Display for Domain {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
157 self.0.fmt(f)
158 }
159}
160
161impl std::str::FromStr for Domain {
162 type Err = OpaqueError;
163
164 fn from_str(s: &str) -> Result<Self, Self::Err> {
165 Domain::try_from(s.to_owned())
166 }
167}
168
169impl TryFrom<String> for Domain {
170 type Error = OpaqueError;
171
172 fn try_from(name: String) -> Result<Self, Self::Error> {
173 if is_valid_name(name.as_bytes()) {
174 Ok(Self(SmolStr::new(name)))
175 } else {
176 Err(OpaqueError::from_display("invalid domain"))
177 }
178 }
179}
180
181impl<'a> TryFrom<&'a [u8]> for Domain {
182 type Error = OpaqueError;
183
184 fn try_from(name: &'a [u8]) -> Result<Self, Self::Error> {
185 if is_valid_name(name) {
186 Ok(Self(SmolStr::new(
187 std::str::from_utf8(name).context("convert domain bytes to utf-8 string")?,
188 )))
189 } else {
190 Err(OpaqueError::from_display("invalid domain"))
191 }
192 }
193}
194
195impl TryFrom<Vec<u8>> for Domain {
196 type Error = OpaqueError;
197
198 fn try_from(name: Vec<u8>) -> Result<Self, Self::Error> {
199 if is_valid_name(name.as_slice()) {
200 Ok(Self(SmolStr::new(
201 String::from_utf8(name).context("convert domain bytes to utf-8 string")?,
202 )))
203 } else {
204 Err(OpaqueError::from_display("invalid domain"))
205 }
206 }
207}
208
209fn cmp_domain(a: impl AsRef<str>, b: impl AsRef<str>) -> Ordering {
210 let a = a.as_ref();
211 let a = a.strip_prefix('.').unwrap_or(a);
212 let a = a.bytes().map(Some).chain(repeat(None));
213
214 let b = b.as_ref();
215 let b = b.strip_prefix('.').unwrap_or(b);
216 let b = b.bytes().map(Some).chain(repeat(None));
217
218 a.zip(b)
219 .find_map(|(a, b)| match (a, b) {
220 (Some(a), Some(b)) => match a.to_ascii_lowercase().cmp(&b.to_ascii_lowercase()) {
221 Ordering::Greater => Some(Ordering::Greater),
222 Ordering::Less => Some(Ordering::Less),
223 Ordering::Equal => None,
224 },
225 (Some(_), None) => Some(Ordering::Greater),
226 (None, Some(_)) => Some(Ordering::Less),
227 (None, None) => Some(Ordering::Equal),
228 })
229 .unwrap() }
231
232impl PartialOrd<Domain> for Domain {
233 fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
234 Some(self.cmp(other))
235 }
236}
237
238impl Ord for Domain {
239 fn cmp(&self, other: &Self) -> Ordering {
240 cmp_domain(self, other)
241 }
242}
243
244impl PartialOrd<str> for Domain {
245 fn partial_cmp(&self, other: &str) -> Option<Ordering> {
246 Some(cmp_domain(self, other))
247 }
248}
249
250impl PartialOrd<Domain> for str {
251 fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
252 Some(cmp_domain(self, other))
253 }
254}
255
256impl PartialOrd<&str> for Domain {
257 fn partial_cmp(&self, other: &&str) -> Option<Ordering> {
258 Some(cmp_domain(self, other))
259 }
260}
261
262impl PartialOrd<Domain> for &str {
263 fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
264 Some(cmp_domain(self, other))
265 }
266}
267
268impl PartialOrd<String> for Domain {
269 fn partial_cmp(&self, other: &String) -> Option<Ordering> {
270 Some(cmp_domain(self, other))
271 }
272}
273
274impl PartialOrd<Domain> for String {
275 fn partial_cmp(&self, other: &Domain) -> Option<Ordering> {
276 Some(cmp_domain(self, other))
277 }
278}
279
280fn partial_eq_domain(a: impl AsRef<str>, b: impl AsRef<str>) -> bool {
281 let a = a.as_ref();
282 let a = a.strip_prefix('.').unwrap_or(a);
283
284 let b = b.as_ref();
285 let b = b.strip_prefix('.').unwrap_or(b);
286
287 a.eq_ignore_ascii_case(b)
288}
289
290impl PartialEq<Domain> for Domain {
291 fn eq(&self, other: &Domain) -> bool {
292 partial_eq_domain(self, other)
293 }
294}
295
296impl Eq for Domain {}
297
298impl PartialEq<str> for Domain {
299 fn eq(&self, other: &str) -> bool {
300 partial_eq_domain(self, other)
301 }
302}
303
304impl PartialEq<&str> for Domain {
305 fn eq(&self, other: &&str) -> bool {
306 partial_eq_domain(self, other)
307 }
308}
309
310impl PartialEq<Domain> for str {
311 fn eq(&self, other: &Domain) -> bool {
312 other == self
313 }
314}
315
316impl PartialEq<Domain> for &str {
317 fn eq(&self, other: &Domain) -> bool {
318 partial_eq_domain(self, other)
319 }
320}
321
322impl PartialEq<String> for Domain {
323 fn eq(&self, other: &String) -> bool {
324 partial_eq_domain(self, other)
325 }
326}
327
328impl PartialEq<Domain> for String {
329 fn eq(&self, other: &Domain) -> bool {
330 partial_eq_domain(self, other)
331 }
332}
333
334impl serde::Serialize for Domain {
335 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
336 where
337 S: serde::Serializer,
338 {
339 self.0.serialize(serializer)
340 }
341}
342
343impl<'de> serde::Deserialize<'de> for Domain {
344 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
345 where
346 D: serde::Deserializer<'de>,
347 {
348 let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
349 s.parse().map_err(serde::de::Error::custom)
350 }
351}
352
353impl Domain {
354 const MAX_LABEL_LEN: usize = 63;
356
357 const MAX_NAME_LEN: usize = 253;
359}
360
361const fn is_valid_label(name: &[u8], start: usize, stop: usize) -> bool {
362 if start >= stop
363 || stop - start > Domain::MAX_LABEL_LEN
364 || name[start] == b'-'
365 || start == stop
366 || name[stop - 1] == b'-'
367 {
368 false
369 } else {
370 let mut i = start;
371 while i < stop {
372 let c = name[i];
373 if !c.is_ascii_alphanumeric() && (c != b'-' || i == start) {
374 return false;
375 }
376 i += 1;
377 }
378 true
379 }
380}
381
382const fn is_valid_name(name: &[u8]) -> bool {
384 if name.is_empty() || name.len() > Domain::MAX_NAME_LEN {
385 false
386 } else {
387 let mut non_empty_groups = 0;
388 let mut i = 0;
389 let mut offset = 0;
390 while i < name.len() {
391 let c = name[i];
392 if c == b'.' {
393 if offset == i {
394 if i == 0 || i == name.len() - 1 {
396 i += 1;
397 offset = i + 1;
398 continue;
399 } else {
400 return false;
402 }
403 }
404 if !is_valid_label(name, offset, i) {
405 return false;
406 }
407 offset = i + 1;
408 non_empty_groups += 1;
409 }
410 i += 1;
411 }
412 if offset == i {
413 non_empty_groups > 0
414 } else {
415 is_valid_label(name, offset, i)
416 }
417 }
418}
419
420#[cfg(test)]
421#[allow(clippy::expect_fun_call)]
422mod tests {
423 use super::*;
424 use std::collections::HashMap;
425
426 #[test]
427 fn test_specials() {
428 assert_eq!(Domain::tld_localhost(), "localhost");
429 assert_eq!(Domain::tld_private(), "internal");
430 assert_eq!(Domain::example(), "example.com");
431 }
432
433 #[test]
434 fn test_domain_parse_valid() {
435 for str in [
436 "example.com",
437 "www.example.com",
438 "a-b-c.com",
439 "a-b-c.example.com",
440 "a-b-c.example",
441 "aA1",
442 ".example.com",
443 "example.com.",
444 ".example.com.",
445 "rr5---sn-q4fl6n6s.video.com", "127.0.0.1",
447 ] {
448 let msg = format!("to parse: {}", str);
449 assert_eq!(Domain::try_from(str.to_owned()).expect(msg.as_str()), str);
450 assert_eq!(
451 Domain::try_from(str.as_bytes().to_vec()).expect(msg.as_str()),
452 str
453 );
454 }
455 }
456
457 #[test]
458 fn test_domain_parse_invalid() {
459 for str in [
460 "",
461 ".",
462 "..",
463 "-",
464 ".-",
465 "-.",
466 ".-.",
467 "-.-.",
468 "-.-.-",
469 ".-.-",
470 "2001:db8:3333:4444:5555:6666:7777:8888",
471 "-example.com",
472 "local!host",
473 "thislabeliswaytoolongforbeingeversomethingwewishtocareabout-example.com",
474 "example-thislabeliswaytoolongforbeingeversomethingwewishtocareabout.com",
475 "こんにちは",
476 "こんにちは.com",
477 "😀",
478 "example..com",
479 "example dot com",
480 "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz",
481 ] {
482 assert!(Domain::try_from(str.to_owned()).is_err());
483 assert!(Domain::try_from(str.as_bytes().to_vec()).is_err());
484 }
485 }
486
487 #[test]
488 fn is_parent() {
489 let test_cases = vec![
490 ("www.example.com", "www.example.com"),
491 ("www.example.com", "www.example.com."),
492 ("www.example.com", ".www.example.com."),
493 (".www.example.com", "www.example.com"),
494 (".www.example.com", "www.example.com."),
495 (".www.example.com.", "www.example.com."),
496 ("www.example.com", "WwW.ExamplE.COM"),
497 ("example.com", "www.example.com"),
498 ("example.com", "m.example.com"),
499 ("example.com", "www.EXAMPLE.com"),
500 ("example.com", "M.example.com"),
501 ];
502 for (a, b) in test_cases.into_iter() {
503 let a = Domain::from_static(a);
504 let b = Domain::from_static(b);
505 assert!(a.is_parent_of(&b), "({:?}).is_parent_of({})", a, b);
506 }
507 }
508
509 #[test]
510 fn is_not_parent() {
511 let test_cases = vec![
512 ("www.example.com", "www.example.co"),
513 ("www.example.com", "www.ejemplo.com"),
514 ("www.example.com", "www3.example.com"),
515 ("w.example.com", "www.example.com"),
516 ("gel.com", "kegel.com"),
517 ];
518 for (a, b) in test_cases.into_iter() {
519 let a = Domain::from_static(a);
520 let b = Domain::from_static(b);
521 assert!(!a.is_parent_of(&b), "!({:?}).is_parent_of({})", a, b);
522 }
523 }
524
525 #[test]
526 fn is_equal() {
527 let test_cases = vec![
528 ("example.com", "example.com"),
529 ("example.com", "EXAMPLE.com"),
530 (".example.com", ".example.com"),
531 (".example.com", "example.com"),
532 ("example.com", ".example.com"),
533 ];
534 for (a, b) in test_cases.into_iter() {
535 assert_eq!(Domain::from_static(a), b);
536 assert_eq!(Domain::from_static(a), b.to_owned());
537 assert_eq!(Domain::from_static(a), Domain::from_static(b));
538 assert_eq!(a, Domain::from_static(b));
539 assert_eq!(a.to_owned(), Domain::from_static(b));
540 }
541 }
542
543 #[test]
544 fn is_not_equal() {
545 let test_cases = vec![
546 ("example.com", "localhost"),
547 ("example.com", "example.com."),
548 ("example.com", "example.co"),
549 ("example.com", "examine.com"),
550 ("example.com", "example.com.us"),
551 ("example.com", "www.example.com"),
552 ];
553 for (a, b) in test_cases.into_iter() {
554 assert_ne!(Domain::from_static(a), b);
555 assert_ne!(Domain::from_static(a), b.to_owned());
556 assert_ne!(Domain::from_static(a), Domain::from_static(b));
557 assert_ne!(a, Domain::from_static(b));
558 assert_ne!(a.to_owned(), Domain::from_static(b));
559 }
560 }
561
562 #[test]
563 fn cmp() {
564 let test_cases = vec![
565 ("example.com", "example.com", Ordering::Equal),
566 ("example.com", "EXAMPLE.com", Ordering::Equal),
567 (".example.com", ".example.com", Ordering::Equal),
568 (".example.com", "example.com", Ordering::Equal),
569 ("example.com", ".example.com", Ordering::Equal),
570 ("example.com", "localhost", Ordering::Less),
571 ("example.com", "example.com.", Ordering::Less),
572 ("example.com", "example.co", Ordering::Greater),
573 ("example.com", "examine.com", Ordering::Greater),
574 ("example.com", "example.com.us", Ordering::Less),
575 ("example.com", "www.example.com", Ordering::Less),
576 ];
577 for (a, b, expected) in test_cases.into_iter() {
578 assert_eq!(Some(expected), Domain::from_static(a).partial_cmp(&b));
579 assert_eq!(
580 Some(expected),
581 Domain::from_static(a).partial_cmp(&b.to_owned())
582 );
583 assert_eq!(
584 Some(expected),
585 Domain::from_static(a).partial_cmp(&Domain::from_static(b))
586 );
587 assert_eq!(
588 expected,
589 Domain::from_static(a).cmp(&Domain::from_static(b))
590 );
591 assert_eq!(Some(expected), a.partial_cmp(&Domain::from_static(b)));
592 assert_eq!(
593 Some(expected),
594 a.to_owned().partial_cmp(&Domain::from_static(b))
595 );
596 }
597 }
598
599 #[test]
600 fn test_hash() {
601 let mut m = HashMap::new();
602
603 assert!(!m.contains_key(&Domain::from_static("example.com")));
604 assert!(!m.contains_key(&Domain::from_static("EXAMPLE.COM")));
605 assert!(!m.contains_key(&Domain::from_static(".example.com")));
606 assert!(!m.contains_key(&Domain::from_static(".example.COM")));
607
608 m.insert(Domain::from_static("eXaMpLe.COm"), ());
609
610 assert!(m.contains_key(&Domain::from_static("example.com")));
611 assert!(m.contains_key(&Domain::from_static("EXAMPLE.COM")));
612 assert!(m.contains_key(&Domain::from_static(".example.com")));
613 assert!(m.contains_key(&Domain::from_static(".example.COM")));
614
615 assert!(!m.contains_key(&Domain::from_static("www.example.com")));
616 assert!(!m.contains_key(&Domain::from_static("examine.com")));
617 assert!(!m.contains_key(&Domain::from_static("example.com.")));
618 assert!(!m.contains_key(&Domain::from_static("example.co")));
619 assert!(!m.contains_key(&Domain::from_static("example.commerce")));
620 }
621}