rama_net/address/
host.rs

1use super::{Domain, parse_utils};
2use rama_core::error::{ErrorContext, OpaqueError};
3use std::{
4    fmt,
5    net::{IpAddr, Ipv4Addr, Ipv6Addr},
6};
7
8#[cfg(feature = "http")]
9use rama_http_types::HeaderValue;
10
11/// Either a [`Domain`] or an [`IpAddr`].
12#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
13pub enum Host {
14    /// A domain.
15    Name(Domain),
16
17    /// An IP address.
18    Address(IpAddr),
19}
20
21impl Host {
22    /// Local loopback address (IPv4)
23    pub const LOCALHOST_IPV4: Self = Self::Address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)));
24
25    /// Local loopback address (IPv6)
26    pub const LOCALHOST_IPV6: Self =
27        Self::Address(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)));
28
29    /// Local loopback name
30    pub const LOCALHOST_NAME: Self = Self::Name(Domain::from_static("localhost"));
31
32    /// Default address, not routable
33    pub const DEFAULT_IPV4: Self = Self::Address(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)));
34
35    /// Default address, not routable (IPv6)
36    pub const DEFAULT_IPV6: Self = Self::Address(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)));
37
38    /// Broadcast address (IPv4)
39    pub const BROADCAST_IPV4: Self = Self::Address(IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)));
40
41    /// `example.com` domain name
42    pub const EXAMPLE_NAME: Self = Self::Name(Domain::from_static("example.com"));
43}
44
45impl PartialEq<str> for Host {
46    fn eq(&self, other: &str) -> bool {
47        match self {
48            Self::Name(domain) => domain.as_str() == other,
49            Self::Address(ip) => ip.to_string() == other,
50        }
51    }
52}
53
54impl PartialEq<Host> for str {
55    fn eq(&self, other: &Host) -> bool {
56        other == self
57    }
58}
59
60impl PartialEq<&str> for Host {
61    fn eq(&self, other: &&str) -> bool {
62        self == *other
63    }
64}
65
66impl PartialEq<Host> for &str {
67    fn eq(&self, other: &Host) -> bool {
68        other == *self
69    }
70}
71
72impl PartialEq<String> for Host {
73    fn eq(&self, other: &String) -> bool {
74        self == other.as_str()
75    }
76}
77
78impl PartialEq<Host> for String {
79    fn eq(&self, other: &Host) -> bool {
80        other == self.as_str()
81    }
82}
83
84impl PartialEq<Ipv4Addr> for Host {
85    fn eq(&self, other: &Ipv4Addr) -> bool {
86        match self {
87            Self::Name(_) => false,
88            Self::Address(ip) => match ip {
89                IpAddr::V4(ip) => ip == other,
90                IpAddr::V6(ip) => ip.to_ipv4().map(|ip| ip == *other).unwrap_or_default(),
91            },
92        }
93    }
94}
95
96impl PartialEq<Host> for Ipv4Addr {
97    fn eq(&self, other: &Host) -> bool {
98        other == self
99    }
100}
101
102impl PartialEq<Ipv6Addr> for Host {
103    fn eq(&self, other: &Ipv6Addr) -> bool {
104        match self {
105            Self::Name(_) => false,
106            Self::Address(ip) => match ip {
107                IpAddr::V4(ip) => ip.to_ipv6_mapped() == *other,
108                IpAddr::V6(ip) => ip == other,
109            },
110        }
111    }
112}
113
114impl PartialEq<Host> for Ipv6Addr {
115    fn eq(&self, other: &Host) -> bool {
116        other == self
117    }
118}
119
120impl PartialEq<IpAddr> for Host {
121    fn eq(&self, other: &IpAddr) -> bool {
122        match other {
123            IpAddr::V4(ip) => self == ip,
124            IpAddr::V6(ip) => self == ip,
125        }
126    }
127}
128
129impl PartialEq<Host> for IpAddr {
130    fn eq(&self, other: &Host) -> bool {
131        other == self
132    }
133}
134
135impl From<Domain> for Host {
136    fn from(domain: Domain) -> Self {
137        Host::Name(domain)
138    }
139}
140
141impl From<IpAddr> for Host {
142    fn from(ip: IpAddr) -> Self {
143        Host::Address(ip)
144    }
145}
146
147impl From<Ipv4Addr> for Host {
148    fn from(ip: Ipv4Addr) -> Self {
149        Host::Address(IpAddr::V4(ip))
150    }
151}
152
153impl From<Ipv6Addr> for Host {
154    fn from(ip: Ipv6Addr) -> Self {
155        Host::Address(IpAddr::V6(ip))
156    }
157}
158
159impl fmt::Display for Host {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            Self::Name(domain) => domain.fmt(f),
163            Self::Address(ip) => ip.fmt(f),
164        }
165    }
166}
167
168impl std::str::FromStr for Host {
169    type Err = OpaqueError;
170
171    fn from_str(s: &str) -> Result<Self, Self::Err> {
172        Host::try_from(s)
173    }
174}
175
176impl TryFrom<String> for Host {
177    type Error = OpaqueError;
178
179    fn try_from(name: String) -> Result<Self, Self::Error> {
180        parse_utils::try_to_parse_str_to_ip(name.as_str())
181            .map(Host::Address)
182            .or_else(|| Domain::try_from(name).ok().map(Host::Name))
183            .context("parse host from string")
184    }
185}
186
187impl TryFrom<&str> for Host {
188    type Error = OpaqueError;
189
190    fn try_from(name: &str) -> Result<Self, Self::Error> {
191        parse_utils::try_to_parse_str_to_ip(name)
192            .map(Host::Address)
193            .or_else(|| Domain::try_from(name.to_owned()).ok().map(Host::Name))
194            .context("parse host from string")
195    }
196}
197
198#[cfg(feature = "http")]
199impl TryFrom<HeaderValue> for Host {
200    type Error = OpaqueError;
201
202    fn try_from(header: HeaderValue) -> Result<Self, Self::Error> {
203        Self::try_from(&header)
204    }
205}
206
207#[cfg(feature = "http")]
208impl TryFrom<&HeaderValue> for Host {
209    type Error = OpaqueError;
210
211    fn try_from(header: &HeaderValue) -> Result<Self, Self::Error> {
212        header.to_str().context("convert header to str")?.try_into()
213    }
214}
215
216impl TryFrom<Vec<u8>> for Host {
217    type Error = OpaqueError;
218
219    fn try_from(name: Vec<u8>) -> Result<Self, Self::Error> {
220        try_to_parse_bytes_to_ip(name.as_slice())
221            .map(Host::Address)
222            .or_else(|| Domain::try_from(name).ok().map(Host::Name))
223            .context("parse host from string")
224    }
225}
226
227impl TryFrom<&[u8]> for Host {
228    type Error = OpaqueError;
229
230    fn try_from(name: &[u8]) -> Result<Self, Self::Error> {
231        try_to_parse_bytes_to_ip(name)
232            .map(Host::Address)
233            .or_else(|| Domain::try_from(name.to_owned()).ok().map(Host::Name))
234            .context("parse host from string")
235    }
236}
237
238impl serde::Serialize for Host {
239    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
240    where
241        S: serde::Serializer,
242    {
243        let host = self.to_string();
244        host.serialize(serializer)
245    }
246}
247
248impl<'de> serde::Deserialize<'de> for Host {
249    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
250    where
251        D: serde::Deserializer<'de>,
252    {
253        let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
254        s.parse().map_err(serde::de::Error::custom)
255    }
256}
257
258fn try_to_parse_bytes_to_ip(value: &[u8]) -> Option<IpAddr> {
259    if let Some(ip) = std::str::from_utf8(value)
260        .ok()
261        .and_then(parse_utils::try_to_parse_str_to_ip)
262    {
263        return Some(ip);
264    }
265
266    if let Ok(ip) = TryInto::<&[u8; 4]>::try_into(value).map(|bytes| IpAddr::from(*bytes)) {
267        return Some(ip);
268    }
269
270    if let Ok(ip) = TryInto::<&[u8; 16]>::try_into(value).map(|bytes| IpAddr::from(*bytes)) {
271        return Some(ip);
272    }
273
274    None
275}
276
277#[cfg(test)]
278#[allow(clippy::expect_fun_call)]
279mod tests {
280    use super::*;
281
282    #[derive(Debug, Clone, Copy)]
283    enum Is {
284        Domain(&'static str),
285        Ip(&'static str),
286    }
287
288    fn assert_is(host: Host, expected: Is) {
289        match expected {
290            Is::Domain(domain) => match host {
291                Host::Address(address) => panic!(
292                    "expected host address {} to be the domain: {}",
293                    address, domain
294                ),
295                Host::Name(name) => assert_eq!(domain, name),
296            },
297            Is::Ip(ip) => match host {
298                Host::Address(address) => assert_eq!(ip, address.to_string()),
299                Host::Name(name) => panic!("expected host domain {} to be the ip: {}", name, ip),
300            },
301        }
302    }
303
304    #[test]
305    fn test_parse_specials() {
306        for (str, expected) in [
307            ("localhost", Is::Domain("localhost")),
308            ("0.0.0.0", Is::Ip("0.0.0.0")),
309            ("::1", Is::Ip("::1")),
310            ("[::1]", Is::Ip("::1")),
311            ("127.0.0.1", Is::Ip("127.0.0.1")),
312            ("::", Is::Ip("::")),
313            ("[::]", Is::Ip("::")),
314        ] {
315            let msg = format!("parsing {}", str);
316            assert_is(Host::try_from(str).expect(msg.as_str()), expected);
317            assert_is(
318                Host::try_from(str.to_owned()).expect(msg.as_str()),
319                expected,
320            );
321            assert_is(
322                Host::try_from(str.as_bytes()).expect(msg.as_str()),
323                expected,
324            );
325            assert_is(
326                Host::try_from(str.as_bytes().to_vec()).expect(msg.as_str()),
327                expected,
328            );
329        }
330    }
331
332    #[test]
333    fn test_parse_bytes_valid() {
334        for (bytes, expected) in [
335            ("example.com".as_bytes(), Is::Domain("example.com")),
336            ("aA1".as_bytes(), Is::Domain("aA1")),
337            (&[127, 0, 0, 1], Is::Ip("127.0.0.1")),
338            (&[19, 117, 63, 126], Is::Ip("19.117.63.126")),
339            (
340                &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
341                Is::Ip("::1"),
342            ),
343            (
344                &[
345                    32, 1, 13, 184, 51, 51, 68, 68, 85, 85, 102, 102, 119, 119, 136, 136,
346                ],
347                Is::Ip("2001:db8:3333:4444:5555:6666:7777:8888"),
348            ),
349        ] {
350            let msg = format!("parsing {:?}", bytes);
351            assert_is(Host::try_from(bytes).expect(msg.as_str()), expected);
352            assert_is(
353                Host::try_from(bytes.to_vec()).expect(msg.as_str()),
354                expected,
355            );
356        }
357    }
358
359    #[test]
360    fn test_parse_valid() {
361        for (str, expected) in [
362            ("example.com", Is::Domain("example.com")),
363            ("www.example.com", Is::Domain("www.example.com")),
364            ("a-b-c.com", Is::Domain("a-b-c.com")),
365            ("a-b-c.example.com", Is::Domain("a-b-c.example.com")),
366            ("a-b-c.example", Is::Domain("a-b-c.example")),
367            ("aA1", Is::Domain("aA1")),
368            (".example.com", Is::Domain(".example.com")),
369            ("example.com.", Is::Domain("example.com.")),
370            (".example.com.", Is::Domain(".example.com.")),
371            ("127.0.0.1", Is::Ip("127.0.0.1")),
372            ("127.00.1", Is::Domain("127.00.1")),
373            ("::1", Is::Ip("::1")),
374            ("[::1]", Is::Ip("::1")),
375            (
376                "2001:db8:3333:4444:5555:6666:7777:8888",
377                Is::Ip("2001:db8:3333:4444:5555:6666:7777:8888"),
378            ),
379            (
380                "[2001:db8:3333:4444:5555:6666:7777:8888]",
381                Is::Ip("2001:db8:3333:4444:5555:6666:7777:8888"),
382            ),
383            ("::", Is::Ip("::")),
384            ("[::]", Is::Ip("::")),
385            ("19.117.63.126", Is::Ip("19.117.63.126")),
386        ] {
387            let msg = format!("parsing {}", str);
388            assert_is(Host::try_from(str).expect(msg.as_str()), expected);
389            assert_is(
390                Host::try_from(str.to_owned()).expect(msg.as_str()),
391                expected,
392            );
393            assert_is(
394                Host::try_from(str.as_bytes()).expect(msg.as_str()),
395                expected,
396            );
397            assert_is(
398                Host::try_from(str.as_bytes().to_vec()).expect(msg.as_str()),
399                expected,
400            );
401        }
402    }
403
404    #[test]
405    fn test_parse_str_invalid() {
406        for str in [
407            "",
408            ".",
409            "-",
410            ".-",
411            "-.",
412            ".-.",
413            "[::",
414            "::]",
415            "@",
416            "こんにちは",
417            "こんにちは.com",
418        ] {
419            assert!(Host::try_from(str).is_err(), "parsing {}", str);
420            assert!(Host::try_from(str.to_owned()).is_err(), "parsing {}", str);
421        }
422    }
423
424    #[test]
425    fn compare_host_with_ipv4_bidirectional() {
426        let test_cases = [
427            (
428                true,
429                "127.0.0.1".parse::<Host>().unwrap(),
430                Ipv4Addr::new(127, 0, 0, 1),
431            ),
432            (
433                false,
434                "127.0.0.2".parse::<Host>().unwrap(),
435                Ipv4Addr::new(127, 0, 0, 1),
436            ),
437            (
438                false,
439                "127.0.0.1".parse::<Host>().unwrap(),
440                Ipv4Addr::new(127, 0, 0, 2),
441            ),
442        ];
443        for (expected, a, b) in test_cases {
444            assert_eq!(expected, a == b, "a[{a}] == b[{b}]");
445            assert_eq!(expected, b == a, "b[{b}] == a[{a}]");
446        }
447    }
448
449    #[test]
450    fn compare_host_with_ipv6_bidirectional() {
451        let test_cases = [
452            (
453                true,
454                "::1".parse::<Host>().unwrap(),
455                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
456            ),
457            (
458                false,
459                "::2".parse::<Host>().unwrap(),
460                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
461            ),
462            (
463                false,
464                "::1".parse::<Host>().unwrap(),
465                Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2),
466            ),
467        ];
468        for (expected, a, b) in test_cases {
469            assert_eq!(expected, a == b, "a[{a}] == b[{b}]");
470            assert_eq!(expected, b == a, "b[{b}] == a[{a}]");
471        }
472    }
473
474    #[test]
475    fn compare_host_with_ip_bidirectional() {
476        let test_cases = [
477            (
478                true,
479                "127.0.0.1".parse::<Host>().unwrap(),
480                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
481            ),
482            (
483                false,
484                "127.0.0.2".parse::<Host>().unwrap(),
485                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
486            ),
487            (
488                false,
489                "127.0.0.1".parse::<Host>().unwrap(),
490                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 2)),
491            ),
492            (
493                false,
494                "::2".parse::<Host>().unwrap(),
495                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
496            ),
497        ];
498        for (expected, a, b) in test_cases {
499            assert_eq!(expected, a == b, "a[{a}] == b[{b}]");
500            assert_eq!(expected, b == a, "b[{b}] == a[{a}]");
501        }
502    }
503}