uni_addr/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(clippy::must_use_candidate)]
3
4use std::borrow::Cow;
5use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
6use std::str::FromStr;
7use std::sync::Arc;
8use std::{fmt, io};
9
10#[cfg(unix)]
11pub mod unix;
12
13/// The prefix for Unix domain socket URIs.
14///
15/// - `unix:///path/to/socket` for a pathname socket address.
16/// - `unix://@abstract.unix.socket` for an abstract socket address.
17pub const UNIX_URI_PREFIX: &str = "unix://";
18
19wrapper_lite::wrapper!(
20    #[wrapper_impl(Debug)]
21    #[wrapper_impl(Display)]
22    #[wrapper_impl(AsRef)]
23    #[wrapper_impl(Deref)]
24    #[repr(align(cache))]
25    #[derive(Clone, PartialEq, Eq, Hash)]
26    /// A unified address type that can represent:
27    ///
28    /// - [`std::net::SocketAddr`]
29    /// - [`unix::SocketAddr`] (a wrapper over
30    ///   [`std::os::unix::net::SocketAddr`])
31    /// - A host name with port. See [`ToSocketAddrs`].
32    ///
33    /// # Parsing Behaviour
34    ///
35    /// - Checks if the address started with [`UNIX_URI_PREFIX`]: parse as a UDS
36    ///   address.
37    /// - Checks if the address is started with a alphabetic character (a-z,
38    ///   A-Z): treat as a host name. Notes that we will not validate if the
39    ///   host name is valid.
40    /// - Tries to parse as a network socket address.
41    /// - Otherwise, treats the input as a host name.
42    pub struct UniAddr(UniAddrInner);
43);
44
45impl From<SocketAddr> for UniAddr {
46    fn from(addr: SocketAddr) -> Self {
47        UniAddr::const_from(UniAddrInner::Inet(addr))
48    }
49}
50
51#[cfg(unix)]
52impl From<std::os::unix::net::SocketAddr> for UniAddr {
53    fn from(addr: std::os::unix::net::SocketAddr) -> Self {
54        UniAddr::const_from(UniAddrInner::Unix(addr.into()))
55    }
56}
57
58#[cfg(all(unix, feature = "feat-tokio"))]
59impl From<tokio::net::unix::SocketAddr> for UniAddr {
60    fn from(addr: tokio::net::unix::SocketAddr) -> Self {
61        UniAddr::const_from(UniAddrInner::Unix(unix::SocketAddr::from(addr.into())))
62    }
63}
64
65#[cfg(unix)]
66impl From<crate::unix::SocketAddr> for UniAddr {
67    fn from(addr: crate::unix::SocketAddr) -> Self {
68        UniAddr::const_from(UniAddrInner::Unix(addr))
69    }
70}
71
72impl FromStr for UniAddr {
73    type Err = ParseError;
74
75    fn from_str(addr: &str) -> Result<Self, Self::Err> {
76        Self::new(addr)
77    }
78}
79
80#[cfg(feature = "feat-serde")]
81impl serde::Serialize for UniAddr {
82    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83    where
84        S: serde::Serializer,
85    {
86        serializer.serialize_str(&self.to_str())
87    }
88}
89
90#[cfg(feature = "feat-serde")]
91impl<'de> serde::Deserialize<'de> for UniAddr {
92    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
93    where
94        D: serde::Deserializer<'de>,
95    {
96        Self::new(&String::deserialize(deserializer)?).map_err(serde::de::Error::custom)
97    }
98}
99
100impl UniAddr {
101    #[inline]
102    /// Creates a new [`UniAddr`] from its string representation.
103    ///
104    /// # Errors
105    ///
106    /// Not a valid address string.
107    pub fn new(addr: &str) -> Result<Self, ParseError> {
108        if addr.is_empty() {
109            return Err(ParseError::Empty);
110        }
111
112        #[cfg(unix)]
113        if let Some(addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
114            return unix::SocketAddr::new(addr)
115                .map(UniAddrInner::Unix)
116                .map(Self::const_from)
117                .map_err(ParseError::InvalidUDSAddress);
118        }
119
120        #[cfg(not(unix))]
121        if let Some(_addr) = addr.strip_prefix(UNIX_URI_PREFIX) {
122            return Err(ParseError::Unsupported);
123        }
124
125        let Some((host, port)) = addr.rsplit_once(':') else {
126            return Err(ParseError::InvalidPort);
127        };
128
129        let Ok(port) = port.parse::<u16>() else {
130            return Err(ParseError::InvalidPort);
131        };
132
133        // Short-circuit: IPv4 address starts with a digit.
134        if host.chars().next().is_some_and(|c| c.is_ascii_digit()) {
135            return Ipv4Addr::from_str(host)
136                .map(|ip| SocketAddr::V4(SocketAddrV4::new(ip, port)))
137                .map(UniAddrInner::Inet)
138                .map(Self::const_from)
139                .map_err(|_| ParseError::InvalidHost)
140                .or_else(|_| {
141                    // A host name may also start with a digit.
142                    Self::new_host(addr, Some((host, port)))
143                });
144        }
145
146        // Short-circuit: if starts with '[' and ends with ']', may be an IPv6 address
147        // and can never be a host.
148        if let Some(ipv6_addr) = host.strip_prefix('[').and_then(|s| s.strip_suffix(']')) {
149            return Ipv6Addr::from_str(ipv6_addr)
150                .map(|ip| SocketAddr::V6(SocketAddrV6::new(ip, port, 0, 0)))
151                .map(UniAddrInner::Inet)
152                .map(Self::const_from)
153                .map_err(|_| ParseError::InvalidHost);
154        }
155
156        // Fallback: check if is a valid host name.
157        Self::new_host(addr, Some((host, port)))
158    }
159
160    /// Creates a new [`UniAddr`] from a string containing a host name and port,
161    /// like `example.com:8080`.
162    ///
163    /// # Errors
164    ///
165    /// - [`ParseError::InvalidHost`] if the host name is invalid.
166    /// - [`ParseError::InvalidPort`] if the port is invalid.
167    pub fn new_host(addr: &str, parsed: Option<(&str, u16)>) -> Result<Self, ParseError> {
168        let (hostname, _port) = match parsed {
169            Some((hostname, port)) => (hostname, port),
170            None => addr
171                .rsplit_once(':')
172                .ok_or(ParseError::InvalidPort)
173                .and_then(|(hostname, port)| {
174                    let Ok(port) = port.parse::<u16>() else {
175                        return Err(ParseError::InvalidPort);
176                    };
177
178                    Ok((hostname, port))
179                })?,
180        };
181
182        Self::validate_host_name(hostname.as_bytes()).map_err(|()| ParseError::InvalidHost)?;
183
184        Ok(Self::const_from(UniAddrInner::Host(Arc::from(addr))))
185    }
186
187    // https://github.com/rustls/pki-types/blob/b8c04aa6b7a34875e2c4a33edc9b78d31da49523/src/server_name.rs
188    const fn validate_host_name(input: &[u8]) -> Result<(), ()> {
189        enum State {
190            Start,
191            Next,
192            NumericOnly { len: usize },
193            NextAfterNumericOnly,
194            Subsequent { len: usize },
195            Hyphen { len: usize },
196        }
197
198        use State::{Hyphen, Next, NextAfterNumericOnly, NumericOnly, Start, Subsequent};
199
200        /// "Labels must be 63 characters or less."
201        const MAX_LABEL_LENGTH: usize = 63;
202
203        /// <https://devblogs.microsoft.com/oldnewthing/20120412-00/?p=7873>
204        const MAX_NAME_LENGTH: usize = 253;
205
206        let mut state = Start;
207
208        if input.len() > MAX_NAME_LENGTH {
209            return Err(());
210        }
211
212        let mut idx = 0;
213        while idx < input.len() {
214            let ch = input[idx];
215            state = match (state, ch) {
216                (Start | Next | NextAfterNumericOnly | Hyphen { .. }, b'.') => {
217                    return Err(());
218                }
219                (Subsequent { .. }, b'.') => Next,
220                (NumericOnly { .. }, b'.') => NextAfterNumericOnly,
221                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, _)
222                    if len >= MAX_LABEL_LENGTH =>
223                {
224                    return Err(());
225                }
226                (Start | Next | NextAfterNumericOnly, b'0'..=b'9') => NumericOnly { len: 1 },
227                (NumericOnly { len }, b'0'..=b'9') => NumericOnly { len: len + 1 },
228                (Start | Next | NextAfterNumericOnly, b'a'..=b'z' | b'A'..=b'Z' | b'_') => {
229                    Subsequent { len: 1 }
230                }
231                (Subsequent { len } | NumericOnly { len } | Hyphen { len }, b'-') => {
232                    Hyphen { len: len + 1 }
233                }
234                (
235                    Subsequent { len } | NumericOnly { len } | Hyphen { len },
236                    b'a'..=b'z' | b'A'..=b'Z' | b'_' | b'0'..=b'9',
237                ) => Subsequent { len: len + 1 },
238                _ => return Err(()),
239            };
240            idx += 1;
241        }
242
243        if matches!(
244            state,
245            Start | Hyphen { .. } | NumericOnly { .. } | NextAfterNumericOnly
246        ) {
247            return Err(());
248        }
249
250        Ok(())
251    }
252
253    #[inline]
254    /// Serializes the address to a string.
255    pub fn to_str(&self) -> Cow<'_, str> {
256        self.as_inner().to_str()
257    }
258}
259
260#[non_exhaustive]
261#[derive(Debug, Clone, PartialEq, Eq, Hash)]
262/// See [`UniAddr`].
263///
264/// Generally, you should use [`UniAddr`] instead of this type directly, as
265/// we expose this type only for easier pattern matching. A valid [`UniAddr`]
266/// can be constructed only through [`FromStr`] implementation.
267pub enum UniAddrInner {
268    /// See [`SocketAddr`].
269    Inet(SocketAddr),
270
271    #[cfg(unix)]
272    /// See [`SocketAddr`](crate::unix::SocketAddr).
273    Unix(crate::unix::SocketAddr),
274
275    /// A host name with port. See [`ToSocketAddrs`](std::net::ToSocketAddrs).
276    Host(Arc<str>),
277}
278
279impl fmt::Display for UniAddrInner {
280    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281        self.to_str().fmt(f)
282    }
283}
284
285impl UniAddrInner {
286    #[inline]
287    /// Serializes the address to a string.
288    pub fn to_str(&self) -> Cow<'_, str> {
289        match self {
290            Self::Inet(addr) => addr.to_string().into(),
291            #[cfg(unix)]
292            Self::Unix(addr) => addr
293                .to_os_string_impl(UNIX_URI_PREFIX, "@")
294                .to_string_lossy()
295                .to_string()
296                .into(),
297            Self::Host(host) => Cow::Borrowed(host),
298        }
299    }
300}
301
302#[derive(Debug)]
303/// Errors that can occur when parsing a [`UniAddr`] from a string.
304pub enum ParseError {
305    /// Empty input string
306    Empty,
307
308    /// Invalid or missing hostname, or an invalid Ipv4 / IPv6 address
309    InvalidHost,
310
311    /// Invalid address format: missing or invalid port
312    InvalidPort,
313
314    /// Invalid UDS address format
315    InvalidUDSAddress(io::Error),
316
317    /// Unsupported address type on this platform
318    Unsupported,
319}
320
321impl fmt::Display for ParseError {
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        match self {
324            Self::Empty => write!(f, "empty address string"),
325            Self::InvalidHost => write!(f, "invalid or missing host address"),
326            Self::InvalidPort => write!(f, "invalid or missing port"),
327            Self::InvalidUDSAddress(err) => write!(f, "invalid UDS address: {err}"),
328            Self::Unsupported => write!(f, "unsupported address type on this platform"),
329        }
330    }
331}
332
333impl std::error::Error for ParseError {
334    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
335        match self {
336            Self::InvalidUDSAddress(err) => Some(err),
337            _ => None,
338        }
339    }
340}
341
342impl From<ParseError> for io::Error {
343    fn from(value: ParseError) -> Self {
344        io::Error::new(io::ErrorKind::Other, value)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_socket_addr_new_ipv4() {
354        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
355
356        match addr.as_inner() {
357            UniAddrInner::Inet(std_addr) => {
358                assert_eq!(std_addr.ip().to_string(), "127.0.0.1");
359                assert_eq!(std_addr.port(), 8080);
360            }
361            _ => panic!("Got {:?}", addr),
362        }
363    }
364
365    #[test]
366    fn test_socket_addr_new_ipv6() {
367        let addr = UniAddr::new("[::1]:8080").unwrap();
368
369        match addr.as_inner() {
370            UniAddrInner::Inet(std_addr) => {
371                assert_eq!(std_addr.ip().to_string(), "::1");
372                assert_eq!(std_addr.port(), 8080);
373            }
374            _ => panic!("Got {:?}", addr),
375        }
376    }
377
378    #[cfg(unix)]
379    #[test]
380    fn test_socket_addr_new_unix_pathname() {
381        let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
382
383        match addr.as_inner() {
384            UniAddrInner::Unix(unix_addr) => {
385                assert!(unix_addr.as_pathname().is_some());
386            }
387            _ => panic!("Got {:?}", addr),
388        }
389    }
390
391    #[cfg(any(target_os = "android", target_os = "linux"))]
392    #[test]
393    fn test_socket_addr_new_unix_abstract() {
394        #[cfg(target_os = "android")]
395        use std::os::android::net::SocketAddrExt;
396        #[cfg(target_os = "linux")]
397        use std::os::linux::net::SocketAddrExt;
398
399        let addr = UniAddr::new("unix://@test.abstract").unwrap();
400
401        match addr.as_inner() {
402            UniAddrInner::Unix(unix_addr) => {
403                assert!(unix_addr.as_abstract_name().is_some());
404            }
405            _ => panic!("Got {:?}", addr),
406        }
407    }
408
409    #[test]
410    fn test_socket_addr_new_host() {
411        let addr = UniAddr::new("example.com:8080").unwrap();
412
413        match addr.as_inner() {
414            UniAddrInner::Host(host) => {
415                assert_eq!(&**host, "example.com:8080");
416            }
417            _ => panic!("Got {:?}", addr),
418        }
419    }
420
421    #[test]
422    fn test_socket_addr_new_invalid() {
423        // Invalid format
424        assert!(UniAddr::new("invalid").is_err());
425        assert!(UniAddr::new("127.0.0.1").is_err()); // Missing port
426        assert!(UniAddr::new("example.com:invalid").is_err()); // Invalid port
427        assert!(UniAddr::new("127.0.0.1:invalid").is_err()); // Invalid port
428    }
429
430    #[cfg(not(unix))]
431    #[test]
432    fn test_socket_addr_new_unix_unsupported() {
433        // Unix sockets should be unsupported on non-Unix platforms
434        let result = UniAddr::new("unix:///tmp/test.sock");
435
436        assert!(matches!(result.unwrap_err(), ParseError::Unsupported));
437    }
438
439    #[test]
440    fn test_socket_addr_display() {
441        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
442        assert_eq!(&addr.to_str(), "127.0.0.1:8080");
443
444        let addr = UniAddr::new("[::1]:8080").unwrap();
445        assert_eq!(&addr.to_str(), "[::1]:8080");
446
447        #[cfg(unix)]
448        {
449            let addr = UniAddr::new("unix:///tmp/test.sock").unwrap();
450            assert_eq!(&addr.to_str(), "unix:///tmp/test.sock");
451
452            #[cfg(any(target_os = "android", target_os = "linux"))]
453            {
454                let addr = UniAddr::new("unix://@test.abstract").unwrap();
455                assert_eq!(&addr.to_str(), "unix://@test.abstract");
456            }
457        }
458
459        let addr = UniAddr::new("example.com:8080").unwrap();
460        assert_eq!(&addr.to_str(), "example.com:8080");
461    }
462
463    #[test]
464    fn test_socket_addr_debug() {
465        let addr = UniAddr::new("127.0.0.1:8080").unwrap();
466        let debug_str = format!("{:?}", addr);
467
468        assert!(debug_str.contains("127.0.0.1:8080"));
469    }
470
471    #[test]
472    fn test_edge_cases() {
473        assert!(UniAddr::new("").is_err());
474        assert!(UniAddr::new("not-an-address").is_err());
475        assert!(UniAddr::new("127.0.0.1:99999").is_err()); // Port too high
476
477        #[cfg(unix)]
478        {
479            assert!(UniAddr::new("unix://").is_ok()); // Empty path -> unnamed one
480            #[cfg(any(target_os = "android", target_os = "linux"))]
481            assert!(UniAddr::new("unix://@").is_ok()); // Empty abstract one
482        }
483    }
484}