uni_addr/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::borrow::Cow;
4use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
5use std::str::FromStr;
6use std::sync::Arc;
7use std::{fmt, io};
8
9#[cfg(unix)]
10pub mod unix;
11
12/// The prefix for Unix domain socket URIs.
13///
14/// - `unix:///path/to/socket` for a pathname socket address.
15/// - `unix://@abstract.unix.socket` for an abstract socket address.
16pub const UNIX_URI_PREFIX: &str = "unix://";
17
18wrapper_lite::wrapper!(
19    #[wrapper_impl(Debug)]
20    #[wrapper_impl(Display)]
21    #[wrapper_impl(AsRef)]
22    #[wrapper_impl(Deref)]
23    #[repr(align(cache))]
24    #[derive(Clone, PartialEq, Eq, Hash)]
25    /// A unified address type that can represent:
26    ///
27    /// - [`std::net::SocketAddr`]
28    /// - [`unix::SocketAddr`] (a wrapper over
29    ///   [`std::os::unix::net::SocketAddr`])
30    /// - A host name with port. See [`ToSocketAddrs`].
31    ///
32    /// # Parsing Behaviour
33    ///
34    /// - Checks if the address started with [`UNIX_URI_PREFIX`]: parse as a UDS
35    ///   address.
36    /// - Checks if the address is started with a alphabetic character (a-z,
37    ///   A-Z): treat as a host name. Notes that we will not validate if the
38    ///   host name is valid.
39    /// - Tries to parse as a network socket address.
40    /// - Otherwise, treats the input as a host name.
41    pub struct UniAddr(UniAddrInner);
42);
43
44impl From<SocketAddr> for UniAddr {
45    fn from(addr: SocketAddr) -> Self {
46        UniAddr::const_from(UniAddrInner::Inet(addr))
47    }
48}
49
50#[cfg(unix)]
51impl From<std::os::unix::net::SocketAddr> for UniAddr {
52    fn from(addr: std::os::unix::net::SocketAddr) -> Self {
53        UniAddr::const_from(UniAddrInner::Unix(addr.into()))
54    }
55}
56
57#[cfg(all(unix, feature = "feat-tokio"))]
58impl From<tokio::net::unix::SocketAddr> for UniAddr {
59    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
60        UniAddr::const_from(UniAddrInner::Unix(unix::SocketAddr::from(addr.into())))
61    }
62}
63
64#[cfg(unix)]
65impl From<crate::unix::SocketAddr> for UniAddr {
66    fn from(addr: crate::unix::SocketAddr) -> Self {
67        UniAddr::const_from(UniAddrInner::Unix(addr))
68    }
69}
70
71impl FromStr for UniAddr {
72    type Err = ParseError;
73
74    fn from_str(addr: &str) -> Result<Self, Self::Err> {
75        Self::new(addr)
76    }
77}
78
79#[cfg(feature = "feat-serde")]
80impl serde::Serialize for UniAddr {
81    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82    where
83        S: serde::Serializer,
84    {
85        serializer.serialize_str(&self.to_str())
86    }
87}
88
89#[cfg(feature = "feat-serde")]
90impl<'de> serde::Deserialize<'de> for UniAddr {
91    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92    where
93        D: serde::Deserializer<'de>,
94    {
95        Self::new(<&str>::deserialize(deserializer)?).map_err(serde::de::Error::custom)
96    }
97}
98
99impl UniAddr {
100    #[inline]
101    /// Creates a new [`UniAddr`] from its string representation.
102    pub fn new(addr: &str) -> Result<Self, ParseError> {
103        if addr.is_empty() {
104            return Err(ParseError::Empty);
105        }
106
107        if let Some(addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
108            #[cfg(unix)]
109            {
110                return unix::SocketAddr::new(addr)
111                    .map(UniAddrInner::Unix)
112                    .map(Self::const_from)
113                    .map_err(ParseError::InvalidUDSAddress);
114            }
115
116            #[cfg(not(unix))]
117            {
118                return Err(ParseError::Unsupported);
119            }
120        }
121
122        let Some((host, port)) = addr.rsplit_once(':') else {
123            return Err(ParseError::InvalidPort);
124        };
125
126        let Ok(port) = port.parse::<u16>() else {
127            return Err(ParseError::InvalidPort);
128        };
129
130        // Short-circuit: IPv4 address starts with a digit.
131        if host.chars().next().is_some_and(|c| c.is_ascii_digit()) {
132            return Ipv4Addr::from_str(host)
133                .map(|ip| SocketAddr::V4(SocketAddrV4::new(ip, port)))
134                .map(UniAddrInner::Inet)
135                .map(Self::const_from)
136                .map_err(|_| ParseError::InvalidHost)
137                .or_else(|_| {
138                    // A host name may also start with a digit.
139                    Self::new_host(addr, Some((host, port)))
140                });
141        }
142
143        // Short-circuit: if starts with '[' and ends with ']', may be an IPv6 address
144        // and can never be a host.
145        if let Some(ipv6_addr) = host.strip_prefix('[').and_then(|s| s.strip_suffix(']')) {
146            return Ipv6Addr::from_str(ipv6_addr)
147                .map(|ip| SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))
148                .map(UniAddrInner::Inet)
149                .map(Self::const_from)
150                .map_err(|_| ParseError::InvalidHost);
151        }
152
153        // Fallback: check if is a valid host name.
154        Self::new_host(addr, Some((host, port)))
155    }
156
157    /// Creates a new [`UniAddr`] from a string containing a host name and port,
158    /// like `example.com:8080`.
159    pub fn new_host(addr: &str, parsed: Option<(&str, u16)>) -> Result<Self, ParseError> {
160        let (hostname, _port) = match parsed {
161            Some((hostname, port)) => (hostname, port),
162            None => addr
163                .rsplit_once(':')
164                .ok_or(ParseError::InvalidPort)
165                .and_then(|(hostname, port)| {
166                    let Ok(port) = port.parse::<u16>() else {
167                        return Err(ParseError::InvalidPort);
168                    };
169
170                    Ok((hostname, port))
171                })?,
172        };
173
174        Self::validate_host_name(hostname.as_bytes()).map_err(|_| ParseError::InvalidHost)?;
175
176        Ok(Self::const_from(UniAddrInner::Host(Arc::from(addr))))
177    }
178
179    // https://github.com/rustls/pki-types/blob/b8c04aa6b7a34875e2c4a33edc9b78d31da49523/src/server_name.rs
180    const fn validate_host_name(input: &[u8]) -> Result<(), ()> {
181        enum State {
182            Start,
183            Next,
184            NumericOnly { len: usize },
185            NextAfterNumericOnly,
186            Subsequent { len: usize },
187            Hyphen { len: usize },
188        }
189
190        use State::*;
191
192        let mut state = Start;
193
194        /// "Labels must be 63 characters or less."
195        const MAX_LABEL_LENGTH: usize = 63;
196
197        /// https://devblogs.microsoft.com/oldnewthing/20120412-00/?p=7873
198        const MAX_NAME_LENGTH: usize = 253;
199
200        if input.len() > MAX_NAME_LENGTH {
201            return Err(());
202        }
203
204        let mut idx = 0;
205        while idx < input.len() {
206            let ch = input[idx];
207            state = match (state, ch) {
208                (Start | Next | NextAfterNumericOnly | Hyphen { .. }, b'.') => {
209                    return Err(());
210                }
211                (Subsequent { .. }, b'.') => Next,
212                (NumericOnly { .. }, b'.') => NextAfterNumericOnly,
213                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, _)
214                    if len >= MAX_LABEL_LENGTH =>
215                {
216                    return Err(());
217                }
218                (Start | Next | NextAfterNumericOnly, b'0'..=b'9') => NumericOnly { len: 1 },
219                (NumericOnly { len }, b'0'..=b'9') => NumericOnly { len: len + 1 },
220                (Start | Next | NextAfterNumericOnly, b'a'..=b'z' | b'A'..=b'Z' | b'_') => {
221                    Subsequent { len: 1 }
222                }
223                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, b'-') => {
224                    Hyphen { len: len + 1 }
225                }
226                (
227                    Subsequent { len } | NumericOnly { len } | Hyphen { len },
228                    b'a'..=b'z' | b'A'..=b'Z' | b'_' | b'0'..=b'9',
229                ) => Subsequent { len: len + 1 },
230                _ => return Err(()),
231            };
232            idx += 1;
233        }
234
235        if matches!(
236            state,
237            Start | Hyphen { .. } | NumericOnly { .. } | NextAfterNumericOnly
238        ) {
239            return Err(());
240        }
241
242        Ok(())
243    }
244
245    #[inline]
246    /// Serializes the address to a string.
247    pub fn to_str(&self) -> Cow<'_, str> {
248        self.as_inner().to_str()
249    }
250}
251
252#[non_exhaustive]
253#[derive(Debug, Clone, PartialEq, Eq, Hash)]
254/// See [`UniAddr`].
255///
256/// Generally, you should use [`UniAddr`] instead of this type directly, as
257/// we expose this type only for easier pattern matching. A valid [`UniAddr`]
258/// can be constructed only through [`FromStr`] implementation.
259pub enum UniAddrInner {
260    /// See [`SocketAddr`](std::net::SocketAddr).
261    Inet(std::net::SocketAddr),
262
263    #[cfg(unix)]
264    /// See [`SocketAddr`](crate::unix::SocketAddr).
265    Unix(crate::unix::SocketAddr),
266
267    /// A host name with port. See [`ToSocketAddrs`](std::net::ToSocketAddrs).
268    Host(Arc<str>),
269}
270
271impl fmt::Display for UniAddrInner {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        self.to_str().fmt(f)
274    }
275}
276
277impl UniAddrInner {
278    #[inline]
279    /// Serializes the address to a string.
280    pub fn to_str(&self) -> Cow<'_, str> {
281        match self {
282            Self::Inet(addr) => addr.to_string().into(),
283            Self::Unix(addr) => addr
284                .to_os_string_impl(UNIX_URI_PREFIX, "@")
285                .to_string_lossy()
286                .to_string()
287                .into(),
288            Self::Host(host) => Cow::Borrowed(host),
289        }
290    }
291}
292
293#[derive(Debug)]
294/// Errors that can occur when parsing a [`UniAddr`] from a string.
295pub enum ParseError {
296    /// Empty input string
297    Empty,
298
299    /// Invalid or missing hostname, or an invalid Ipv4 / IPv6 address
300    InvalidHost,
301
302    /// Invalid address format: missing or invalid port
303    InvalidPort,
304
305    /// Invalid UDS address format
306    InvalidUDSAddress(io::Error),
307
308    /// Unsupported address type on this platform
309    Unsupported,
310}
311
312impl fmt::Display for ParseError {
313    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
314        match self {
315            Self::Empty => write!(f, "empty address string"),
316            Self::InvalidHost => write!(f, "invalid or missing host address"),
317            Self::InvalidPort => write!(f, "invalid or missing port"),
318            Self::InvalidUDSAddress(err) => write!(f, "invalid UDS address: {}", err),
319            Self::Unsupported => write!(f, "unsupported address type on this platform"),
320        }
321    }
322}
323
324impl std::error::Error for ParseError {
325    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
326        match self {
327            Self::InvalidUDSAddress(err) => Some(err),
328            _ => None,
329        }
330    }
331}
332
333impl From<ParseError> for io::Error {
334    fn from(value: ParseError) -> Self {
335        io::Error::new(io::ErrorKind::Other, value)
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    #[cfg(unix)]
342    use std::os::linux::net::SocketAddrExt;
343
344    use super::*;
345
346    #[test]
347    fn test_socket_addr_new_ipv4() {
348        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
349
350        match addr.as_inner() {
351            UniAddrInner::Inet(std_addr) => {
352                assert_eq!(std_addr.ip().to_string(), "127.0.0.1");
353                assert_eq!(std_addr.port(), 8080);
354            }
355            _ => panic!("Expected Inet address, got {:?}", addr),
356        }
357    }
358
359    #[test]
360    fn test_socket_addr_new_ipv6() {
361        let addr = UniAddr::new("[::1]:8080").unwrap();
362
363        match addr.as_inner() {
364            UniAddrInner::Inet(std_addr) => {
365                assert_eq!(std_addr.ip().to_string(), "::1");
366                assert_eq!(std_addr.port(), 8080);
367            }
368            #[cfg(unix)]
369            _ => unreachable!(),
370        }
371    }
372
373    #[cfg(unix)]
374    #[test]
375    fn test_socket_addr_new_unix_pathname() {
376        let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
377
378        match addr.as_inner() {
379            UniAddrInner::Unix(unix_addr) => {
380                assert!(unix_addr.as_pathname().is_some());
381            }
382            _ => unreachable!(),
383        }
384    }
385
386    #[cfg(unix)]
387    #[test]
388    fn test_socket_addr_new_unix_abstract() {
389        let addr = UniAddr::new("unix://@test.abstract").unwrap();
390
391        match addr.as_inner() {
392            UniAddrInner::Unix(unix_addr) => {
393                assert!(unix_addr.as_abstract_name().is_some());
394            }
395            _ => unreachable!(),
396        }
397    }
398
399    #[test]
400    fn test_socket_addr_new_host() {
401        let addr = UniAddr::new("example.com:8080").unwrap();
402
403        match addr.as_inner() {
404            UniAddrInner::Host(host) => {
405                assert_eq!(&**host, "example.com:8080");
406            }
407            _ => unreachable!(),
408        }
409    }
410
411    #[test]
412    fn test_socket_addr_new_invalid() {
413        // Invalid format
414        assert!(UniAddr::new("invalid").is_err());
415        assert!(UniAddr::new("127.0.0.1").is_err()); // Missing port
416        assert!(UniAddr::new("example.com:invalid").is_err()); // Invalid port
417        assert!(UniAddr::new("127.0.0.1:invalid").is_err()); // Invalid port
418    }
419
420    #[cfg(not(unix))]
421    #[test]
422    fn test_socket_addr_new_unix_unsupported() {
423        // Unix sockets should be unsupported on non-Unix platforms
424        let result = UniAddr::new("unix:///tmp/test.sock");
425
426        assert!(matches!(result.unwrap_err(), ParseError::Unsupported));
427    }
428
429    #[test]
430    fn test_socket_addr_display() {
431        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
432        assert_eq!(&addr.to_str(), "127.0.0.1:8080");
433
434        let addr = UniAddr::new("[::1]:8080").unwrap();
435        assert_eq!(&addr.to_str(), "[::1]:8080");
436
437        #[cfg(unix)]
438        {
439            let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
440            assert_eq!(&addr.to_str(), "unix:///tmp/test.sock");
441
442            let addr = UniAddr::new("unix://@test.abstract").unwrap();
443            assert_eq!(&addr.to_str(), "unix://@test.abstract");
444        }
445
446        let addr = UniAddr::new("example.com:8080").unwrap();
447        assert_eq!(&addr.to_str(), "example.com:8080");
448    }
449
450    #[test]
451    fn test_socket_addr_debug() {
452        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
453        let debug_str = format!("{:?}", addr);
454
455        assert!(debug_str.contains("127.0.0.1:8080"));
456    }
457
458    #[test]
459    fn test_edge_cases() {
460        assert!(UniAddr::new("").is_err());
461        assert!(UniAddr::new("not-an-address").is_err());
462        assert!(UniAddr::new("127.0.0.1:99999").is_err()); // Port too high
463
464        #[cfg(unix)]
465        {
466            assert!(UniAddr::new("unix://").is_ok()); // Empty path -> unnamed one
467            #[cfg(any(target_os = "android", target_os = "linux"))]
468            assert!(UniAddr::new("unix://@").is_ok()); // Empty abstract one
469        }
470    }
471}