use std::{
convert::TryFrom,
io::{ErrorKind, IoSlice, IoSliceMut},
net::Shutdown,
os::unix::{
io::{AsRawFd, RawFd},
net::{SocketAddr, UnixStream as StdUnixStream},
},
path::Path,
pin::Pin,
task::{Context, Poll},
};
use futures_core::stream::Stream;
use futures_util::ready;
use pin_project::pin_project;
use tokio::{
io::{self, AsyncRead, AsyncWrite, Interest, ReadBuf},
net::{
unix::SocketAddr as TokioSocketAddr, UnixListener as TokioUnixListener,
UnixStream as TokioUnixStream,
},
};
use crate::{biqueue::BiQueue, DequeueFd, EnqueueFd, QueueFullError};
#[pin_project]
#[derive(Debug)]
pub struct UnixStream {
#[pin]
inner: TokioUnixStream,
biqueue: BiQueue,
}
#[derive(Debug)]
pub struct UnixListener {
inner: TokioUnixListener,
}
impl UnixStream {
pub async fn connect(path: impl AsRef<Path>) -> io::Result<UnixStream> {
TokioUnixStream::connect(path).await.map(|s| s.into())
}
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
TokioUnixStream::pair().map(|(s1, s2)| (s1.into(), s2.into()))
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
to_addr(self.inner.local_addr()?)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
to_addr(self.inner.peer_addr()?)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.take_error()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
shutdown(self, how)
}
}
impl EnqueueFd for UnixStream {
fn enqueue(&mut self, fd: &impl AsRawFd) -> Result<(), QueueFullError> {
self.biqueue.enqueue(fd)
}
}
impl DequeueFd for UnixStream {
fn dequeue(&mut self) -> Option<RawFd> {
self.biqueue.dequeue()
}
}
impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl AsyncRead for UnixStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let this = self.project();
let inner = this.inner;
let biqueue = this.biqueue;
let fd = inner.as_raw_fd();
loop {
ready!(inner.poll_read_ready(cx))?;
match inner.try_io(Interest::READABLE, || {
biqueue.read_vectored(fd, &mut [IoSliceMut::new(buf.initialize_unfilled())])
}) {
Ok(count) => {
buf.advance(count);
return Poll::Ready(Ok(()));
}
Err(e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
}
impl AsyncWrite for UnixStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.poll_write_vectored(cx, &[IoSlice::new(buf)])
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.project().inner.poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
let this = self.project();
let inner = this.inner;
let biqueue = this.biqueue;
let fd = inner.as_raw_fd();
loop {
ready!(inner.poll_write_ready(cx))?;
match inner.try_io(Interest::WRITABLE, || biqueue.write_vectored(fd, bufs)) {
Ok(count) => return Poll::Ready(Ok(count)),
Err(e) if e.kind() == ErrorKind::WouldBlock => {}
Err(e) => return Poll::Ready(Err(e)),
}
}
}
fn is_write_vectored(&self) -> bool {
true
}
}
impl From<TokioUnixStream> for UnixStream {
fn from(inner: TokioUnixStream) -> UnixStream {
UnixStream {
inner,
biqueue: BiQueue::new(),
}
}
}
impl TryFrom<StdUnixStream> for UnixStream {
type Error = io::Error;
fn try_from(inner: StdUnixStream) -> Result<Self, Self::Error> {
inner.set_nonblocking(true)?;
TokioUnixStream::from_std(inner).map(|stream| stream.into())
}
}
impl UnixListener {
pub fn bind(path: impl AsRef<Path>) -> io::Result<UnixListener> {
TokioUnixListener::bind(path).map(|l| l.into())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
to_addr(self.inner.local_addr()?)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.inner.take_error()
}
pub async fn accept(&mut self) -> io::Result<(UnixStream, SocketAddr)> {
self.inner
.accept()
.await
.and_then(|(stream, addr)| to_addr(addr).map(|addr| (stream.into(), addr)))
}
fn poll_accept(&self, cx: &mut Context) -> Poll<io::Result<(UnixStream, SocketAddr)>> {
self.inner.poll_accept(cx).map(|result| {
result.and_then(|(stream, addr)| to_addr(addr).map(|addr| (stream.into(), addr)))
})
}
}
impl AsRawFd for UnixListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl Stream for UnixListener {
type Item = io::Result<UnixStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
use Poll::{Pending, Ready};
match self.poll_accept(cx) {
Pending => Pending,
Ready(Ok((stream, _))) => Ready(Some(Ok(stream))),
Ready(Err(err)) => Ready(Some(Err(err))),
}
}
}
impl From<TokioUnixListener> for UnixListener {
fn from(inner: TokioUnixListener) -> UnixListener {
UnixListener { inner }
}
}
fn to_addr(addr: TokioSocketAddr) -> io::Result<SocketAddr> {
addr.as_pathname()
.map_or(SocketAddr::from_pathname(""), |path| {
SocketAddr::from_pathname(path)
})
}
fn shutdown(socket: &impl AsRawFd, how: Shutdown) -> io::Result<()> {
let how = match how {
Shutdown::Write => libc::SHUT_WR,
Shutdown::Read => libc::SHUT_RD,
Shutdown::Both => libc::SHUT_RDWR,
};
let code = unsafe { libc::shutdown(socket.as_raw_fd(), how) };
if code == -1 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::File;
use std::io::{prelude::*, SeekFrom};
use std::os::unix::io::FromRawFd as _;
use tempfile::{tempdir, tempfile};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[tokio::test]
async fn unix_stream_reads_other_sides_writes() {
let mut buf: [u8; 12] = [0; 12];
let (mut sut, mut other) = UnixStream::pair().expect("Can't create UnixStream's");
tokio::spawn(async move {
other
.write_all(b"Hello World!".as_ref())
.await
.expect("Can't write to UnixStream");
});
sut.read_exact(buf.as_mut())
.await
.expect("Can't read from UnixStream");
assert_eq!(&buf, b"Hello World!");
}
#[tokio::test]
async fn unix_stream_passes_fd() {
let mut file1 = tempfile().expect("Can't create temp file.");
file1
.write_all(b"Hello World!\0")
.expect("Can't write to temp file.");
file1.flush().expect("Can't flush temp file.");
file1
.seek(SeekFrom::Start(0))
.expect("Couldn't seek the file.");
let mut buf = [0u8];
let (mut sut, mut other) = UnixStream::pair().expect("Can't create UnixStream's");
tokio::spawn(async move {
other.enqueue(&file1).expect("Can't enqueue fd.");
other
.write_all(b"1".as_ref())
.await
.expect("Can't write to UnixStream");
});
sut.read_exact(buf.as_mut())
.await
.expect("Can't read from UnixStream");
let fd = sut.dequeue().expect("Can't dequeue fd");
let mut file2 = unsafe { File::from_raw_fd(fd) };
let mut buf2 = Vec::new();
file2.read_to_end(&mut buf2).expect("Can't read from file");
assert_eq!(&buf2[..], b"Hello World!\0".as_ref());
}
#[tokio::test]
async fn unix_stream_connects_to_listner() {
let dir = tempdir().expect("Can't create temp dir");
let sock_addr = dir.as_ref().join("socket");
let mut buf: [u8; 12] = [0; 12];
let mut listener = UnixListener::bind(&sock_addr).expect("Can't bind listener");
tokio::spawn(async move {
let mut client = UnixStream::connect(sock_addr)
.await
.expect("Can't connect to listener");
client
.write_all(b"Hello World!".as_ref())
.await
.expect("Can't write to client");
});
let (mut server, _) = listener.accept().await.expect("Can't accept on listener");
server
.read_exact(buf.as_mut())
.await
.expect("Can't read from server");
assert_eq!(&buf, b"Hello World!");
}
}