use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, BorrowedFd, FromRawFd, OwnedFd, RawFd};
use std::{fmt, mem};
#[derive(Debug, Clone)]
pub struct AncillaryError;
impl fmt::Display for AncillaryError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ancillary buffer too small")
}
}
impl std::error::Error for AncillaryError {}
pub enum AncillaryData<'a> {
ScmRights(ScmRights<'a>),
}
pub struct ScmRights<'a> {
data: &'a [u8],
offset: usize,
}
impl<'a> ScmRights<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
ScmRights { data, offset: 0 }
}
}
impl Iterator for ScmRights<'_> {
type Item = OwnedFd;
fn next(&mut self) -> Option<Self::Item> {
let fd_size = mem::size_of::<RawFd>();
loop {
if self.offset + fd_size > self.data.len() {
return None;
}
let mut fd_bytes = [0u8; mem::size_of::<RawFd>()];
fd_bytes.copy_from_slice(&self.data[self.offset..self.offset + fd_size]);
self.offset += fd_size;
let raw = RawFd::from_ne_bytes(fd_bytes);
if raw < 0 {
continue;
}
return Some(unsafe { OwnedFd::from_raw_fd(raw) });
}
}
}
pub struct Messages<'a> {
current: *const libc::cmsghdr,
msg: libc::msghdr,
_marker: PhantomData<&'a [u8]>,
}
impl<'a> Messages<'a> {
fn new(buffer: &'a [u8], length: usize) -> Self {
let mut msg: libc::msghdr = unsafe { mem::zeroed() };
msg.msg_control = buffer.as_ptr() as *mut libc::c_void;
msg.msg_controllen = length as _;
let current = unsafe { libc::CMSG_FIRSTHDR(&msg) };
Messages {
current,
msg,
_marker: PhantomData,
}
}
}
impl<'a> Iterator for Messages<'a> {
type Item = AncillaryData<'a>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.current.is_null() {
return None;
}
let buf_start = self.msg.msg_control as usize;
#[allow(clippy::unnecessary_cast)]
let buf_end = buf_start.saturating_add(self.msg.msg_controllen as usize);
let cur_addr = self.current as usize;
#[allow(clippy::unnecessary_cast)]
let (level, ty, data_ptr, data_len, well_formed) = unsafe {
let cmsg = &*self.current;
let data_ptr = libc::CMSG_DATA(self.current as *mut _);
let header_len = (data_ptr as usize).saturating_sub(cur_addr);
let total = cmsg.cmsg_len as usize;
let remaining = buf_end.saturating_sub(cur_addr);
let well_formed = total >= header_len && total <= remaining;
let data_len = if well_formed { total - header_len } else { 0 };
(
cmsg.cmsg_level,
cmsg.cmsg_type,
data_ptr,
data_len,
well_formed,
)
};
self.current = if well_formed {
unsafe { libc::CMSG_NXTHDR(&self.msg, self.current) }
} else {
std::ptr::null()
};
if level == libc::SOL_SOCKET && ty == libc::SCM_RIGHTS {
let data: &'a [u8] = unsafe { std::slice::from_raw_parts(data_ptr, data_len) };
return Some(AncillaryData::ScmRights(ScmRights::new(data)));
}
}
}
}
pub struct SocketAncillary<'a> {
pub(crate) buffer: &'a mut [u8],
pub(crate) length: usize,
pub(crate) truncated: bool,
}
impl<'a> SocketAncillary<'a> {
pub fn new(buffer: &'a mut [u8]) -> Self {
SocketAncillary {
buffer,
length: 0,
truncated: false,
}
}
pub fn buffer_size_for_rights(num_fds: usize) -> usize {
unsafe { libc::CMSG_SPACE((num_fds * mem::size_of::<RawFd>()) as libc::c_uint) as usize }
}
pub fn add_fds(&mut self, fds: &[BorrowedFd<'_>]) -> Result<(), AncillaryError> {
let fd_bytes_len = fds.len() * mem::size_of::<RawFd>();
let space = unsafe { libc::CMSG_SPACE(fd_bytes_len as libc::c_uint) as usize };
let new_len = self.length.checked_add(space).ok_or(AncillaryError)?;
if new_len > self.buffer.len() {
return Err(AncillaryError);
}
unsafe {
let mut msg: libc::msghdr = mem::zeroed();
msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = new_len as _;
let cmsg = if self.length == 0 {
libc::CMSG_FIRSTHDR(&msg)
} else {
let mut walk_msg: libc::msghdr = mem::zeroed();
walk_msg.msg_control = self.buffer.as_mut_ptr() as *mut libc::c_void;
walk_msg.msg_controllen = self.length as _;
let mut cur = libc::CMSG_FIRSTHDR(&walk_msg);
while !cur.is_null() {
let next = libc::CMSG_NXTHDR(&walk_msg, cur);
if next.is_null() {
break;
}
cur = next;
}
if cur.is_null() {
libc::CMSG_FIRSTHDR(&msg)
} else {
libc::CMSG_NXTHDR(&msg, cur)
}
};
if cmsg.is_null() {
return Err(AncillaryError);
}
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
(*cmsg).cmsg_len = libc::CMSG_LEN(fd_bytes_len as libc::c_uint) as _;
let data_ptr = libc::CMSG_DATA(cmsg) as *mut RawFd;
for (i, fd) in fds.iter().enumerate() {
std::ptr::write_unaligned(data_ptr.add(i), fd.as_raw_fd());
}
}
self.length = new_len;
Ok(())
}
pub fn messages(&self) -> Messages<'_> {
Messages::new(&self.buffer[..self.length], self.length)
}
#[must_use]
pub fn is_truncated(&self) -> bool {
self.truncated
}
pub fn clear(&mut self) {
self.length = 0;
self.truncated = false;
}
}
#[doc(hidden)]
pub unsafe fn __fuzz_parse(buf: &[u8]) -> Messages<'_> {
Messages::new(buf, buf.len())
}