#![allow(unsafe_code)]
#[cfg(target_os = "linux")]
use crate::backend::net::msghdr::noaddr_msghdr;
use crate::backend::{self, c};
use crate::fd::{AsFd, BorrowedFd, OwnedFd};
use crate::io::{self, IoSlice, IoSliceMut};
use crate::net::addr::SocketAddrArg;
#[cfg(linux_kernel)]
use crate::net::UCred;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::{align_of, size_of, size_of_val, take, MaybeUninit};
#[cfg(linux_kernel)]
use core::ptr::addr_of;
use core::{ptr, slice};
use super::{RecvFlags, ReturnFlags, SendFlags, SocketAddrAny};
#[macro_export]
macro_rules! cmsg_space {
(ScmRights($len:expr)) => {
$crate::net::__cmsg_space(
$len * ::core::mem::size_of::<$crate::fd::BorrowedFd<'static>>(),
)
};
(ScmCredentials($len:expr)) => {
$crate::net::__cmsg_space(
$len * ::core::mem::size_of::<$crate::net::UCred>(),
)
};
(TxTime($len:expr)) => {
$crate::net::__cmsg_space(
$len * ::core::mem::size_of::<::core::primitive::u64>(),
)
};
($firstid:ident($firstex:expr), $($restid:ident($restex:expr)),*) => {{
let sum = $crate::cmsg_space!($firstid($firstex));
$(
let sum = sum + $crate::cmsg_aligned_space!($restid($restex));
)*
sum
}};
}
#[doc(hidden)]
#[macro_export]
macro_rules! cmsg_aligned_space {
(ScmRights($len:expr)) => {
$crate::net::__cmsg_aligned_space(
$len * ::core::mem::size_of::<$crate::fd::BorrowedFd<'static>>(),
)
};
(ScmCredentials($len:expr)) => {
$crate::net::__cmsg_aligned_space(
$len * ::core::mem::size_of::<$crate::net::UCred>(),
)
};
(TxTime($len:expr)) => {
$crate::net::__cmsg_aligned_space(
$len * ::core::mem::size_of::<::core::primitive::u64>(),
)
};
($firstid:ident($firstex:expr), $($restid:ident($restex:expr)),*) => {{
let sum = $crate::cmsg_aligned_space!($firstid($firstex));
$(
let sum = sum + $crate::cmsg_aligned_space!($restid($restex));
)*
sum
}};
}
#[doc(hidden)]
pub const fn __cmsg_space(len: usize) -> usize {
let len = len + align_of::<c::cmsghdr>();
__cmsg_aligned_space(len)
}
#[doc(hidden)]
pub const fn __cmsg_aligned_space(len: usize) -> usize {
let converted_len = len as u32;
if converted_len as usize != len {
unreachable!(); }
unsafe { c::CMSG_SPACE(converted_len) as usize }
}
#[non_exhaustive]
pub enum SendAncillaryMessage<'slice, 'fd> {
#[doc(alias = "SCM_RIGHTS")]
ScmRights(&'slice [BorrowedFd<'fd>]),
#[cfg(linux_kernel)]
#[doc(alias = "SCM_CREDENTIAL")]
ScmCredentials(UCred),
#[cfg(target_os = "linux")]
#[doc(alias = "SCM_TXTIME")]
TxTime(u64),
}
impl SendAncillaryMessage<'_, '_> {
pub const fn size(&self) -> usize {
match self {
Self::ScmRights(slice) => cmsg_space!(ScmRights(slice.len())),
#[cfg(linux_kernel)]
Self::ScmCredentials(_) => cmsg_space!(ScmCredentials(1)),
#[cfg(target_os = "linux")]
Self::TxTime(_) => cmsg_space!(TxTime(1)),
}
}
}
#[non_exhaustive]
pub enum RecvAncillaryMessage<'a> {
#[doc(alias = "SCM_RIGHTS")]
ScmRights(AncillaryIter<'a, OwnedFd>),
#[cfg(linux_kernel)]
#[doc(alias = "SCM_CREDENTIALS")]
ScmCredentials(UCred),
}
pub struct SendAncillaryBuffer<'buf, 'slice, 'fd> {
buffer: &'buf mut [MaybeUninit<u8>],
length: usize,
_phantom: PhantomData<&'slice [BorrowedFd<'fd>]>,
}
impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for SendAncillaryBuffer<'buf, '_, '_> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new(buffer)
}
}
impl Default for SendAncillaryBuffer<'_, '_, '_> {
fn default() -> Self {
Self {
buffer: &mut [],
length: 0,
_phantom: PhantomData,
}
}
}
impl<'buf, 'slice, 'fd> SendAncillaryBuffer<'buf, 'slice, 'fd> {
#[inline]
pub fn new(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
length: 0,
_phantom: PhantomData,
}
}
pub(crate) fn as_control_ptr(&mut self) -> *mut u8 {
#[cfg(not(linux_kernel))]
if self.length == 0 {
return core::ptr::null_mut();
}
self.buffer.as_mut_ptr().cast()
}
pub(crate) fn control_len(&self) -> usize {
self.length
}
pub fn clear(&mut self) {
self.length = 0;
}
pub fn push(&mut self, msg: SendAncillaryMessage<'slice, 'fd>) -> bool {
match msg {
SendAncillaryMessage::ScmRights(fds) => {
let fds_bytes =
unsafe { slice::from_raw_parts(fds.as_ptr().cast::<u8>(), size_of_val(fds)) };
self.push_ancillary(fds_bytes, c::SOL_SOCKET as _, c::SCM_RIGHTS as _)
}
#[cfg(linux_kernel)]
SendAncillaryMessage::ScmCredentials(ucred) => {
let ucred_bytes = unsafe {
slice::from_raw_parts(addr_of!(ucred).cast::<u8>(), size_of_val(&ucred))
};
self.push_ancillary(ucred_bytes, c::SOL_SOCKET as _, c::SCM_CREDENTIALS as _)
}
#[cfg(target_os = "linux")]
SendAncillaryMessage::TxTime(tx_time) => {
let tx_time_bytes = unsafe {
slice::from_raw_parts(addr_of!(tx_time).cast::<u8>(), size_of_val(&tx_time))
};
self.push_ancillary(tx_time_bytes, c::SOL_SOCKET as _, c::SO_TXTIME as _)
}
}
}
fn push_ancillary(&mut self, source: &[u8], cmsg_level: c::c_int, cmsg_type: c::c_int) -> bool {
macro_rules! leap {
($e:expr) => {{
match ($e) {
Some(x) => x,
None => return false,
}
}};
}
let source_len = leap!(u32::try_from(source.len()).ok());
let additional_space = unsafe { c::CMSG_SPACE(source_len) };
let new_length = leap!(self.length.checked_add(additional_space as usize));
let buffer = leap!(self.buffer.get_mut(..new_length));
buffer[self.length..new_length].fill(MaybeUninit::new(0));
self.length = new_length;
let last_header = leap!(messages::Messages::new(buffer).last());
last_header.cmsg_len = unsafe { c::CMSG_LEN(source_len) } as _;
last_header.cmsg_level = cmsg_level;
last_header.cmsg_type = cmsg_type;
unsafe {
let payload = c::CMSG_DATA(last_header);
ptr::copy_nonoverlapping(source.as_ptr(), payload, source_len as usize);
}
true
}
}
impl<'slice, 'fd> Extend<SendAncillaryMessage<'slice, 'fd>>
for SendAncillaryBuffer<'_, 'slice, 'fd>
{
fn extend<T: IntoIterator<Item = SendAncillaryMessage<'slice, 'fd>>>(&mut self, iter: T) {
iter.into_iter().all(|msg| self.push(msg));
}
}
#[derive(Default)]
pub struct RecvAncillaryBuffer<'buf> {
buffer: &'buf mut [MaybeUninit<u8>],
read: usize,
length: usize,
}
impl<'buf> From<&'buf mut [MaybeUninit<u8>]> for RecvAncillaryBuffer<'buf> {
fn from(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self::new(buffer)
}
}
impl<'buf> RecvAncillaryBuffer<'buf> {
#[inline]
pub fn new(buffer: &'buf mut [MaybeUninit<u8>]) -> Self {
Self {
buffer: align_for_cmsghdr(buffer),
read: 0,
length: 0,
}
}
pub(crate) fn as_control_ptr(&mut self) -> *mut u8 {
#[cfg(not(linux_kernel))]
if self.buffer.is_empty() {
return core::ptr::null_mut();
}
self.buffer.as_mut_ptr().cast()
}
pub(crate) fn control_len(&self) -> usize {
self.buffer.len()
}
pub(crate) unsafe fn set_control_len(&mut self, len: usize) {
self.length = len;
self.read = 0;
}
pub(crate) fn clear(&mut self) {
self.drain().for_each(drop);
}
pub fn drain(&mut self) -> AncillaryDrain<'_> {
AncillaryDrain {
messages: messages::Messages::new(&mut self.buffer[self.read..][..self.length]),
read_and_length: Some((&mut self.read, &mut self.length)),
}
}
}
impl Drop for RecvAncillaryBuffer<'_> {
fn drop(&mut self) {
self.clear();
}
}
#[inline]
fn align_for_cmsghdr(buffer: &mut [MaybeUninit<u8>]) -> &mut [MaybeUninit<u8>] {
if buffer.is_empty() {
return buffer;
}
let align = align_of::<c::cmsghdr>();
let addr = buffer.as_ptr() as usize;
let adjusted = (addr + (align - 1)) & align.wrapping_neg();
&mut buffer[adjusted - addr..]
}
pub struct AncillaryDrain<'buf> {
messages: messages::Messages<'buf>,
read_and_length: Option<(&'buf mut usize, &'buf mut usize)>,
}
impl<'buf> AncillaryDrain<'buf> {
pub unsafe fn parse(buffer: &'buf mut [u8]) -> Self {
Self {
messages: messages::Messages::new(buffer),
read_and_length: None,
}
}
fn advance(
read_and_length: &mut Option<(&'buf mut usize, &'buf mut usize)>,
msg: &c::cmsghdr,
) -> Option<RecvAncillaryMessage<'buf>> {
if let Some((read, length)) = read_and_length {
let msg_len = msg.cmsg_len as usize;
**read += msg_len;
**length -= msg_len;
}
Self::cvt_msg(msg)
}
fn cvt_msg(msg: &c::cmsghdr) -> Option<RecvAncillaryMessage<'buf>> {
unsafe {
let payload = c::CMSG_DATA(msg);
let payload_len = msg.cmsg_len as usize - c::CMSG_LEN(0) as usize;
let payload: &'buf mut [u8] = slice::from_raw_parts_mut(payload, payload_len);
let (level, msg_type) = (msg.cmsg_level, msg.cmsg_type);
match (level as _, msg_type as _) {
(c::SOL_SOCKET, c::SCM_RIGHTS) => {
let fds = AncillaryIter::new(payload);
Some(RecvAncillaryMessage::ScmRights(fds))
}
#[cfg(linux_kernel)]
(c::SOL_SOCKET, c::SCM_CREDENTIALS) => {
if payload_len >= size_of::<UCred>() {
let ucred = payload.as_ptr().cast::<UCred>().read_unaligned();
Some(RecvAncillaryMessage::ScmCredentials(ucred))
} else {
None
}
}
_ => None,
}
}
}
}
impl<'buf> Iterator for AncillaryDrain<'buf> {
type Item = RecvAncillaryMessage<'buf>;
fn next(&mut self) -> Option<Self::Item> {
self.messages
.find_map(|ev| Self::advance(&mut self.read_and_length, ev))
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, max) = self.messages.size_hint();
(0, max)
}
fn fold<B, F>(mut self, init: B, f: F) -> B
where
Self: Sized,
F: FnMut(B, Self::Item) -> B,
{
self.messages
.filter_map(|ev| Self::advance(&mut self.read_and_length, ev))
.fold(init, f)
}
fn count(mut self) -> usize {
self.messages
.filter_map(|ev| Self::advance(&mut self.read_and_length, ev))
.count()
}
fn last(mut self) -> Option<Self::Item>
where
Self: Sized,
{
self.messages
.filter_map(|ev| Self::advance(&mut self.read_and_length, ev))
.last()
}
fn collect<B: FromIterator<Self::Item>>(mut self) -> B
where
Self: Sized,
{
self.messages
.filter_map(|ev| Self::advance(&mut self.read_and_length, ev))
.collect()
}
}
impl FusedIterator for AncillaryDrain<'_> {}
#[cfg(target_os = "linux")]
#[repr(transparent)]
pub struct MMsgHdr<'a> {
raw: c::mmsghdr,
_phantom: PhantomData<&'a mut ()>,
}
#[cfg(target_os = "linux")]
impl<'a> MMsgHdr<'a> {
pub fn new(iov: &'a [IoSlice<'_>], control: &'a mut SendAncillaryBuffer<'_, '_, '_>) -> Self {
Self::wrap(noaddr_msghdr(iov, control))
}
pub fn new_with_addr(
addr: &'a SocketAddrAny,
iov: &'a [IoSlice<'_>],
control: &'a mut SendAncillaryBuffer<'_, '_, '_>,
) -> Self {
let mut msghdr = noaddr_msghdr(iov, control);
msghdr.msg_name = addr.as_ptr() as _;
msghdr.msg_namelen = bitcast!(addr.addr_len());
Self::wrap(msghdr)
}
fn wrap(msg_hdr: c::msghdr) -> Self {
Self {
raw: c::mmsghdr {
msg_hdr,
msg_len: 0,
},
_phantom: PhantomData,
}
}
pub fn bytes_sent(&self) -> usize {
self.raw.msg_len as usize
}
}
#[inline]
pub fn sendmsg<Fd: AsFd>(
socket: Fd,
iov: &[IoSlice<'_>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
flags: SendFlags,
) -> io::Result<usize> {
backend::net::syscalls::sendmsg(socket.as_fd(), iov, control, flags)
}
#[inline]
pub fn sendmsg_addr<Fd: AsFd>(
socket: Fd,
addr: &impl SocketAddrArg,
iov: &[IoSlice<'_>],
control: &mut SendAncillaryBuffer<'_, '_, '_>,
flags: SendFlags,
) -> io::Result<usize> {
backend::net::syscalls::sendmsg_addr(socket.as_fd(), addr, iov, control, flags)
}
#[inline]
#[cfg(target_os = "linux")]
pub fn sendmmsg<Fd: AsFd>(
socket: Fd,
msgs: &mut [MMsgHdr<'_>],
flags: SendFlags,
) -> io::Result<usize> {
backend::net::syscalls::sendmmsg(socket.as_fd(), msgs, flags)
}
#[inline]
pub fn recvmsg<Fd: AsFd>(
socket: Fd,
iov: &mut [IoSliceMut<'_>],
control: &mut RecvAncillaryBuffer<'_>,
flags: RecvFlags,
) -> io::Result<RecvMsg> {
backend::net::syscalls::recvmsg(socket.as_fd(), iov, control, flags)
}
#[derive(Debug, Clone)]
pub struct RecvMsg {
pub bytes: usize,
pub flags: ReturnFlags,
pub address: Option<SocketAddrAny>,
}
pub struct AncillaryIter<'data, T> {
data: &'data mut [u8],
_marker: PhantomData<T>,
}
impl<'data, T> AncillaryIter<'data, T> {
unsafe fn new(data: &'data mut [u8]) -> Self {
assert_eq!(data.len() % size_of::<T>(), 0);
Self {
data,
_marker: PhantomData,
}
}
}
impl<'data, T> Drop for AncillaryIter<'data, T> {
fn drop(&mut self) {
self.for_each(drop);
}
}
impl<T> Iterator for AncillaryIter<'_, T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
if self.data.len() < size_of::<T>() {
return None;
}
let item = unsafe { self.data.as_ptr().cast::<T>().read_unaligned() };
let data = take(&mut self.data);
self.data = &mut data[size_of::<T>()..];
Some(item)
}
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
fn count(self) -> usize {
self.len()
}
fn last(mut self) -> Option<Self::Item> {
self.next_back()
}
}
impl<T> FusedIterator for AncillaryIter<'_, T> {}
impl<T> ExactSizeIterator for AncillaryIter<'_, T> {
fn len(&self) -> usize {
self.data.len() / size_of::<T>()
}
}
impl<T> DoubleEndedIterator for AncillaryIter<'_, T> {
fn next_back(&mut self) -> Option<Self::Item> {
if self.data.len() < size_of::<T>() {
return None;
}
let item = unsafe {
let ptr = self.data.as_ptr().add(self.data.len() - size_of::<T>());
ptr.cast::<T>().read_unaligned()
};
let len = self.data.len();
let data = take(&mut self.data);
self.data = &mut data[..len - size_of::<T>()];
Some(item)
}
}
mod messages {
use crate::backend::c;
use crate::backend::net::msghdr;
use core::iter::FusedIterator;
use core::marker::PhantomData;
use core::mem::MaybeUninit;
use core::ptr::NonNull;
pub(super) struct Messages<'buf> {
msghdr: c::msghdr,
header: Option<NonNull<c::cmsghdr>>,
_buffer: PhantomData<&'buf mut [MaybeUninit<u8>]>,
}
pub(super) trait AllowedMsgBufType {}
impl AllowedMsgBufType for u8 {}
impl AllowedMsgBufType for MaybeUninit<u8> {}
impl<'buf> Messages<'buf> {
pub(super) fn new(buf: &'buf mut [impl AllowedMsgBufType]) -> Self {
let mut msghdr = msghdr::zero_msghdr();
msghdr.msg_control = buf.as_mut_ptr().cast();
msghdr.msg_controllen = buf.len().try_into().expect("buffer too large for msghdr");
let header = NonNull::new(unsafe { c::CMSG_FIRSTHDR(&msghdr) });
Self {
msghdr,
header,
_buffer: PhantomData,
}
}
}
impl<'a> Iterator for Messages<'a> {
type Item = &'a mut c::cmsghdr;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
let header = self.header?;
self.header = NonNull::new(unsafe { c::CMSG_NXTHDR(&self.msghdr, header.as_ptr()) });
if Some(header) == self.header {
self.header = None;
}
Some(unsafe { &mut *header.as_ptr() })
}
fn size_hint(&self) -> (usize, Option<usize>) {
if self.header.is_some() {
let max_size = unsafe { c::CMSG_LEN(0) } as usize;
let remaining_count = self.msghdr.msg_controllen as usize / max_size;
(1, Some(remaining_count))
} else {
(0, Some(0))
}
}
}
impl FusedIterator for Messages<'_> {}
}
#[cfg(test)]
mod tests {
#[no_implicit_prelude]
mod hygiene {
#[allow(unused_macros)]
#[test]
fn macro_hygiene() {
#[allow(dead_code, non_camel_case_types)]
struct u64([u8]);
macro_rules! cmsg_space {
($($tt:tt)*) => {{
let v: usize = ::core::panic!("Wrong cmsg_space! macro called");
v
}};
}
macro_rules! cmsg_aligned_space {
($($tt:tt)*) => {{
let v: usize = ::core::panic!("Wrong cmsg_aligned_space! macro called");
v
}};
}
crate::cmsg_space!(ScmRights(1));
crate::cmsg_space!(TxTime(1));
#[cfg(linux_kernel)]
{
crate::cmsg_space!(ScmCredentials(1));
crate::cmsg_space!(ScmRights(1), ScmCredentials(1), TxTime(1));
crate::cmsg_aligned_space!(ScmRights(1), ScmCredentials(1), TxTime(1));
}
}
}
}