use std::{
convert::TryInto,
error, fmt,
io::{self, IoSlice, IoSliceMut},
marker::PhantomData,
mem,
ops::Neg,
os::unix::io::{IntoRawFd, RawFd},
ptr::{self, NonNull},
slice,
};
use std::isize;
use libc::{
c_int, c_uint, close, cmsghdr, iovec, msghdr, recvmsg, sendmsg, CMSG_DATA, CMSG_FIRSTHDR,
CMSG_LEN, CMSG_NXTHDR, CMSG_SPACE, MSG_CTRUNC, SCM_RIGHTS, SOL_SOCKET,
};
use num_traits::One;
#[derive(Debug)]
pub struct MsgHdr<'a, State> {
mhdr: msghdr,
state: State,
_phantom: PhantomData<(&'a mut [iovec], &'a mut [u8])>,
}
trait NullableControl {}
#[derive(Debug, Default)]
pub struct RecvStart {}
#[derive(Debug)]
pub struct MsgHdrRecvEnd<'a> {
mhdr: msghdr,
bytes_recvieved: usize,
fds_taken: bool,
_phantom: PhantomData<(&'a mut [iovec], &'a mut [u8])>,
}
#[derive(Debug, Default)]
pub struct SendStart {}
#[derive(Debug)]
pub struct SendReady {
fds_count: usize,
}
impl NullableControl for SendReady {}
#[derive(Debug)]
pub struct SendEnd {
bytes_sent: usize,
fds_sent: usize,
}
impl NullableControl for SendEnd {}
struct FdsIter<'a> {
mhdr: &'a msghdr,
cmsg: Option<&'a cmsghdr>,
data: Option<FdsIterData>,
}
struct FdsIterData {
curr: *const RawFd,
end: *const RawFd,
}
struct CMsgMut<'a> {
cmsg: NonNull<cmsghdr>,
_phantom: PhantomData<&'a mut msghdr>,
}
#[derive(Debug)]
pub struct Fd {
fd: Option<RawFd>,
}
impl<'a, State: Default> MsgHdr<'a, State> {
unsafe fn new(iov: *mut iovec, iov_len: usize, cmsg_buffer: &'a mut [u8]) -> Self {
let mhdr = {
let mut mhdr = mem::MaybeUninit::<libc::msghdr>::zeroed();
let p = mhdr.as_mut_ptr();
(*p).msg_name = ptr::null_mut();
(*p).msg_namelen = 0;
(*p).msg_iov = iov;
(*p).msg_iovlen = iov_len;
(*p).msg_control = cmsg_buffer.as_mut_ptr() as _;
(*p).msg_controllen = cmsg_buffer.len();
(*p).msg_flags = 0;
mhdr.assume_init()
};
Self {
mhdr,
state: Default::default(),
_phantom: PhantomData,
}
}
}
impl<'a> MsgHdr<'a, RecvStart> {
pub fn from_io_slice_mut(bufs: &'a mut [IoSliceMut], cmsg_buffer: &'a mut [u8]) -> Self {
let iov: *mut iovec = bufs.as_mut_ptr() as *mut iovec;
let iov_len = bufs.len();
unsafe { Self::new(iov, iov_len, cmsg_buffer) }
}
pub fn recv(mut self, sockfd: RawFd) -> io::Result<MsgHdrRecvEnd<'a>> {
let count =
call_res(|| unsafe { recvmsg(sockfd, &mut self.mhdr, 0) }).map(|c| c as usize)?;
Ok(MsgHdrRecvEnd {
mhdr: self.mhdr,
bytes_recvieved: count,
fds_taken: false,
_phantom: PhantomData,
})
}
}
impl<'a> MsgHdrRecvEnd<'a> {
pub fn bytes_recvieved(&self) -> usize {
self.bytes_recvieved
}
pub fn was_control_truncated(&self) -> bool {
self.mhdr.msg_flags & MSG_CTRUNC != 0
}
pub fn take_fds<'b>(&'b mut self) -> impl Iterator<Item = Fd> + 'b {
if self.fds_taken {
FdsIter::empty(&self.mhdr)
} else {
self.fds_taken = true;
unsafe { FdsIter::new(&self.mhdr) }
}
}
}
impl<'a> Drop for MsgHdrRecvEnd<'a> {
fn drop(&mut self) {
if !self.fds_taken {
drop(self.take_fds());
}
}
}
impl<'a> MsgHdr<'a, SendStart> {
pub fn from_io_slice(bufs: &'a [IoSlice], cmsg_buffer: &'a mut [u8]) -> Self {
let iov: *mut iovec = bufs.as_ptr() as *mut iovec;
let iov_len = bufs.len();
unsafe { Self::new(iov, iov_len, cmsg_buffer) }
}
pub fn encode_fds(
mut self,
fds: impl Iterator<Item = RawFd>,
) -> io::Result<MsgHdr<'a, SendReady>> {
let count = match unsafe { CMsgMut::first_cmsg(&mut self.mhdr, SOL_SOCKET, SCM_RIGHTS) } {
None => {
if fds.count() > 0 {
return Err(CMsgBufferTooSmallError::new());
}
0
}
Some(mut cmsg) => {
let mut count = 0;
let mut data = cmsg.data();
for fd in fds {
let fd_size = mem::size_of_val(&fd);
if data.len() < fd_size {
return Err(CMsgBufferTooSmallError::new());
}
let (nextval, nextdata) = data.split_at_mut(fd_size);
nextval.copy_from_slice(&fd.to_ne_bytes());
data = nextdata;
count += 1;
}
cmsg.shrink_data_len((count * mem::size_of::<RawFd>()).try_into().unwrap());
count
}
};
if count == 0 {
self.mhdr.msg_control = ptr::null_mut();
self.mhdr.msg_controllen = 0;
} else {
self.mhdr.msg_controllen = cmsg_buffer_fds_space(count);
}
Ok(MsgHdr {
mhdr: self.mhdr,
state: SendReady { fds_count: count },
_phantom: PhantomData,
})
}
}
impl<'a> MsgHdr<'a, SendReady> {
pub fn send(self, sock_fd: RawFd) -> io::Result<MsgHdr<'a, SendEnd>> {
let bytes_sent =
call_res(|| unsafe { sendmsg(sock_fd, &self.mhdr, 0) }).map(|c| c as usize)?;
Ok(MsgHdr {
mhdr: self.mhdr,
state: SendEnd {
bytes_sent,
fds_sent: self.state.fds_count,
},
_phantom: PhantomData,
})
}
}
impl<'a> MsgHdr<'a, SendEnd> {
pub fn bytes_sent(&self) -> usize {
self.state.bytes_sent
}
pub fn fds_sent(&self) -> usize {
self.state.fds_sent
}
}
impl<'a> FdsIter<'a> {
unsafe fn new(mhdr: &'a msghdr) -> Self {
let cmsg = FdsIter::first_cmsg(mhdr);
let data = cmsg.and_then(|cmsg| FdsIterData::new(cmsg));
FdsIter {
mhdr,
cmsg,
data,
}
}
fn empty(mhdr: &'a msghdr) -> Self {
FdsIter {
mhdr,
cmsg: None,
data: None,
}
}
unsafe fn first_cmsg(mhdr: &'a msghdr) -> Option<&'a cmsghdr> {
let cmsg = CMSG_FIRSTHDR(mhdr);
cmsg.as_ref()
}
fn advance_cmsg(&mut self) {
if let Some(cmsg) = self.cmsg {
let new_cmsg = unsafe { CMSG_NXTHDR(self.mhdr, cmsg).as_ref() };
let new_data = new_cmsg.and_then(|cmsg| unsafe { FdsIterData::new(cmsg) });
self.cmsg = new_cmsg;
self.data = new_data;
}
}
}
impl<'a> Iterator for FdsIter<'a> {
type Item = Fd;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.data.as_mut().and_then(|data| data.next()) {
Some(data) => return Some(data),
None => {
self.advance_cmsg();
if self.cmsg.is_none() {
return None;
}
}
};
}
}
}
impl<'a> Drop for FdsIter<'a> {
fn drop(&mut self) {
for fd in self {
drop(fd);
}
}
}
impl FdsIterData {
unsafe fn new(cmsg: &cmsghdr) -> Option<Self> {
if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
let p_start = CMSG_DATA(cmsg) as *const u8;
assert!(cmsg.cmsg_len <= (isize::MAX as usize));
let pcmsg: *const cmsghdr = cmsg;
let p_end = (pcmsg.cast::<u8>()).offset(cmsg.cmsg_len as isize);
let data_size = (p_end as usize) - (p_start as usize);
let fds_count: usize = data_size / mem::size_of::<RawFd>();
let curr = p_start.cast::<RawFd>();
let end = curr.offset(fds_count as isize);
Some(FdsIterData { curr, end })
} else {
None
}
}
}
impl Iterator for FdsIterData {
type Item = Fd;
fn next(&mut self) -> Option<Self::Item> {
if self.curr < self.end {
let next = unsafe { self.curr.read_unaligned() };
self.curr = unsafe { self.curr.offset(1) };
Some(Fd::new(next))
} else {
None
}
}
}
impl<'a> CMsgMut<'a> {
unsafe fn first_cmsg(mhdr: &'a mut msghdr, level: c_int, typ: c_int) -> Option<Self> {
let cmsg = CMSG_FIRSTHDR(mhdr);
if cmsg == ptr::null_mut() {
None
} else {
let control_max =
(mhdr.msg_control.cast::<u8>()).offset(mhdr.msg_controllen.try_into().unwrap());
let data = CMSG_DATA(cmsg);
let data_size = (control_max as usize) - (data as usize);
let max_len = CMSG_LEN(data_size.try_into().unwrap());
(*cmsg).cmsg_level = level;
(*cmsg).cmsg_type = typ;
(*cmsg).cmsg_len = max_len.try_into().unwrap();
Some(Self {
cmsg: NonNull::new_unchecked(cmsg),
_phantom: PhantomData,
})
}
}
fn shrink_data_len(&mut self, len: c_uint) {
let cmsg_len = unsafe { CMSG_LEN(len) }.try_into().unwrap();
let mut cmsg = unsafe { self.cmsg.as_mut() };
if cmsg_len < cmsg.cmsg_len {
cmsg.cmsg_len = cmsg_len;
}
}
fn data(&mut self) -> &mut [u8] {
let cmsg = unsafe { self.cmsg.as_ref() };
let data: *mut u8 = unsafe { CMSG_DATA(cmsg) };
let hdr_size = (data as usize) - (cmsg as *const cmsghdr as usize);
let data_size = (cmsg).cmsg_len - hdr_size;
assert!(data_size <= (isize::MAX as usize));
unsafe { slice::from_raw_parts_mut(data, data_size) }
}
}
impl Fd {
fn new(fd: RawFd) -> Self {
Self { fd: Some(fd) }
}
}
impl Drop for Fd {
fn drop(&mut self) {
if let Some(fd) = self.fd {
unsafe { close(fd) };
}
}
}
impl IntoRawFd for Fd {
fn into_raw_fd(mut self) -> RawFd {
self.fd
.take()
.expect("Attempt to take the RawFd contained in an Fd a second time")
}
}
#[derive(Debug)]
struct CMsgBufferTooSmallError {}
impl CMsgBufferTooSmallError {
fn new() -> io::Error {
io::Error::new(io::ErrorKind::Other, CMsgBufferTooSmallError {})
}
}
impl fmt::Display for CMsgBufferTooSmallError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"The control buffer passed to MsgHdr was too small for the \
number of file descriptors"
)
}
}
impl error::Error for CMsgBufferTooSmallError {}
pub fn cmsg_buffer_fds_space(count: usize) -> usize {
unsafe { CMSG_SPACE((count * mem::size_of::<RawFd>()) as u32) as usize }
}
fn call_res<F, R>(mut f: F) -> Result<R, io::Error>
where
F: FnMut() -> R,
R: One + Neg<Output = R> + PartialEq,
{
let res = f();
if res == -R::one() {
Err(io::Error::last_os_error())
} else {
Ok(res)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{iter, os::unix::io::AsRawFd};
#[test]
fn recv_end_take_fds_twice_is_empty() {
let mut control_buffer = [0u8; 1024];
let bufs: [IoSlice; 0] = [];
let fds = [1, 2, 3, 4];
let mhdr = MsgHdr::from_io_slice(&bufs, &mut control_buffer)
.encode_fds(fds.iter().map(|fd| *fd))
.expect("Can't encode fds");
let mut sut = MsgHdrRecvEnd {
mhdr: mhdr.mhdr,
bytes_recvieved: 0,
fds_taken: false,
_phantom: PhantomData,
};
mem::forget(sut.take_fds());
let taken = sut.take_fds();
assert_eq!(taken.count(), 0);
}
#[test]
fn start_send_encode_fds_with_small_buffer_is_error() {
let mut control_buffer = vec![0u8; cmsg_buffer_fds_space(1)];
let bufs: [IoSlice; 0] = [];
let fds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let sut = MsgHdr::from_io_slice(&bufs, &mut control_buffer);
let result = sut.encode_fds(fds.iter().map(|fd| *fd));
assert!(result.is_err());
}
#[test]
fn recv_end_take_fds_handle_non_scm_rights() {
let mut control_buffer = vec![0u8; cmsg_buffer_fds_space(4) + cmsg_buffer_cred_size()];
let fds = [1, 2, 3, 4];
let bufs: [IoSlice; 0] = [];
let mhdr = MsgHdr::from_io_slice(&bufs, &mut control_buffer);
unsafe {
let mut cmsg = CMSG_FIRSTHDR(&mhdr.mhdr);
assert_ne!(cmsg, ptr::null_mut());
encode_fake_cred(cmsg);
cmsg = CMSG_NXTHDR(&mhdr.mhdr, cmsg);
assert_ne!(cmsg, ptr::null_mut());
encode_fds(cmsg, &fds);
}
let mut count = 0;
let mut sut = MsgHdrRecvEnd {
mhdr: mhdr.mhdr,
bytes_recvieved: 0,
fds_taken: false,
_phantom: PhantomData,
};
for fd in sut.take_fds() {
count += 1;
let _ = fd.into_raw_fd();
}
assert_eq!(count, fds.len());
}
#[test]
fn send_ready_send_on_non_socket_is_error() {
let mut control_buffer = [0u8; 0];
let bytes = [1u8, 2, 3, 4, 5];
let bufs = [IoSlice::new(&bytes)];
let file = tempfile::tempfile().expect("Can't get temporary file.");
let sut = MsgHdr::from_io_slice(&bufs, &mut control_buffer)
.encode_fds(iter::empty())
.expect("Can't encode fds");
let result = sut.send(file.as_raw_fd());
assert!(result.is_err());
}
#[test]
fn recv_start_recv_on_non_socket_is_error() {
let mut control_buffer = [0u8; 0];
let mut bytes = [1u8, 2, 3, 4, 5];
let mut bufs = [IoSliceMut::new(&mut bytes)];
let file = tempfile::tempfile().expect("Can't get temporary file.");
let sut = MsgHdr::from_io_slice_mut(&mut bufs, &mut control_buffer);
let result = sut.recv(file.as_raw_fd());
assert!(result.is_err());
}
unsafe fn encode_fds(cmsg: *mut cmsghdr, fds: &[RawFd]) {
let data_size = fds.len() * mem::size_of::<RawFd>();
(*cmsg).cmsg_len = CMSG_LEN((data_size) as u32) as usize;
(*cmsg).cmsg_level = SOL_SOCKET;
(*cmsg).cmsg_type = SCM_RIGHTS;
ptr::copy_nonoverlapping(fds.as_ptr() as *const u8, CMSG_DATA(cmsg), data_size);
}
fn cmsg_buffer_cred_size() -> usize {
unsafe { CMSG_SPACE(mem::size_of::<libc::ucred>() as u32) as usize }
}
unsafe fn encode_fake_cred(cmsg: *mut cmsghdr) {
let data_size = mem::size_of::<libc::ucred>();
(*cmsg).cmsg_len = CMSG_LEN((data_size) as u32) as usize;
(*cmsg).cmsg_level = SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_CREDENTIALS;
let fake_cred = libc::ucred {
pid: 5,
uid: 2,
gid: 2,
};
ptr::copy_nonoverlapping(
&fake_cred as *const libc::ucred as *const u8,
CMSG_DATA(cmsg),
data_size,
);
}
}