distant_net/common/destination/
host.rs

1use std::fmt;
2use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
3use std::str::FromStr;
4
5use derive_more::{Display, Error, From};
6use serde::de::Deserializer;
7use serde::ser::Serializer;
8use serde::{Deserialize, Serialize};
9
10use super::{deserialize_from_str, serialize_to_str};
11
12/// Represents the host of a destination
13#[derive(Clone, Debug, From, Display, Hash, PartialEq, Eq)]
14pub enum Host {
15    Ipv4(Ipv4Addr),
16    Ipv6(Ipv6Addr),
17
18    /// Represents a hostname that follows the
19    /// [DoD Internet Host Table Specification](https://www.ietf.org/rfc/rfc0952.txt):
20    ///
21    /// * Hostname can be a maximum of 253 characters including '.'
22    /// * Each label is a-zA-Z0-9 alongside hyphen ('-') and a maximum size of 63 characters
23    /// * Labels can be segmented by periods ('.')
24    Name(String),
25}
26
27impl Host {
28    /// Indicates whether the host destination is globally routable
29    pub const fn is_global(&self) -> bool {
30        match self {
31            Self::Ipv4(x) => {
32                !(x.is_broadcast()
33                    || x.is_documentation()
34                    || x.is_link_local()
35                    || x.is_loopback()
36                    || x.is_private()
37                    || x.is_unspecified())
38            }
39            Self::Ipv6(x) => {
40                // NOTE: 14 is the global flag
41                x.is_multicast() && (x.segments()[0] & 0x000f == 14)
42            }
43            Self::Name(_) => false,
44        }
45    }
46
47    /// Returns true if host is an IPv4 address
48    pub const fn is_ipv4(&self) -> bool {
49        matches!(self, Self::Ipv4(_))
50    }
51
52    /// Returns true if host is an IPv6 address
53    pub const fn is_ipv6(&self) -> bool {
54        matches!(self, Self::Ipv6(_))
55    }
56
57    /// Returns true if host is a name
58    pub const fn is_name(&self) -> bool {
59        matches!(self, Self::Name(_))
60    }
61}
62
63impl From<IpAddr> for Host {
64    fn from(addr: IpAddr) -> Self {
65        match addr {
66            IpAddr::V4(x) => Self::Ipv4(x),
67            IpAddr::V6(x) => Self::Ipv6(x),
68        }
69    }
70}
71
72#[derive(Copy, Clone, Debug, Error, Hash, PartialEq, Eq)]
73pub enum HostParseError {
74    EmptyLabel,
75    EndsWithHyphen,
76    EndsWithPeriod,
77    InvalidLabel,
78    LargeLabel,
79    LargeName,
80    StartsWithHyphen,
81    StartsWithPeriod,
82}
83
84impl HostParseError {
85    /// Returns a static `str` describing the error
86    pub const fn into_static_str(self) -> &'static str {
87        match self {
88            Self::EmptyLabel => "Hostname cannot have an empty label",
89            Self::EndsWithHyphen => "Hostname cannot end with hyphen ('-')",
90            Self::EndsWithPeriod => "Hostname cannot end with period ('.')",
91            Self::InvalidLabel => "Hostname label can only be a-zA-Z0-9 or hyphen ('-')",
92            Self::LargeLabel => "Hostname label larger cannot be larger than 63 characters",
93            Self::LargeName => "Hostname cannot be larger than 253 characters",
94            Self::StartsWithHyphen => "Hostname cannot start with hyphen ('-')",
95            Self::StartsWithPeriod => "Hostname cannot start with period ('.')",
96        }
97    }
98}
99
100impl fmt::Display for HostParseError {
101    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102        write!(f, "{}", self.into_static_str())
103    }
104}
105
106impl FromStr for Host {
107    type Err = HostParseError;
108
109    /// Parses a host from a str
110    ///
111    /// ### Examples
112    ///
113    /// ```
114    /// # use distant_net::common::Host;
115    /// # use std::net::{Ipv4Addr, Ipv6Addr};
116    /// // IPv4 address
117    /// assert_eq!("127.0.0.1".parse(), Ok(Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1))));
118    ///
119    /// // IPv6 address
120    /// assert_eq!("::1".parse(), Ok(Host::Ipv6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))));
121    ///
122    /// // Valid hostname
123    /// assert_eq!("localhost".parse(), Ok(Host::Name("localhost".to_string())));
124    ///
125    /// // Invalid hostname
126    /// assert!("local_host".parse::<Host>().is_err());
127    /// ```
128    fn from_str(s: &str) -> Result<Self, Self::Err> {
129        // Check if the str is a valid Ipv4 or Ipv6 address first
130        if let Ok(x) = s.parse::<Ipv4Addr>() {
131            return Ok(Self::Ipv4(x));
132        } else if let Ok(x) = s.parse::<Ipv6Addr>() {
133            return Ok(Self::Ipv6(x));
134        }
135
136        // NOTE: We have to catch an empty string seprately from invalid label checks
137        if s.is_empty() {
138            return Err(HostParseError::InvalidLabel);
139        }
140
141        // Since it is not, we need to validate the string as a hostname
142        let mut label_size_cnt = 0;
143        let mut last_char = None;
144        for (i, c) in s.char_indices() {
145            if i >= 253 {
146                return Err(HostParseError::LargeName);
147            }
148
149            // Dot and hyphen cannot be first character
150            if i == 0 && c == '.' {
151                return Err(HostParseError::StartsWithPeriod);
152            } else if i == 0 && c == '-' {
153                return Err(HostParseError::StartsWithHyphen);
154            }
155
156            if c.is_alphanumeric() {
157                label_size_cnt += 1;
158                if label_size_cnt > 63 {
159                    return Err(HostParseError::LargeLabel);
160                }
161            } else if c == '.' {
162                // Back-to-back dots are not allowed (would indicate an empty label, which is
163                // reserved)
164                if label_size_cnt == 0 {
165                    return Err(HostParseError::EmptyLabel);
166                }
167
168                label_size_cnt = 0;
169            } else if c != '-' {
170                return Err(HostParseError::InvalidLabel);
171            }
172
173            last_char = Some(c);
174        }
175
176        if last_char == Some('.') {
177            return Err(HostParseError::EndsWithPeriod);
178        } else if last_char == Some('-') {
179            return Err(HostParseError::EndsWithHyphen);
180        }
181
182        Ok(Self::Name(s.to_string()))
183    }
184}
185
186impl PartialEq<str> for Host {
187    fn eq(&self, other: &str) -> bool {
188        match self {
189            Self::Ipv4(x) => x.to_string() == other,
190            Self::Ipv6(x) => x.to_string() == other,
191            Self::Name(x) => x == other,
192        }
193    }
194}
195
196impl<'a> PartialEq<&'a str> for Host {
197    fn eq(&self, other: &&'a str) -> bool {
198        match self {
199            Self::Ipv4(x) => x.to_string() == *other,
200            Self::Ipv6(x) => x.to_string() == *other,
201            Self::Name(x) => x == other,
202        }
203    }
204}
205
206impl Serialize for Host {
207    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
208    where
209        S: Serializer,
210    {
211        serialize_to_str(self, serializer)
212    }
213}
214
215impl<'de> Deserialize<'de> for Host {
216    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
217    where
218        D: Deserializer<'de>,
219    {
220        deserialize_from_str(deserializer)
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn display_should_output_ipv4_correctly() {
230        let host = Host::Ipv4(Ipv4Addr::LOCALHOST);
231        assert_eq!(host.to_string(), "127.0.0.1");
232    }
233
234    #[test]
235    fn display_should_output_ipv6_correctly() {
236        let host = Host::Ipv6(Ipv6Addr::LOCALHOST);
237        assert_eq!(host.to_string(), "::1");
238    }
239
240    #[test]
241    fn display_should_output_hostname_verbatim() {
242        let host = Host::Name("localhost".to_string());
243        assert_eq!(host.to_string(), "localhost");
244    }
245
246    #[test]
247    fn from_str_should_fail_if_str_is_empty() {
248        let err = "".parse::<Host>().unwrap_err();
249        assert_eq!(err, HostParseError::InvalidLabel);
250    }
251
252    #[test]
253    fn from_str_should_fail_if_str_is_larger_than_253_characters() {
254        // 63 + 1 + 63 + 1 + 63 + 1 + 62 = 254 characters
255        let long_name = format!(
256            "{}.{}.{}.{}",
257            "a".repeat(63),
258            "a".repeat(63),
259            "a".repeat(63),
260            "a".repeat(62)
261        );
262        let err = long_name.parse::<Host>().unwrap_err();
263        assert_eq!(err, HostParseError::LargeName);
264    }
265
266    #[test]
267    fn from_str_should_fail_if_str_starts_with_period() {
268        let err = ".localhost".parse::<Host>().unwrap_err();
269        assert_eq!(err, HostParseError::StartsWithPeriod);
270    }
271
272    #[test]
273    fn from_str_should_fail_if_str_ends_with_period() {
274        let err = "localhost.".parse::<Host>().unwrap_err();
275        assert_eq!(err, HostParseError::EndsWithPeriod);
276    }
277
278    #[test]
279    fn from_str_should_fail_if_str_starts_with_hyphen() {
280        let err = "-localhost".parse::<Host>().unwrap_err();
281        assert_eq!(err, HostParseError::StartsWithHyphen);
282    }
283
284    #[test]
285    fn from_str_should_fail_if_str_ends_with_hyphen() {
286        let err = "localhost-".parse::<Host>().unwrap_err();
287        assert_eq!(err, HostParseError::EndsWithHyphen);
288    }
289
290    #[test]
291    fn from_str_should_fail_if_str_has_a_label_larger_than_63_characters() {
292        let long_label = format!("{}.com", "a".repeat(64));
293        let err = long_label.parse::<Host>().unwrap_err();
294        assert_eq!(err, HostParseError::LargeLabel);
295    }
296
297    #[test]
298    fn from_str_should_fail_if_str_has_empty_label() {
299        let err = "example..com".parse::<Host>().unwrap_err();
300        assert_eq!(err, HostParseError::EmptyLabel);
301    }
302
303    #[test]
304    fn from_str_should_fail_if_str_has_invalid_label() {
305        let err = "www.exa_mple.com".parse::<Host>().unwrap_err();
306        assert_eq!(err, HostParseError::InvalidLabel);
307    }
308
309    #[test]
310    fn from_str_should_succeed_if_valid_ipv4_address() {
311        let host = "127.0.0.1".parse::<Host>().unwrap();
312        assert_eq!(host, Host::Ipv4(Ipv4Addr::new(127, 0, 0, 1)));
313    }
314
315    #[test]
316    fn from_str_should_succeed_if_valid_ipv6_address() {
317        let host = "::1".parse::<Host>().unwrap();
318        assert_eq!(host, Host::Ipv6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
319    }
320
321    #[test]
322    fn from_str_should_succeed_if_valid_hostname() {
323        let host = "localhost".parse::<Host>().unwrap();
324        assert_eq!(host, Host::Name("localhost".to_string()));
325
326        let host = "example.com".parse::<Host>().unwrap();
327        assert_eq!(host, Host::Name("example.com".to_string()));
328
329        let host = "w-w-w.example.com".parse::<Host>().unwrap();
330        assert_eq!(host, Host::Name("w-w-w.example.com".to_string()));
331
332        let host = "w3.example.com".parse::<Host>().unwrap();
333        assert_eq!(host, Host::Name("w3.example.com".to_string()));
334
335        // Revision of RFC-952 via RFC-1123 allows digit at start of label
336        let host = "3.example.com".parse::<Host>().unwrap();
337        assert_eq!(host, Host::Name("3.example.com".to_string()));
338    }
339}