use std::{
collections::VecDeque,
fmt,
io::{self, prelude::*, Error, ErrorKind, IoSlice, IoSliceMut},
iter,
net::Shutdown,
os::unix::{
io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
net::{SocketAddr, UnixListener as StdUnixListner, UnixStream as StdUnixStream},
},
path::Path,
};
use std::usize;
use iomsg::{cmsg_buffer_fds_space, Fd, MsgHdr};
use tracing::{trace, warn};
use crate::{DequeueFd, EnqueueFd, QueueFullError};
mod iomsg;
#[derive(Debug)]
pub struct UnixStream {
inner: StdUnixStream,
infd: VecDeque<Fd>,
outfd: Option<Vec<RawFd>>,
}
#[derive(Debug)]
pub struct UnixListener {
inner: StdUnixListner,
}
#[derive(Debug)]
pub struct Incoming<'a> {
listener: &'a UnixListener,
}
#[derive(Debug)]
struct CMsgTruncatedError {}
#[derive(Debug)]
struct PushFailureError {}
trait Push<A> {
fn push(&mut self, item: A) -> Result<(), A>;
}
impl UnixStream {
pub const FD_QUEUE_SIZE: usize = 2;
pub fn connect<P: AsRef<Path>>(path: P) -> io::Result<UnixStream> {
StdUnixStream::connect(path).map(|s| s.into())
}
pub fn pair() -> io::Result<(UnixStream, UnixStream)> {
StdUnixStream::pair().map(|(s1, s2)| (s1.into(), s2.into()))
}
pub fn try_clone(&self) -> io::Result<UnixStream> {
self.inner.try_clone().map(|s| s.into())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
pub fn take_error(&self) -> io::Result<Option<Error>> {
self.inner.take_error()
}
pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.inner.shutdown(how)
}
#[allow(dead_code)]
pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
}
fn send_fds(
sockfd: RawFd,
bufs: &[IoSlice],
fds: impl Iterator<Item = RawFd>,
) -> io::Result<usize> {
debug_assert_eq!(
constants::CMSG_SCM_RIGHTS_SPACE as usize,
cmsg_buffer_fds_space(constants::MAX_FD_COUNT)
);
assert!(UnixStream::FD_QUEUE_SIZE <= constants::MAX_FD_COUNT);
let mut cmsg_buffer = [0u8; constants::CMSG_SCM_RIGHTS_SPACE as _];
let counts = MsgHdr::from_io_slice(bufs, &mut cmsg_buffer)
.encode_fds(fds)?
.send(sockfd)?;
trace!(
source = "UnixStream",
event = "write",
fds_count = counts.fds_sent(),
byte_count = counts.bytes_sent(),
);
Ok(counts.bytes_sent())
}
fn recv_fds(
sockfd: RawFd,
bufs: &mut [IoSliceMut],
fds_sink: &mut impl Push<Fd>,
) -> io::Result<usize> {
debug_assert_eq!(
constants::CMSG_SCM_RIGHTS_SPACE as usize,
cmsg_buffer_fds_space(constants::MAX_FD_COUNT)
);
let mut cmsg_buffer = [0u8; constants::CMSG_SCM_RIGHTS_SPACE as _];
let mut recv = MsgHdr::from_io_slice_mut(bufs, &mut cmsg_buffer).recv(sockfd)?;
let mut fds_count = 0;
for fd in recv.take_fds() {
match fds_sink.push(fd) {
Ok(_) => fds_count += 1,
Err(_) => {
warn!(
source = "UnixStream",
event = "read",
condition = "too many fds received"
);
return Err(PushFailureError::new());
}
}
}
if recv.was_control_truncated() {
warn!(
source = "UnixStream",
event = "read",
condition = "cmsgs truncated"
);
Err(CMsgTruncatedError::new())
} else {
trace!(
source = "UnixStream",
event = "read",
fds_count,
byte_count = recv.bytes_recvieved(),
);
Ok(recv.bytes_recvieved())
}
}
impl EnqueueFd for UnixStream {
fn enqueue(&mut self, fd: &impl AsRawFd) -> std::result::Result<(), QueueFullError> {
let outfd = self
.outfd
.get_or_insert_with(|| Vec::with_capacity(Self::FD_QUEUE_SIZE));
if outfd.len() >= Self::FD_QUEUE_SIZE {
warn!(source = "UnixStream", event = "enqueue", condition = "full");
Err(QueueFullError::new())
} else {
trace!(source = "UnixStream", event = "enqueue", count = 1);
outfd.push(fd.as_raw_fd());
Ok(())
}
}
}
impl DequeueFd for UnixStream {
fn dequeue(&mut self) -> Option<RawFd> {
let result = self.infd.pop_front();
trace!(
source = "UnixStream",
event = "dequeue",
count = if result.is_some() { 1 } else { 0 }
);
result.map(|fd| fd.into_raw_fd())
}
}
impl Read for UnixStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.read_vectored(&mut [IoSliceMut::new(buf)])
}
fn read_vectored(&mut self, bufs: &mut [IoSliceMut]) -> io::Result<usize> {
recv_fds(self.as_raw_fd(), bufs, &mut self.infd)
}
}
impl Write for UnixStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_vectored(&[IoSlice::new(buf)])
}
fn write_vectored(&mut self, bufs: &[IoSlice]) -> io::Result<usize> {
let outfd = self.outfd.take();
match outfd {
Some(mut fds) => send_fds(self.as_raw_fd(), bufs, fds.drain(..)),
None => send_fds(self.as_raw_fd(), bufs, iter::empty()),
}
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl AsRawFd for UnixStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl FromRawFd for UnixStream {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
StdUnixStream::from_raw_fd(fd).into()
}
}
impl IntoRawFd for UnixStream {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
impl From<StdUnixStream> for UnixStream {
fn from(inner: StdUnixStream) -> Self {
Self {
inner,
infd: VecDeque::with_capacity(Self::FD_QUEUE_SIZE),
outfd: None,
}
}
}
impl Push<Fd> for VecDeque<Fd> {
fn push(&mut self, item: Fd) -> Result<(), Fd> {
self.push_back(item);
Ok(())
}
}
impl UnixListener {
pub fn bind(path: impl AsRef<Path>) -> io::Result<UnixListener> {
StdUnixListner::bind(path).map(|s| s.into())
}
pub fn accept(&self) -> io::Result<(UnixStream, SocketAddr)> {
self.inner.accept().map(|(s, a)| (s.into(), a))
}
pub fn try_clone(&self) -> io::Result<UnixListener> {
self.inner.try_clone().map(|s| s.into())
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn take_error(&self) -> io::Result<Option<Error>> {
self.inner.take_error()
}
pub fn incoming(&self) -> Incoming {
Incoming { listener: self }
}
pub(crate) fn set_nonblocking(&self, nonblocking: bool) -> io::Result<()> {
self.inner.set_nonblocking(nonblocking)
}
}
impl AsRawFd for UnixListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl FromRawFd for UnixListener {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
StdUnixListner::from_raw_fd(fd).into()
}
}
impl IntoRawFd for UnixListener {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
impl<'a> IntoIterator for &'a UnixListener {
type Item = io::Result<UnixStream>;
type IntoIter = Incoming<'a>;
fn into_iter(self) -> Self::IntoIter {
self.incoming()
}
}
impl From<StdUnixListner> for UnixListener {
fn from(inner: StdUnixListner) -> Self {
UnixListener { inner }
}
}
impl Iterator for Incoming<'_> {
type Item = io::Result<UnixStream>;
fn next(&mut self) -> Option<Self::Item> {
Some(self.listener.accept().map(|(s, _)| s))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::MAX, None)
}
}
impl CMsgTruncatedError {
fn new() -> Error {
Error::new(ErrorKind::Other, CMsgTruncatedError {})
}
}
impl fmt::Display for CMsgTruncatedError {
fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
write!(
f,
"The buffer used to receive file descriptors was too small."
)
}
}
impl std::error::Error for CMsgTruncatedError {}
impl PushFailureError {
fn new() -> Error {
Error::new(ErrorKind::Other, PushFailureError {})
}
}
impl fmt::Display for PushFailureError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"The sink for receiving file descriptors was unexpectedly full."
)
}
}
impl std::error::Error for PushFailureError {}
mod constants {
include!(concat!(env!("OUT_DIR"), "/constants.rs"));
}
#[cfg(test)]
mod test {
use super::*;
use std::convert::AsMut;
use std::ffi::c_void;
use std::ptr;
use std::slice;
use nix::fcntl::OFlag;
use nix::sys::mman::{mmap, munmap, shm_open, shm_unlink, MapFlags, ProtFlags};
use nix::sys::stat::Mode;
use nix::unistd::{close, ftruncate};
struct Shm {
fd: RawFd,
ptr: *mut u8,
len: usize,
name: String,
}
impl Shm {
fn new(name: &str, size: i64) -> Shm {
let oflag = OFlag::O_CREAT | OFlag::O_RDWR;
let fd =
shm_open(name, oflag, Mode::S_IRUSR | Mode::S_IWUSR).expect("Can't create shm.");
ftruncate(fd, size).expect("Can't ftruncate");
let len: usize = size as usize;
let prot = ProtFlags::PROT_READ | ProtFlags::PROT_WRITE;
let flags = MapFlags::MAP_SHARED;
let ptr = unsafe {
mmap(ptr::null_mut(), len, prot, flags, fd, 0).expect("Can't mmap") as *mut u8
};
Shm {
fd,
ptr,
len,
name: name.to_string(),
}
}
fn from_raw_fd(fd: RawFd, size: usize) -> Shm {
let prot = ProtFlags::PROT_READ | ProtFlags::PROT_WRITE;
let flags = MapFlags::MAP_SHARED;
let ptr = unsafe {
mmap(ptr::null_mut(), size, prot, flags, fd, 0).expect("Can't mmap") as *mut u8
};
Shm {
fd,
ptr,
len: size,
name: String::new(),
}
}
}
impl Drop for Shm {
fn drop(&mut self) {
unsafe {
munmap(self.ptr as *mut c_void, self.len).expect("Can't munmap");
}
close(self.fd).expect("Can't close");
if !self.name.is_empty() {
let name: &str = self.name.as_ref();
shm_unlink(name).expect("Can't shm_unlink");
}
}
}
impl AsMut<[u8]> for Shm {
fn as_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl AsRawFd for Shm {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
fn make_hello(name: &str) -> Shm {
let hello = b"Hello World!\0";
let mut shm = Shm::new(name, hello.len() as i64);
shm.as_mut().copy_from_slice(hello.as_ref());
shm
}
fn compare_hello(fd: RawFd) -> bool {
let hello = b"Hello World!\0";
let mut shm = Shm::from_raw_fd(fd, hello.len());
&shm.as_mut()[..hello.len()] == hello.as_ref()
}
#[test]
fn unix_stream_passes_fd() {
let shm = make_hello("/unix_stream_passes_fd");
let mut buf = vec![0; 20];
let (mut sut1, mut sut2) = UnixStream::pair().expect("Can't make pair");
sut1.enqueue(&shm).expect("Can't enqueue");
sut1.write(b"abc").expect("Can't write");
sut1.flush().expect("Can't flush");
sut2.read(&mut buf).expect("Can't read");
let fd = sut2.dequeue().expect("Empty fd queue");
assert!(fd != shm.fd, "fd's unexpectedly equal");
assert!(compare_hello(fd), "fd didn't contain expect contents");
}
}