use std::io;
use std::net::TcpStream;
#[cfg(unix)]
use std::os::unix::net::UnixStream;
#[cfg(unix)]
use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SocketAddr {
#[cfg(unix)]
Unix(PathBuf),
Tcp(String, u16),
}
pub fn parse_addr(s: &str) -> Result<SocketAddr, String> {
if let Some(rest) = s.strip_prefix(':') {
let port: u16 = rest
.parse()
.map_err(|_| format!("invalid TCP port in `{s}`"))?;
return Ok(SocketAddr::Tcp("127.0.0.1".to_string(), port));
}
#[cfg(unix)]
if s.starts_with('/') {
return Ok(SocketAddr::Unix(PathBuf::from(s)));
}
if let Some((host, port_str)) = s.split_once(':')
&& !host.is_empty()
&& let Ok(port) = port_str.parse::<u16>()
{
return Ok(SocketAddr::Tcp(host.to_string(), port));
}
#[cfg(unix)]
{
Ok(SocketAddr::Unix(PathBuf::from(s)))
}
#[cfg(not(unix))]
{
Err(format!(
"ambiguous socket address `{s}` on non-unix platforms"
))
}
}
pub enum SocketStream {
#[cfg(unix)]
Unix(UnixStream),
Tcp(TcpStream),
}
impl SocketStream {
pub fn connect(addr: &SocketAddr) -> io::Result<Self> {
match addr {
#[cfg(unix)]
SocketAddr::Unix(path) => Ok(SocketStream::Unix(UnixStream::connect(path)?)),
SocketAddr::Tcp(host, port) => Ok(SocketStream::Tcp(TcpStream::connect((
host.as_str(),
*port,
))?)),
}
}
pub fn try_clone(&self) -> io::Result<Self> {
match self {
#[cfg(unix)]
SocketStream::Unix(s) => Ok(SocketStream::Unix(s.try_clone()?)),
SocketStream::Tcp(s) => Ok(SocketStream::Tcp(s.try_clone()?)),
}
}
pub fn shutdown(&self) {
match self {
#[cfg(unix)]
SocketStream::Unix(s) => {
let _ = s.shutdown(std::net::Shutdown::Both);
}
SocketStream::Tcp(s) => {
let _ = s.shutdown(std::net::Shutdown::Both);
}
}
}
}
impl io::Read for SocketStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match self {
#[cfg(unix)]
SocketStream::Unix(s) => s.read(buf),
SocketStream::Tcp(s) => s.read(buf),
}
}
}
impl io::Write for SocketStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match self {
#[cfg(unix)]
SocketStream::Unix(s) => s.write(buf),
SocketStream::Tcp(s) => s.write(buf),
}
}
fn flush(&mut self) -> io::Result<()> {
match self {
#[cfg(unix)]
SocketStream::Unix(s) => s.flush(),
SocketStream::Tcp(s) => s.flush(),
}
}
}
pub struct SocketAdapter {
pub addr: SocketAddr,
pub stream: SocketStream,
}
impl SocketAdapter {
pub fn connect(addr_str: &str) -> std::result::Result<Self, crate::Error> {
let addr = parse_addr(addr_str).map_err(crate::Error::InvalidSettings)?;
let stream = SocketStream::connect(&addr)?;
Ok(Self { addr, stream })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_tcp_port_only() {
let addr = parse_addr(":4567").unwrap();
assert_eq!(addr, SocketAddr::Tcp("127.0.0.1".to_string(), 4567));
}
#[test]
fn parses_host_port() {
let addr = parse_addr("example.com:8080").unwrap();
assert_eq!(addr, SocketAddr::Tcp("example.com".to_string(), 8080));
}
#[cfg(unix)]
#[test]
fn parses_unix_absolute_path() {
let addr = parse_addr("/tmp/plushie.sock").unwrap();
assert_eq!(addr, SocketAddr::Unix(PathBuf::from("/tmp/plushie.sock")));
}
#[cfg(unix)]
#[test]
fn parses_bare_name_as_unix_path() {
let addr = parse_addr("plushie.sock").unwrap();
assert_eq!(addr, SocketAddr::Unix(PathBuf::from("plushie.sock")));
}
#[test]
fn rejects_bad_port() {
assert!(parse_addr(":not_a_port").is_err());
}
}