pubky_homeserver/data_directory/
domain.rs1use std::fmt::{self, Display};
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq)]
8pub struct Domain(pub String);
9
10impl Domain {
11 pub fn new(domain: String) -> Result<Self, anyhow::Error> {
13 Self::is_valid_domain(&domain)?;
14 Ok(Self(domain))
15 }
16
17 pub fn is_valid_domain(domain: &str) -> anyhow::Result<()> {
19 if !hostname_validator::is_valid(domain) {
21 return Err(anyhow::anyhow!(
22 "Invalid domain '{}': is not a valid RFC 1123 hostname",
23 domain
24 ));
25 }
26 Ok(())
27 }
28}
29
30impl Default for Domain {
31 fn default() -> Self {
32 Self("localhost".to_string())
33 }
34}
35
36impl FromStr for Domain {
37 type Err = anyhow::Error;
38
39 fn from_str(s: &str) -> Result<Self, Self::Err> {
40 Self::is_valid_domain(s)?;
41 Ok(Self(s.to_string()))
42 }
43}
44
45impl Display for Domain {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 write!(f, "{}", self.0)
48 }
49}
50
51impl Serialize for Domain {
52 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
53 where
54 S: serde::Serializer,
55 {
56 serializer.serialize_str(&self.to_string())
57 }
58}
59
60impl<'de> Deserialize<'de> for Domain {
61 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62 where
63 D: serde::Deserializer<'de>,
64 {
65 let s = String::deserialize(deserializer)?;
66 Self::from_str(&s).map_err(|e| serde::de::Error::custom(e.to_string()))
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn test_domain_validation() {
76 let valid_domains = [
78 "example.com",
79 "sub.example.com",
80 "a.b.c.d",
81 "valid-domain.com",
82 "valid.domain-name.com",
83 "localhost",
84 "test.local",
85 ];
86
87 for domain in valid_domains {
88 let result: anyhow::Result<Domain> = domain.parse();
89 assert!(result.is_ok(), "Domain '{}' should be valid", domain);
90 }
91
92 let invalid_domains = [
94 ("invalid@domain.com", "contains invalid characters"),
95 ("domain..com", "contains consecutive dots"),
96 (".domain.com", "starts with a dot"),
97 ("domain.com.", "ends with a dot"),
98 ("-domain.com", "starts with a hyphen"),
99 ("domain.com-", "ends with a hyphen"),
100 ];
101
102 for (domain, reason) in invalid_domains {
103 let result: anyhow::Result<Domain> = domain.parse();
104 assert!(
105 result.is_err(),
106 "Domain '{}' should be invalid: {}",
107 domain,
108 reason
109 );
110 }
111 }
112}