rama_net/address/
socket_address.rs

1use crate::address::parse_utils::try_to_parse_str_to_ip;
2use rama_core::error::{ErrorContext, OpaqueError};
3#[cfg(feature = "http")]
4use rama_http_types::HeaderValue;
5use std::fmt;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
7use std::str::FromStr;
8
9/// An [`IpAddr`] with an associated port
10#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
11pub struct SocketAddress {
12    ip_addr: IpAddr,
13    port: u16,
14}
15
16impl SocketAddress {
17    /// creates a new [`SocketAddress`]
18    pub const fn new(ip_addr: IpAddr, port: u16) -> Self {
19        SocketAddress { ip_addr, port }
20    }
21
22    /// creates a new local ipv4 [`SocketAddress`] for the given port
23    ///
24    /// # Example
25    ///
26    /// ```
27    /// use rama_net::address::SocketAddress;
28    ///
29    /// let addr = SocketAddress::local_ipv4(8080);
30    /// assert_eq!("127.0.0.1:8080", addr.to_string());
31    /// ```
32    pub const fn local_ipv4(port: u16) -> Self {
33        SocketAddress {
34            ip_addr: IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
35            port,
36        }
37    }
38
39    /// creates a new local ipv6 [`SocketAddress`] for the given port.
40    ///
41    /// # Example
42    ///
43    /// ```
44    /// use rama_net::address::SocketAddress;
45    ///
46    /// let addr = SocketAddress::local_ipv6(8080);
47    /// assert_eq!("[::1]:8080", addr.to_string());
48    /// ```
49    pub const fn local_ipv6(port: u16) -> Self {
50        SocketAddress {
51            ip_addr: IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
52            port,
53        }
54    }
55
56    /// creates a new default ipv4 [`SocketAddress`] for the given port
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rama_net::address::SocketAddress;
62    ///
63    /// let addr = SocketAddress::default_ipv4(8080);
64    /// assert_eq!("0.0.0.0:8080", addr.to_string());
65    /// ```
66    pub const fn default_ipv4(port: u16) -> Self {
67        SocketAddress {
68            ip_addr: IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)),
69            port,
70        }
71    }
72
73    /// creates a new default ipv6 [`SocketAddress`] for the given port.
74    ///
75    /// # Example
76    ///
77    /// ```
78    /// use rama_net::address::SocketAddress;
79    ///
80    /// let addr = SocketAddress::default_ipv6(8080);
81    /// assert_eq!("[::]:8080", addr.to_string());
82    /// ```
83    pub const fn default_ipv6(port: u16) -> Self {
84        SocketAddress {
85            ip_addr: IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)),
86            port,
87        }
88    }
89
90    /// creates a new broadcast ipv4 [`SocketAddress`] for the given port
91    ///
92    /// # Example
93    ///
94    /// ```
95    /// use rama_net::address::SocketAddress;
96    ///
97    /// let addr = SocketAddress::broadcast_ipv4(8080);
98    /// assert_eq!("255.255.255.255:8080", addr.to_string());
99    /// ```
100    pub const fn broadcast_ipv4(port: u16) -> Self {
101        SocketAddress {
102            ip_addr: IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
103            port,
104        }
105    }
106
107    /// Gets the [`IpAddr`] reference.
108    pub fn ip_addr(&self) -> &IpAddr {
109        &self.ip_addr
110    }
111
112    /// Consumes the [`SocketAddress`] and returns the [`IpAddr`].
113    pub fn into_ip_addr(self) -> IpAddr {
114        self.ip_addr
115    }
116
117    /// Gets the port
118    pub fn port(&self) -> u16 {
119        self.port
120    }
121
122    /// Consume self into its parts: `(ip_addr, port)`
123    pub fn into_parts(self) -> (IpAddr, u16) {
124        (self.ip_addr, self.port)
125    }
126}
127
128impl From<SocketAddr> for SocketAddress {
129    fn from(addr: SocketAddr) -> Self {
130        SocketAddress {
131            ip_addr: addr.ip(),
132            port: addr.port(),
133        }
134    }
135}
136
137impl From<&SocketAddr> for SocketAddress {
138    fn from(addr: &SocketAddr) -> Self {
139        SocketAddress {
140            ip_addr: addr.ip(),
141            port: addr.port(),
142        }
143    }
144}
145
146impl From<SocketAddrV4> for SocketAddress {
147    fn from(value: SocketAddrV4) -> Self {
148        SocketAddress {
149            ip_addr: (*value.ip()).into(),
150            port: value.port(),
151        }
152    }
153}
154
155impl From<SocketAddrV6> for SocketAddress {
156    fn from(value: SocketAddrV6) -> Self {
157        SocketAddress {
158            ip_addr: (*value.ip()).into(),
159            port: value.port(),
160        }
161    }
162}
163
164impl From<SocketAddress> for SocketAddr {
165    fn from(addr: SocketAddress) -> Self {
166        SocketAddr::new(addr.ip_addr, addr.port)
167    }
168}
169
170impl From<(IpAddr, u16)> for SocketAddress {
171    #[inline]
172    fn from((ip_addr, port): (IpAddr, u16)) -> Self {
173        Self { ip_addr, port }
174    }
175}
176
177impl From<(Ipv4Addr, u16)> for SocketAddress {
178    #[inline]
179    fn from((ip, port): (Ipv4Addr, u16)) -> Self {
180        Self {
181            ip_addr: ip.into(),
182            port,
183        }
184    }
185}
186
187impl From<([u8; 4], u16)> for SocketAddress {
188    #[inline]
189    fn from((ip, port): ([u8; 4], u16)) -> Self {
190        let ip: IpAddr = ip.into();
191        (ip, port).into()
192    }
193}
194
195impl From<(Ipv6Addr, u16)> for SocketAddress {
196    #[inline]
197    fn from((ip, port): (Ipv6Addr, u16)) -> Self {
198        Self {
199            ip_addr: ip.into(),
200            port,
201        }
202    }
203}
204
205impl From<([u8; 16], u16)> for SocketAddress {
206    #[inline]
207    fn from((ip, port): ([u8; 16], u16)) -> Self {
208        let ip: IpAddr = ip.into();
209        (ip, port).into()
210    }
211}
212
213impl fmt::Display for SocketAddress {
214    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
215        match &self.ip_addr {
216            IpAddr::V4(ip) => write!(f, "{}:{}", ip, self.port),
217            IpAddr::V6(ip) => write!(f, "[{}]:{}", ip, self.port),
218        }
219    }
220}
221
222impl FromStr for SocketAddress {
223    type Err = OpaqueError;
224
225    fn from_str(s: &str) -> Result<Self, Self::Err> {
226        SocketAddress::try_from(s)
227    }
228}
229
230impl TryFrom<String> for SocketAddress {
231    type Error = OpaqueError;
232
233    fn try_from(s: String) -> Result<Self, Self::Error> {
234        s.as_str().try_into()
235    }
236}
237
238impl TryFrom<&String> for SocketAddress {
239    type Error = OpaqueError;
240
241    fn try_from(value: &String) -> Result<Self, Self::Error> {
242        value.as_str().try_into()
243    }
244}
245
246impl TryFrom<&str> for SocketAddress {
247    type Error = OpaqueError;
248
249    fn try_from(s: &str) -> Result<Self, Self::Error> {
250        let (ip_addr, port) = crate::address::parse_utils::split_port_from_str(s)?;
251        let ip_addr =
252            try_to_parse_str_to_ip(ip_addr).context("parse ip address from socket address")?;
253        match ip_addr {
254            IpAddr::V6(_) if !s.starts_with('[') => Err(OpaqueError::from_display(
255                "missing brackets for IPv6 address with port",
256            )),
257            _ => Ok(SocketAddress { ip_addr, port }),
258        }
259    }
260}
261
262#[cfg(feature = "http")]
263impl TryFrom<HeaderValue> for SocketAddress {
264    type Error = OpaqueError;
265
266    fn try_from(header: HeaderValue) -> Result<Self, Self::Error> {
267        Self::try_from(&header)
268    }
269}
270
271#[cfg(feature = "http")]
272impl TryFrom<&HeaderValue> for SocketAddress {
273    type Error = OpaqueError;
274
275    fn try_from(header: &HeaderValue) -> Result<Self, Self::Error> {
276        header.to_str().context("convert header to str")?.try_into()
277    }
278}
279
280impl TryFrom<Vec<u8>> for SocketAddress {
281    type Error = OpaqueError;
282
283    fn try_from(bytes: Vec<u8>) -> Result<Self, Self::Error> {
284        Self::try_from(bytes.as_slice())
285    }
286}
287
288impl TryFrom<&[u8]> for SocketAddress {
289    type Error = OpaqueError;
290
291    fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
292        let s = std::str::from_utf8(bytes).context("parse sock address from bytes")?;
293        s.try_into()
294    }
295}
296
297impl serde::Serialize for SocketAddress {
298    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
299    where
300        S: serde::Serializer,
301    {
302        let address = self.to_string();
303        address.serialize(serializer)
304    }
305}
306
307impl<'de> serde::Deserialize<'de> for SocketAddress {
308    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
309    where
310        D: serde::Deserializer<'de>,
311    {
312        let s = <std::borrow::Cow<'de, str>>::deserialize(deserializer)?;
313        s.parse().map_err(serde::de::Error::custom)
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    fn assert_eq(s: &str, sock_address: SocketAddress, ip_addr: &str, port: u16) {
322        assert_eq!(
323            sock_address.ip_addr().to_string(),
324            ip_addr,
325            "parsing: {}",
326            s
327        );
328        assert_eq!(sock_address.port(), port, "parsing: {}", s);
329    }
330
331    #[test]
332    fn test_parse_valid() {
333        for (s, (expected_ip_addr, expected_port)) in [
334            ("[::1]:80", ("::1", 80)),
335            ("127.0.0.1:80", ("127.0.0.1", 80)),
336            (
337                "[2001:db8:3333:4444:5555:6666:7777:8888]:80",
338                ("2001:db8:3333:4444:5555:6666:7777:8888", 80),
339            ),
340        ] {
341            let msg = format!("parsing '{}'", s);
342
343            assert_eq(s, s.parse().expect(&msg), expected_ip_addr, expected_port);
344            assert_eq(
345                s,
346                s.try_into().expect(&msg),
347                expected_ip_addr,
348                expected_port,
349            );
350            assert_eq(
351                s,
352                s.to_owned().try_into().expect(&msg),
353                expected_ip_addr,
354                expected_port,
355            );
356            assert_eq(
357                s,
358                s.as_bytes().try_into().expect(&msg),
359                expected_ip_addr,
360                expected_port,
361            );
362            assert_eq(
363                s,
364                s.as_bytes().to_vec().try_into().expect(&msg),
365                expected_ip_addr,
366                expected_port,
367            );
368        }
369    }
370
371    #[test]
372    fn test_parse_invalid() {
373        for s in [
374            "",
375            "-",
376            ".",
377            ":",
378            ":80",
379            "-.",
380            ".-",
381            "::1",
382            "127.0.0.1",
383            "[::1]",
384            "2001:db8:3333:4444:5555:6666:7777:8888",
385            "[2001:db8:3333:4444:5555:6666:7777:8888]",
386            "example.com",
387            "example.com:",
388            "example.com:-1",
389            "example.com:999999",
390            "example.com:80",
391            "example:com",
392            "[127.0.0.1]:80",
393            "2001:db8:3333:4444:5555:6666:7777:8888:80",
394        ] {
395            let msg = format!("parsing '{}'", s);
396            assert!(s.parse::<SocketAddress>().is_err(), "{}", msg);
397            assert!(SocketAddress::try_from(s).is_err(), "{}", msg);
398            assert!(SocketAddress::try_from(s.to_owned()).is_err(), "{}", msg);
399            assert!(SocketAddress::try_from(s.as_bytes()).is_err(), "{}", msg);
400            assert!(
401                SocketAddress::try_from(s.as_bytes().to_vec()).is_err(),
402                "{}",
403                msg
404            );
405        }
406    }
407
408    #[test]
409    fn test_parse_display() {
410        for (s, expected) in [("[::1]:80", "[::1]:80"), ("127.0.0.1:80", "127.0.0.1:80")] {
411            let msg = format!("parsing '{}'", s);
412            let socket_address: SocketAddress = s.parse().expect(&msg);
413            assert_eq!(socket_address.to_string(), expected, "{}", msg);
414        }
415    }
416}