use std::{
io::{self, prelude::*, Error, IoSlice, IoSliceMut},
net::Shutdown,
os::unix::{
io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
net::{SocketAddr, UnixListener as StdUnixListner, UnixStream as StdUnixStream},
},
path::Path,
};
use std::usize;
use crate::biqueue::BiQueue;
use crate::{DequeueFd, EnqueueFd, QueueFullError};
#[derive(Debug)]
pub struct UnixStream {
inner: StdUnixStream,
biqueue: BiQueue,
}
#[derive(Debug)]
pub struct UnixListener {
inner: StdUnixListner,
}
#[derive(Debug)]
pub struct Incoming<'a> {
listener: &'a UnixListener,
}
impl UnixStream {
pub const FD_QUEUE_SIZE: usize = BiQueue::FD_QUEUE_SIZE;
pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
StdUnixStream::connect(path).map(|s| s.into())
}
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
StdUnixStream::pair().map(|(s1, s2)| (s1.into(), s2.into()))
}
pub fn try_clone(&self) -> io::Result<UnixStream> {
self.inner.try_clone().map(|s| s.into())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
pub fn take_error(&self) -> io::Result<Option<Error>> {
self.inner.take_error()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.shutdown(how)
}
#[allow(dead_code)]
pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
}
impl EnqueueFd for UnixStream {
fn enqueue(&mut self, fd: &impl AsRawFd) -> std::result::Result<(), QueueFullError> {
self.biqueue.enqueue(fd)
}
}
impl DequeueFd for UnixStream {
fn dequeue(&mut self) -> Option<RawFd> {
self.biqueue.dequeue()
}
}
impl Read for UnixStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_vectored(&mut [IoSliceMut::new(buf)])
}
fn read_vectored(&mut self, bufs: &mut [IoSliceMut]) -> io::Result<usize> {
self.biqueue.read_vectored(self.as_raw_fd(), bufs)
}
}
impl Write for UnixStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_vectored(&[IoSlice::new(buf)])
}
fn write_vectored(&mut self, bufs: &[IoSlice]) -> io::Result<usize> {
self.biqueue.write_vectored(self.as_raw_fd(), bufs)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl FromRawFd for UnixStream {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
StdUnixStream::from_raw_fd(fd).into()
}
}
impl IntoRawFd for UnixStream {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
impl From<StdUnixStream> for UnixStream {
fn from(inner: StdUnixStream) -> Self {
Self {
inner,
biqueue: BiQueue::new(),
}
}
}
impl UnixListener {
pub fn bind(path: impl AsRef<Path>) -> io::Result<UnixListener> {
StdUnixListner::bind(path).map(|s| s.into())
}
pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
self.inner.accept().map(|(s, a)| (s.into(), a))
}
pub fn try_clone(&self) -> io::Result<UnixListener> {
self.inner.try_clone().map(|s| s.into())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn take_error(&self) -> io::Result<Option<Error>> {
self.inner.take_error()
}
pub fn incoming(&self) -> Incoming {
Incoming { listener: self }
}
}
impl AsRawFd for UnixListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl FromRawFd for UnixListener {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
StdUnixListner::from_raw_fd(fd).into()
}
}
impl IntoRawFd for UnixListener {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
impl<'a> IntoIterator for &'a UnixListener {
type Item = io::Result<UnixStream>;
type IntoIter = Incoming<'a>;
fn into_iter(self) -> Self::IntoIter {
self.incoming()
}
}
impl From<StdUnixListner> for UnixListener {
fn from(inner: StdUnixListner) -> Self {
UnixListener { inner }
}
}
impl Iterator for Incoming<'_> {
type Item = io::Result<UnixStream>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.listener.accept().map(|(s, _)| s))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::MAX, None)
}
}
#[cfg(test)]
mod test {
use super::*;
use std::convert::AsMut;
use std::ffi::c_void;
use std::ptr;
use std::slice;
use nix::fcntl::OFlag;
use nix::sys::mman::{mmap, munmap, shm_open, shm_unlink, MapFlags, ProtFlags};
use nix::sys::stat::Mode;
use nix::unistd::{close, ftruncate};
struct Shm {
fd: RawFd,
ptr: *mut u8,
len: usize,
name: String,
}
impl Shm {
fn new(name: &str, size: i64) -> Shm {
let oflag = OFlag::O_CREAT | OFlag::O_RDWR;
let fd =
shm_open(name, oflag, Mode::S_IRUSR | Mode::S_IWUSR).expect("Can't create shm.");
ftruncate(fd, size).expect("Can't ftruncate");
let len: usize = size as usize;
let prot = ProtFlags::PROT_READ | ProtFlags::PROT_WRITE;
let flags = MapFlags::MAP_SHARED;
let ptr = unsafe {
mmap(ptr::null_mut(), len, prot, flags, fd, 0).expect("Can't mmap") as *mut u8
};
Shm {
fd,
ptr,
len,
name: name.to_string(),
}
}
fn from_raw_fd(fd: RawFd, size: usize) -> Shm {
let prot = ProtFlags::PROT_READ | ProtFlags::PROT_WRITE;
let flags = MapFlags::MAP_SHARED;
let ptr = unsafe {
mmap(ptr::null_mut(), size, prot, flags, fd, 0).expect("Can't mmap") as *mut u8
};
Shm {
fd,
ptr,
len: size,
name: String::new(),
}
}
}
impl Drop for Shm {
fn drop(&mut self) {
unsafe {
munmap(self.ptr as *mut c_void, self.len).expect("Can't munmap");
}
close(self.fd).expect("Can't close");
if !self.name.is_empty() {
let name: &str = self.name.as_ref();
shm_unlink(name).expect("Can't shm_unlink");
}
}
}
impl AsMut<[u8]> for Shm {
fn as_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl AsRawFd for Shm {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
fn make_hello(name: &str) -> Shm {
let hello = b"Hello World!\0";
let mut shm = Shm::new(name, hello.len() as i64);
shm.as_mut().copy_from_slice(hello.as_ref());
shm
}
fn compare_hello(fd: RawFd) -> bool {
let hello = b"Hello World!\0";
let mut shm = Shm::from_raw_fd(fd, hello.len());
&shm.as_mut()[..hello.len()] == hello.as_ref()
}
#[test]
fn unix_stream_passes_fd() {
let shm = make_hello("/unix_stream_passes_fd");
let mut buf = vec![0; 20];
let (mut sut1, mut sut2) = UnixStream::pair().expect("Can't make pair");
sut1.enqueue(&shm).expect("Can't enqueue");
sut1.write(b"abc").expect("Can't write");
sut1.flush().expect("Can't flush");
sut2.read(&mut buf).expect("Can't read");
let fd = sut2.dequeue().expect("Empty fd queue");
assert!(fd != shm.fd, "fd's unexpectedly equal");
assert!(compare_hello(fd), "fd didn't contain expect contents");
}
}