Skip to main content

pubky_homeserver/data_directory/
domain.rs

1use std::fmt::{self, Display};
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6/// Validated domain name according to RFC 1123.
7#[derive(Debug, Clone, PartialEq)]
8pub struct Domain(pub String);
9
10impl Domain {
11    /// Create a new domain from a string.
12    pub fn new(domain: String) -> Result<Self, anyhow::Error> {
13        Self::is_valid_domain(&domain)?;
14        Ok(Self(domain))
15    }
16
17    /// Validate a domain name according to RFC 1123
18    pub fn is_valid_domain(domain: &str) -> anyhow::Result<()> {
19        // Check if it's a valid hostname according to RFC 1123
20        if !hostname_validator::is_valid(domain) {
21            return Err(anyhow::anyhow!(
22                "Invalid domain '{}': is not a valid RFC 1123 hostname",
23                domain
24            ));
25        }
26        Ok(())
27    }
28}
29
30impl Default for Domain {
31    fn default() -> Self {
32        Self("localhost".to_string())
33    }
34}
35
36impl FromStr for Domain {
37    type Err = anyhow::Error;
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        Self::is_valid_domain(s)?;
41        Ok(Self(s.to_string()))
42    }
43}
44
45impl Display for Domain {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        write!(f, "{}", self.0)
48    }
49}
50
51impl Serialize for Domain {
52    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
53    where
54        S: serde::Serializer,
55    {
56        serializer.serialize_str(&self.to_string())
57    }
58}
59
60impl<'de> Deserialize<'de> for Domain {
61    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62    where
63        D: serde::Deserializer<'de>,
64    {
65        let s = String::deserialize(deserializer)?;
66        Self::from_str(&s).map_err(|e| serde::de::Error::custom(e.to_string()))
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn test_domain_validation() {
76        // Test valid domains
77        let valid_domains = [
78            "example.com",
79            "sub.example.com",
80            "a.b.c.d",
81            "valid-domain.com",
82            "valid.domain-name.com",
83            "localhost",
84            "test.local",
85        ];
86
87        for domain in valid_domains {
88            let result: anyhow::Result<Domain> = domain.parse();
89            assert!(result.is_ok(), "Domain '{}' should be valid", domain);
90        }
91
92        // Test invalid domains
93        let invalid_domains = [
94            ("invalid@domain.com", "contains invalid characters"),
95            ("domain..com", "contains consecutive dots"),
96            (".domain.com", "starts with a dot"),
97            ("domain.com.", "ends with a dot"),
98            ("-domain.com", "starts with a hyphen"),
99            ("domain.com-", "ends with a hyphen"),
100        ];
101
102        for (domain, reason) in invalid_domains {
103            let result: anyhow::Result<Domain> = domain.parse();
104            assert!(
105                result.is_err(),
106                "Domain '{}' should be invalid: {}",
107                domain,
108                reason
109            );
110        }
111    }
112}