use std::marker::PhantomData;
use std::os::unix::io::{BorrowedFd, FromRawFd, OwnedFd, RawFd};
use std::{fmt, mem};
#[derive(Debug, Clone)]
pub struct AncillaryError;
impl fmt::Display for AncillaryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ancillary buffer too small")
}
}
impl std::error::Error for AncillaryError {}
pub enum AncillaryData<'a> {
ScmRights(ScmRights<'a>),
}
pub struct ScmRights<'a> {
data: &'a [u8],
offset: usize,
}
impl<'a> ScmRights<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
ScmRights { data, offset: 0 }
}
}
impl Iterator for ScmRights<'_> {
type Item = OwnedFd;
fn next(&mut self) -> Option<Self::Item> {
let fd_size = mem::size_of::<RawFd>();
if self.offset + fd_size > self.data.len() {
return None;
}
let mut fd_bytes = [0u8; mem::size_of::<RawFd>()];
fd_bytes.copy_from_slice(&self.data[self.offset..self.offset + fd_size]);
self.offset += fd_size;
let raw = RawFd::from_ne_bytes(fd_bytes);
Some(unsafe { OwnedFd::from_raw_fd(raw) })
}
}
pub struct Messages<'a> {
current: *const libc::cmsghdr,
msg: libc::msghdr,
_marker: PhantomData<&'a [u8]>,
}
impl<'a> Messages<'a> {
fn new(buffer: &'a [u8], length: usize) -> Self {
let mut msg: libc::msghdr = unsafe { mem::zeroed() };
msg.msg_control = buffer.as_ptr() as *mut libc::c_void;
msg.msg_controllen = length as _;
let current = unsafe { libc::CMSG_FIRSTHDR(&msg) };
Messages {
current,
msg,
_marker: PhantomData,
}
}
}
impl<'a> Iterator for Messages<'a> {
type Item = AncillaryData<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.current.is_null() {
return None;
}
unsafe {
let cmsg = &*self.current;
self.current = libc::CMSG_NXTHDR(&self.msg, self.current);
if cmsg.cmsg_level == libc::SOL_SOCKET && cmsg.cmsg_type == libc::SCM_RIGHTS {
let data_ptr = libc::CMSG_DATA(cmsg as *const _ as *mut _);
#[allow(clippy::unnecessary_cast)]
let data_len =
cmsg.cmsg_len as usize - (data_ptr as usize - cmsg as *const _ as usize);
let data = std::slice::from_raw_parts(data_ptr, data_len);
Some(AncillaryData::ScmRights(ScmRights::new(data)))
} else {
self.next()
}
}
}
}
pub struct SocketAncillary<'a> {
pub(crate) buffer: &'a mut [u8],
pub(crate) length: usize,
pub(crate) truncated: bool,
}
impl<'a> SocketAncillary<'a> {
pub fn new(buffer: &'a mut [u8]) -> Self {
SocketAncillary {
buffer,
length: 0,
truncated: false,
}
}
pub fn buffer_size_for_rights(num_fds: usize) -> usize {
unsafe { libc::CMSG_SPACE((num_fds * mem::size_of::<RawFd>()) as libc::c_uint) as usize }
}
pub fn add_fds(&mut self, fds: &[BorrowedFd<'_>]) -> Result<(), AncillaryError> {
let raw_fds: Vec<RawFd> = fds.iter().map(|fd| {
use std::os::unix::io::AsRawFd;
fd.as_raw_fd()
}).collect();
let fd_bytes_len = raw_fds.len() * mem::size_of::<RawFd>();
let space = unsafe { libc::CMSG_SPACE(fd_bytes_len as libc::c_uint) as usize };
if self.length + space > self.buffer.len() {
return Err(AncillaryError);
}
unsafe {
let mut msg: libc::msghdr = mem::zeroed();
msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = (self.length + space) as _;
let cmsg = if self.length == 0 {
libc::CMSG_FIRSTHDR(&msg)
} else {
let mut walk_msg: libc::msghdr = mem::zeroed();
walk_msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
walk_msg.msg_controllen = self.length as _;
let mut cur = libc::CMSG_FIRSTHDR(&walk_msg);
while !cur.is_null() {
let next = libc::CMSG_NXTHDR(&walk_msg, cur);
if next.is_null() {
break;
}
cur = next;
}
msg.msg_controllen = (self.length + space) as _;
if cur.is_null() {
libc::CMSG_FIRSTHDR(&msg)
} else {
libc::CMSG_NXTHDR(&msg, cur)
}
};
if cmsg.is_null() {
return Err(AncillaryError);
}
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
(*cmsg).cmsg_len = libc::CMSG_LEN(fd_bytes_len as libc::c_uint) as _;
let data_ptr = libc::CMSG_DATA(cmsg);
std::ptr::copy_nonoverlapping(
raw_fds.as_ptr() as *const u8,
data_ptr,
fd_bytes_len,
);
}
self.length += space;
Ok(())
}
pub fn messages(&self) -> Messages<'_> {
Messages::new(&self.buffer[..self.length], self.length)
}
#[must_use]
pub fn is_truncated(&self) -> bool {
self.truncated
}
pub fn clear(&mut self) {
self.length = 0;
self.truncated = false;
}
}
impl Drop for SocketAncillary<'_> {
fn drop(&mut self) {
#[cfg(target_os = "macos")]
if self.truncated {
for msg in self.messages() {
match msg {
AncillaryData::ScmRights(rights) => {
for _fd in rights {
}
}
}
}
}
}
}
#[cfg(not(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "dragonfly",
target_os = "netbsd",
target_os = "openbsd",
)))]
pub(crate) fn set_cloexec(fd: RawFd) -> io::Result<()> {
unsafe {
let flags = libc::fcntl(fd, libc::F_GETFD);
if flags < 0 {
return Err(io::Error::last_os_error());
}
let ret = libc::fcntl(fd, libc::F_SETFD, flags | libc::FD_CLOEXEC);
if ret < 0 {
return Err(io::Error::last_os_error());
}
}
Ok(())
}