use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net;
use std::{alloc, io, mem, ptr};
pub mod changelog;
pub trait SendWithFd {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
}
pub trait RecvWithFd {
fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)>;
}
unsafe fn ptr_offset_from(this: *const u8, origin: *const u8) -> isize {
isize::wrapping_sub(this as _, origin as _)
}
unsafe fn construct_msghdr_for(
iov: &mut libc::iovec,
fd_count: usize,
) -> (libc::msghdr, alloc::Layout, usize) {
let fd_len = mem::size_of::<RawFd>() * fd_count;
let cmsg_buffer_len = libc::CMSG_SPACE(fd_len as u32) as usize;
let layout = alloc::Layout::from_size_align(cmsg_buffer_len, mem::align_of::<libc::cmsghdr>());
let (cmsg_buffer, cmsg_layout) = if let Ok(layout) = layout {
const NULL_MUT_U8: *mut u8 = ptr::null_mut();
match alloc::alloc(layout) {
NULL_MUT_U8 => alloc::handle_alloc_error(layout),
x => (x as *mut _, layout),
}
} else {
alloc::handle_alloc_error(alloc::Layout::from_size_align_unchecked(
cmsg_buffer_len,
mem::align_of::<libc::cmsghdr>(),
))
};
(
libc::msghdr {
msg_name: ptr::null_mut(),
msg_namelen: 0,
msg_iov: iov as *mut _,
msg_iovlen: 1,
msg_control: cmsg_buffer,
msg_controllen: cmsg_buffer_len as _,
..mem::zeroed()
},
cmsg_layout,
fd_len,
)
}
fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
unsafe {
let mut iov = libc::iovec {
iov_base: bs.as_ptr() as *const _ as *mut _,
iov_len: bs.len(),
};
let (mut msghdr, cmsg_layout, fd_len) = construct_msghdr_for(&mut iov, fds.len());
let cmsg_buffer = msghdr.msg_control;
let cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
ptr::write(
cmsg_header,
libc::cmsghdr {
cmsg_level: libc::SOL_SOCKET,
cmsg_type: libc::SCM_RIGHTS,
cmsg_len: libc::CMSG_LEN(fd_len as u32) as _,
},
);
#[allow(clippy::cast_ptr_alignment)]
let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
for (i, fd) in fds.iter().enumerate() {
ptr::write_unaligned(cmsg_data.add(i), *fd);
}
let count = libc::sendmsg(socket, &msghdr as *const _, 0);
if count < 0 {
let error = io::Error::last_os_error();
alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
Err(error)
} else {
alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
Ok(count as usize)
}
}
}
fn recv_with_fd(socket: RawFd, bs: &mut [u8], mut fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
unsafe {
let mut iov = libc::iovec {
iov_base: bs.as_mut_ptr() as *mut _,
iov_len: bs.len(),
};
let (mut msghdr, cmsg_layout, _) = construct_msghdr_for(&mut iov, fds.len());
let cmsg_buffer = msghdr.msg_control;
let count = libc::recvmsg(socket, &mut msghdr as *mut _, 0);
if count < 0 {
let error = io::Error::last_os_error();
alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
return Err(error);
}
let mut descriptor_count = 0;
let mut cmsg_header = libc::CMSG_FIRSTHDR(&mut msghdr as *mut _);
while !cmsg_header.is_null() {
if (*cmsg_header).cmsg_level == libc::SOL_SOCKET
&& (*cmsg_header).cmsg_type == libc::SCM_RIGHTS
{
let data_ptr = libc::CMSG_DATA(cmsg_header);
let data_offset = ptr_offset_from(data_ptr, cmsg_header as *const _);
debug_assert!(data_offset >= 0);
let data_byte_count = (*cmsg_header).cmsg_len as usize - data_offset as usize;
debug_assert!((*cmsg_header).cmsg_len as isize > data_offset);
debug_assert!(data_byte_count % mem::size_of::<RawFd>() == 0);
let rawfd_count = (data_byte_count / mem::size_of::<RawFd>()) as isize;
#[allow(clippy::cast_ptr_alignment)]
let fd_ptr = data_ptr as *const RawFd;
for i in 0..rawfd_count {
if let Some((dst, rest)) = { fds }.split_first_mut() {
*dst = ptr::read_unaligned(fd_ptr.offset(i));
descriptor_count += 1;
fds = rest;
} else {
unreachable!();
}
}
}
cmsg_header = libc::CMSG_NXTHDR(&mut msghdr as *mut _, cmsg_header);
}
alloc::dealloc(cmsg_buffer as *mut _, cmsg_layout);
Ok((count as usize, descriptor_count))
}
}
impl SendWithFd for net::UnixStream {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}
}
impl SendWithFd for net::UnixDatagram {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}
}
impl RecvWithFd for net::UnixStream {
fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
recv_with_fd(self.as_raw_fd(), bytes, fds)
}
}
impl RecvWithFd for net::UnixDatagram {
fn recv_with_fd(&self, bytes: &mut [u8], fds: &mut [RawFd]) -> io::Result<(usize, usize)> {
recv_with_fd(self.as_raw_fd(), bytes, fds)
}
}
#[cfg(test)]
mod tests {
use super::{RecvWithFd, SendWithFd};
use std::os::unix::io::{AsRawFd, FromRawFd};
use std::os::unix::net;
#[test]
fn stream_works() {
let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
let sent_bytes = b"hello world!";
let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
assert_eq!(
l.send_with_fd(&sent_bytes[..], &sent_fds[..])
.expect("send should be successful"),
sent_bytes.len()
);
let mut recv_bytes = [0; 128];
let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(
r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
.expect("recv should be successful"),
(sent_bytes.len(), sent_fds.len())
);
assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
let expected_value = Some(std::time::Duration::from_secs(42));
unsafe {
let s = net::UnixStream::from_raw_fd(sent);
s.set_read_timeout(expected_value)
.expect("set read timeout");
std::mem::forget(s);
assert_eq!(
net::UnixStream::from_raw_fd(recvd)
.read_timeout()
.expect("get read timeout"),
expected_value
);
}
}
}
#[test]
fn datagram_works() {
let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
let sent_bytes = b"hello world!";
let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
assert_eq!(
l.send_with_fd(&sent_bytes[..], &sent_fds[..])
.expect("send should be successful"),
sent_bytes.len()
);
let mut recv_bytes = [0; 128];
let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
assert_eq!(
r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
.expect("recv should be successful"),
(sent_bytes.len(), sent_fds.len())
);
assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
let expected_value = Some(std::time::Duration::from_secs(42));
unsafe {
let s = net::UnixDatagram::from_raw_fd(sent);
s.set_read_timeout(expected_value)
.expect("set read timeout");
std::mem::forget(s);
assert_eq!(
net::UnixDatagram::from_raw_fd(recvd)
.read_timeout()
.expect("get read timeout"),
expected_value
);
}
}
}
#[test]
fn datagram_works_across_processes() {
let (l, r) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
let sent_bytes = b"hello world!";
let sent_fds = [l.as_raw_fd(), r.as_raw_fd()];
unsafe {
match libc::fork() {
-1 => panic!("fork failed!"),
0 => {
l.send_with_fd(&sent_bytes[..], &sent_fds[..])
.expect("send should be successful");
::std::process::exit(0);
}
_ => {
}
}
let mut recv_bytes = [0; 128];
let mut recv_fds = [0, 0, 0, 0, 0, 0, 0];
assert_eq!(
r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
.expect("recv should be successful"),
(sent_bytes.len(), sent_fds.len())
);
assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
let expected_value = Some(std::time::Duration::from_secs(42));
let s = net::UnixDatagram::from_raw_fd(sent);
s.set_read_timeout(expected_value)
.expect("set read timeout");
std::mem::forget(s);
assert_eq!(
net::UnixDatagram::from_raw_fd(recvd)
.read_timeout()
.expect("get read timeout"),
expected_value
);
}
}
}
#[test]
fn sending_junk_fails() {
let (l, _) = net::UnixDatagram::pair().expect("create UnixDatagram pair");
let sent_bytes = b"hello world!";
if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[i32::max_value()][..]) {
panic!("expected an error when sending a junk file descriptor");
}
if let Ok(_) = l.send_with_fd(&sent_bytes[..], &[0xffi32][..]) {
panic!("expected an error when sending a junk file descriptor");
}
}
}