use anyhow::{Result, anyhow, bail};
use std::net::SocketAddr;
use std::path::PathBuf;
use std::str::FromStr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SocketKind {
TcpStream,
UdpDgram,
UnixStream,
UnixDgram,
UnixSeqpacket,
}
impl SocketKind {
pub fn is_byte_stream(self) -> bool {
matches!(self, SocketKind::TcpStream | SocketKind::UnixStream)
}
pub fn is_datagram_stream(self) -> bool {
!self.is_byte_stream()
}
#[allow(dead_code)] pub fn family(self) -> AddrFamily {
match self {
SocketKind::TcpStream | SocketKind::UdpDgram => AddrFamily::Inet,
SocketKind::UnixStream
| SocketKind::UnixDgram
| SocketKind::UnixSeqpacket => AddrFamily::Unix,
}
}
pub fn scheme(self) -> &'static str {
match self {
SocketKind::TcpStream => "tcp",
SocketKind::UdpDgram => "udp",
SocketKind::UnixStream => "unix-stream",
SocketKind::UnixDgram => "unix-dgram",
SocketKind::UnixSeqpacket => "unix-seqpacket",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
pub enum AddrFamily {
Inet,
Unix,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum AddrLocation {
Inet(SocketAddr),
Unix(PathBuf),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BoundAddr {
pub kind: SocketKind,
pub location: AddrLocation,
}
impl std::fmt::Display for BoundAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.to_url())
}
}
impl BoundAddr {
pub fn to_url(&self) -> String {
match (&self.kind, &self.location) {
(k, AddrLocation::Inet(sa)) => {
format!("{}://{}", k.scheme(), sa)
}
(k, AddrLocation::Unix(path)) => {
format!("{}:{}", k.scheme(), path.display())
}
}
}
pub fn as_inet(&self) -> Option<SocketAddr> {
match &self.location {
AddrLocation::Inet(sa) => Some(*sa),
_ => None,
}
}
pub fn as_unix_path(&self) -> Option<&std::path::Path> {
match &self.location {
AddrLocation::Unix(p) => Some(p),
_ => None,
}
}
pub fn parse(s: &str) -> Result<Self> {
if let Some(rest) = s.strip_prefix("tcp://") {
let sa = parse_inet(rest)?;
return Ok(BoundAddr {
kind: SocketKind::TcpStream,
location: AddrLocation::Inet(sa),
});
}
if let Some(rest) = s.strip_prefix("udp://") {
let sa = parse_inet(rest)?;
return Ok(BoundAddr {
kind: SocketKind::UdpDgram,
location: AddrLocation::Inet(sa),
});
}
if let Some(rest) = s.strip_prefix("unix-stream:") {
return Ok(BoundAddr {
kind: SocketKind::UnixStream,
location: AddrLocation::Unix(parse_unix_path(rest)?),
});
}
if let Some(rest) = s.strip_prefix("unix-dgram:") {
return Ok(BoundAddr {
kind: SocketKind::UnixDgram,
location: AddrLocation::Unix(parse_unix_path(rest)?),
});
}
if let Some(rest) = s.strip_prefix("unix-seqpacket:") {
return Ok(BoundAddr {
kind: SocketKind::UnixSeqpacket,
location: AddrLocation::Unix(parse_unix_path(rest)?),
});
}
bail!(
"address `{s}` is missing a scheme; expected one of \
tcp://host:port, udp://host:port, unix-stream:/path, \
unix-dgram:/path, unix-seqpacket:/path"
)
}
}
impl FromStr for BoundAddr {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
BoundAddr::parse(s)
}
}
fn parse_inet(rest: &str) -> Result<SocketAddr> {
rest.parse::<SocketAddr>().map_err(|e| {
anyhow!(
"invalid host:port `{rest}` ({e}); use a literal IP \
address with a numeric port"
)
})
}
fn parse_unix_path(rest: &str) -> Result<PathBuf> {
if rest.is_empty() {
bail!("unix socket path must not be empty");
}
Ok(PathBuf::from(rest))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_tcp_v4() {
let a = BoundAddr::parse("tcp://127.0.0.1:8080").unwrap();
assert_eq!(a.kind, SocketKind::TcpStream);
assert!(a.kind.is_byte_stream());
assert_eq!(
a.as_inet().unwrap().to_string(),
"127.0.0.1:8080"
);
}
#[test]
fn parses_tcp_v6() {
let a = BoundAddr::parse("tcp://[::1]:443").unwrap();
assert!(matches!(a.kind, SocketKind::TcpStream));
assert_eq!(a.as_inet().unwrap().port(), 443);
}
#[test]
fn parses_udp() {
let a = BoundAddr::parse("udp://0.0.0.0:53").unwrap();
assert_eq!(a.kind, SocketKind::UdpDgram);
assert!(a.kind.is_datagram_stream());
}
#[test]
fn parses_unix_variants() {
for (s, want) in [
("unix-stream:/run/a.sock", SocketKind::UnixStream),
("unix-dgram:/run/b.sock", SocketKind::UnixDgram),
("unix-seqpacket:/run/c.sock", SocketKind::UnixSeqpacket),
] {
let a = BoundAddr::parse(s).unwrap();
assert_eq!(a.kind, want);
assert_eq!(a.kind.family(), AddrFamily::Unix);
}
}
#[test]
fn rejects_bare_host_port() {
let err = BoundAddr::parse("127.0.0.1:8080").unwrap_err();
assert!(err.to_string().contains("missing a scheme"));
}
#[test]
fn rejects_legacy_udp_prefix() {
let err = BoundAddr::parse("udp:127.0.0.1:53").unwrap_err();
assert!(err.to_string().contains("missing a scheme"));
}
#[test]
fn rejects_legacy_unix_prefix() {
let err = BoundAddr::parse("unix:/run/a.sock").unwrap_err();
assert!(err.to_string().contains("missing a scheme"));
}
#[test]
fn rejects_empty_unix_path() {
let err = BoundAddr::parse("unix-stream:").unwrap_err();
assert!(err.to_string().contains("must not be empty"));
}
#[test]
fn rejects_hostname() {
let err = BoundAddr::parse("tcp://example.com:80").unwrap_err();
assert!(err.to_string().contains("invalid host:port"));
}
#[test]
fn to_url_round_trips() {
for s in [
"tcp://127.0.0.1:8080",
"udp://0.0.0.0:53",
"unix-stream:/run/a.sock",
"unix-dgram:/run/b.sock",
"unix-seqpacket:/run/c.sock",
] {
assert_eq!(BoundAddr::parse(s).unwrap().to_url(), s);
}
}
#[test]
fn round_trips_v6_in_url() {
let a = BoundAddr::parse("tcp://[::1]:443").unwrap();
assert_eq!(a.to_url(), "tcp://[::1]:443");
}
}