use std::{
convert::{TryFrom, TryInto},
io::{Read, Write},
net::Shutdown,
os::unix::{
io::{AsRawFd, RawFd},
net::{SocketAddr, UnixStream as StdUnixStream},
},
path::Path,
pin::Pin,
task::{Context, Poll},
};
use futures_core::stream::Stream;
use futures_util::{future::poll_fn, ready};
use pin_project::pin_project;
use socket2::{Domain, SockAddr, Socket, Type};
use tokio::io::{self, unix::AsyncFd, AsyncRead, AsyncWrite, ReadBuf};
use crate::{DequeueFd, EnqueueFd, QueueFullError};
#[pin_project]
#[derive(Debug)]
pub struct UnixStream {
inner: AsyncFd<crate::net::UnixStream>,
}
#[derive(Debug)]
pub struct UnixListener {
inner: AsyncFd<crate::net::UnixListener>,
}
impl UnixStream {
pub async fn connect(path: impl AsRef<Path>) -> io::Result<UnixStream> {
let typ = Type::stream().non_blocking().cloexec();
let socket = Socket::new(Domain::unix(), typ, None)?;
let addr = SockAddr::unix(path)?;
match socket.connect(&addr) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => return Err(e),
}
let stream: UnixStream = socket.into_unix_stream().try_into()?;
poll_fn(|cx| stream.inner.poll_write_ready(cx))
.await?
.retain_ready();
Ok(stream)
}
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
let (stream1, stream2) = crate::net::UnixStream::pair()?;
Ok((stream1.try_into()?, stream2.try_into()?))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().peer_addr()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.get_ref().take_error()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.get_ref().shutdown(how)
}
}
impl EnqueueFd for UnixStream {
fn enqueue(&mut self, fd: &impl AsRawFd) -> Result<(), QueueFullError> {
self.inner.get_mut().enqueue(fd)
}
}
impl DequeueFd for UnixStream {
fn dequeue(&mut self) -> Option<RawFd> {
self.inner.get_mut().dequeue()
}
}
impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.get_ref().as_raw_fd()
}
}
impl AsyncRead for UnixStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let inner = self.project().inner;
let mut guard = ready!(inner.poll_read_ready_mut(cx))?;
let bufinit = buf.initialize_unfilled();
match guard.try_io(|inner| inner.get_mut().read(bufinit)) {
Err(_) => Poll::Pending,
Ok(Err(e)) => Poll::Ready(Err(e)),
Ok(Ok(count)) => {
buf.advance(count);
Poll::Ready(Ok(()))
}
}
}
}
impl AsyncWrite for UnixStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let inner = self.project().inner;
let mut guard = ready!(inner.poll_write_ready_mut(cx))?;
match guard.try_io(|inner| inner.get_mut().write(buf)) {
Err(_) => Poll::Pending,
Ok(Err(e)) => Poll::Ready(Err(e)),
Ok(Ok(count)) => Poll::Ready(Ok(count)),
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
let inner = self.project().inner;
match inner.get_mut().shutdown(Shutdown::Write) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
Ok(()) => Poll::Ready(Ok(())),
}
}
}
impl TryFrom<crate::net::UnixStream> for UnixStream {
type Error = io::Error;
fn try_from(inner: crate::net::UnixStream) -> io::Result<UnixStream> {
inner.set_nonblocking(true)?;
Ok(UnixStream {
inner: AsyncFd::new(inner)?,
})
}
}
impl TryFrom<StdUnixStream> for UnixStream {
type Error = io::Error;
fn try_from(inner: StdUnixStream) -> Result<Self, Self::Error> {
let net_stream = crate::net::UnixStream::from(inner);
net_stream.try_into()
}
}
impl UnixListener {
pub fn bind(path: impl AsRef<Path>) -> io::Result<UnixListener> {
crate::net::UnixListener::bind(path)?.try_into()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.get_ref().local_addr()
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.get_ref().take_error()
}
pub async fn accept(&mut self) -> io::Result<(UnixStream, SocketAddr)> {
poll_fn(|cx| self.poll_accept(cx)).await
}
fn poll_accept(&self, cx: &mut Context) -> Poll<io::Result<(UnixStream, SocketAddr)>> {
let mut guard = ready!(self.inner.poll_read_ready(cx))?;
match guard.try_io(|inner| inner.get_ref().accept()) {
Err(_) => Poll::Pending,
Ok(Err(e)) => Poll::Ready(Err(e)),
Ok(Ok((socket, addr))) => Poll::Ready(Ok((socket.try_into()?, addr))),
}
}
}
impl AsRawFd for UnixListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.get_ref().as_raw_fd()
}
}
impl Stream for UnixListener {
type Item = io::Result<UnixStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
use Poll::{Pending, Ready};
match self.poll_accept(cx) {
Pending => Pending,
Ready(Ok((stream, _))) => Ready(Some(Ok(stream))),
Ready(Err(err)) => Ready(Some(Err(err))),
}
}
}
impl TryFrom<crate::net::UnixListener> for UnixListener {
type Error = io::Error;
fn try_from(inner: crate::net::UnixListener) -> io::Result<UnixListener> {
inner.set_nonblocking(true)?;
Ok(UnixListener {
inner: AsyncFd::new(inner)?,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::{prelude::*, SeekFrom};
use std::os::unix::io::FromRawFd as _;
use tempfile::{tempdir, tempfile};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn unix_stream_reads_other_sides_writes() {
let mut buf: [u8; 12] = [0; 12];
let (mut sut, mut other) = UnixStream::pair().expect("Can't create UnixStream's");
tokio::spawn(async move {
other
.write_all(b"Hello World!".as_ref())
.await
.expect("Can't write to UnixStream");
});
sut.read_exact(buf.as_mut())
.await
.expect("Can't read from UnixStream");
assert_eq!(&buf, b"Hello World!");
}
#[tokio::test]
async fn unix_stream_passes_fd() {
let mut file1 = tempfile().expect("Can't create temp file.");
file1
.write_all(b"Hello World!\0")
.expect("Can't write to temp file.");
file1.flush().expect("Can't flush temp file.");
file1
.seek(SeekFrom::Start(0))
.expect("Couldn't seek the file.");
let mut buf = [0u8];
let (mut sut, mut other) = UnixStream::pair().expect("Can't create UnixStream's");
tokio::spawn(async move {
other.enqueue(&file1).expect("Can't enqueue fd.");
other
.write_all(b"1".as_ref())
.await
.expect("Can't write to UnixStream");
});
sut.read_exact(buf.as_mut())
.await
.expect("Can't read from UnixStream");
let fd = sut.dequeue().expect("Can't dequeue fd");
let mut file2 = unsafe { File::from_raw_fd(fd) };
let mut buf2 = Vec::new();
file2.read_to_end(&mut buf2).expect("Can't read from file");
assert_eq!(&buf2[..], b"Hello World!\0".as_ref());
}
#[tokio::test]
async fn unix_stream_connects_to_listner() {
let dir = tempdir().expect("Can't create temp dir");
let sock_addr = dir.as_ref().join("socket");
let mut buf: [u8; 12] = [0; 12];
let mut listener = UnixListener::bind(&sock_addr).expect("Can't bind listener");
tokio::spawn(async move {
let mut client = UnixStream::connect(sock_addr)
.await
.expect("Can't connect to listener");
client
.write_all(b"Hello World!".as_ref())
.await
.expect("Can't write to client");
});
let (mut server, _) = listener.accept().await.expect("Can't accept on listener");
server
.read_exact(buf.as_mut())
.await
.expect("Can't read from server");
assert_eq!(&buf, b"Hello World!");
}
}