use std::io;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream};
pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
pub enum Listener {
Tcp(TcpListener),
Unix(UnixListener),
}
impl Listener {
pub async fn accept(
&mut self,
) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
match self {
Listener::Tcp(listener) => {
let (stream, addr) = listener.accept().await?;
Ok((Box::new(stream), Some(addr)))
}
Listener::Unix(listener) => {
let (stream, _) = listener.accept().await?;
Ok((Box::new(stream), None))
}
}
}
pub async fn bind(addr: &str) -> io::Result<Self> {
if addr.starts_with('/') || addr.starts_with('.') {
let listener = UnixListener::bind(addr)?;
Ok(Listener::Unix(listener))
} else {
let mut addr = addr.to_owned();
if addr.starts_with(':') {
addr = format!("127.0.0.1{}", addr);
};
let listener = TcpListener::bind(addr).await?;
Ok(Listener::Tcp(listener))
}
}
#[allow(dead_code)]
pub async fn connect(&self) -> io::Result<AsyncReadWriteBox> {
match self {
Listener::Tcp(listener) => {
let stream = TcpStream::connect(listener.local_addr()?).await?;
Ok(Box::new(stream))
}
Listener::Unix(listener) => {
let stream =
UnixStream::connect(listener.local_addr()?.as_pathname().unwrap()).await?;
Ok(Box::new(stream))
}
}
}
}
impl std::fmt::Display for Listener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Listener::Tcp(listener) => {
let addr = listener.local_addr().unwrap();
write!(f, "{}:{}", addr.ip(), addr.port())
}
Listener::Unix(listener) => {
let addr = listener.local_addr().unwrap();
let path = addr.as_pathname().unwrap();
write!(f, "{}", path.display())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
async fn exercise_listener(addr: &str) {
let mut listener = Listener::bind(addr).await.unwrap();
let mut client = listener.connect().await.unwrap();
let (mut serve, _) = listener.accept().await.unwrap();
let want = b"Hello from server!";
serve.write_all(want).await.unwrap();
drop(serve);
let mut got = Vec::new();
client.read_to_end(&mut got).await.unwrap();
assert_eq!(want.to_vec(), got);
}
#[tokio::test]
async fn test_bind_tcp() {
exercise_listener(":0").await;
}
#[tokio::test]
async fn test_bind_unix() {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.sock");
let path = path.to_str().unwrap();
exercise_listener(path).await;
}
}