use std::{
io,
marker::PhantomData,
os::{
fd::{AsFd, BorrowedFd, OwnedFd},
unix::net::UnixStream,
},
};
pub struct Channel<TX, RX> {
stream: UnixStream,
_marker: PhantomData<(fn(TX) -> RX, fn(RX) -> TX)>,
}
pub fn channel<TX, RX>() -> io::Result<(Channel<TX, RX>, Channel<RX, TX>)> {
let (a, b) = UnixStream::pair()?;
Ok((
Channel {
stream: a,
_marker: PhantomData,
},
Channel {
stream: b,
_marker: PhantomData,
},
))
}
impl<TX, RX> Clone for Channel<TX, RX> {
fn clone(&self) -> Self {
Self {
stream: self.stream.try_clone().unwrap(),
_marker: PhantomData,
}
}
}
#[cfg(use_unstable_unix_socket_ancillary_data_2021)]
mod sys {
use super::*;
use std::os::fd::FromRawFd;
use std::os::unix::net::{AncillaryData, SocketAncillary};
pub(super) fn stream_sendmsg<const FD_LEN: usize>(
stream: &UnixStream,
bytes: io::IoSlice<'_>,
fds: &[BorrowedFd<'_>; FD_LEN],
) -> io::Result<()> {
let mut ancillary_buffer = [0; 64];
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
if !ancillary.add_fds(unsafe { &*(fds as *const [BorrowedFd<'_>] as *const [i32]) }) {
return Err(io::Error::other(format!(
"failed to send {FD_LEN} file descriptors: \
the resulting cmsg doesn't fit in {} bytes",
ancillary.capacity()
)));
}
let written_len = stream.send_vectored_with_ancillary(&[bytes], &mut ancillary)?;
if written_len != bytes.len() {
return Err(io::Error::other(format!(
"partial write (only {written_len} out of {})",
bytes.len()
)));
}
Ok(())
}
pub(super) fn stream_recvmsg<const FD_LEN: usize>(
stream: &UnixStream,
bytes: io::IoSliceMut<'_>,
) -> io::Result<[OwnedFd; FD_LEN]> {
let mut ancillary_buffer = [0; 64];
let mut ancillary = SocketAncillary::new(&mut ancillary_buffer);
let expected_len = bytes.len();
let read_len = stream.recv_vectored_with_ancillary(&mut [bytes], &mut ancillary)?;
let partial_read = read_len != expected_len;
let (anciliary_truncated, anciliary_capacity) =
(ancillary.truncated(), ancillary.capacity());
let mut errors = vec![];
let mut accepted_fds = [(); FD_LEN].map(|()| None);
let mut accepted_fd_count = 0;
for cmsg in ancillary.messages() {
match cmsg {
Err(err) => errors.push(format!("{err:?}")),
Ok(AncillaryData::ScmRights(raw_fds)) => {
let is_first_scm_rights = accepted_fd_count == 0;
for raw_fd in raw_fds {
if raw_fd == -1 {
errors.push("invalid fd (-1) received".into());
continue;
}
let fd = unsafe { OwnedFd::from_raw_fd(raw_fd) };
if is_first_scm_rights {
let i = accepted_fd_count;
accepted_fd_count += 1;
if let Some(slot) = accepted_fds.get_mut(i) {
*slot = Some(fd);
}
}
}
if !is_first_scm_rights {
errors.push("received more than one SCM_RIGHTS cmsg".into());
}
}
Ok(AncillaryData::ScmCredentials(_)) => {
errors.push("received unexpected SCM_CREDS-like cmsg".into());
}
}
}
if accepted_fd_count != FD_LEN {
errors.push(format!(
"wrong number of received fds: expected {FD_LEN}, got {accepted_fd_count}"
))
}
if partial_read {
return Err(io::Error::other(format!(
"partial read: only {read_len} out of {expected_len}"
)));
}
if anciliary_truncated {
return Err(io::Error::other(format!(
"truncated anciliary buffer: received cmsg doesn't fit in {anciliary_capacity} bytes"
)));
}
if errors.is_empty() {
Ok(accepted_fds.map(Option::unwrap))
} else {
Err(io::Error::other(if errors.len() == 1 {
errors.pop().unwrap()
} else {
format!("errors during receiving:\n {}", errors.join("\n "))
}))
}
}
}
#[cfg(not(use_unstable_unix_socket_ancillary_data_2021))]
mod sys {
#![allow(non_camel_case_types)]
fn io_error_other(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
io::Error::new(io::ErrorKind::Other, error)
}
use super::*;
use std::{
ffi::{c_int, c_void},
ptr,
};
type socklen_t = u32;
#[repr(C)]
struct msghdr<IOV> {
msg_name: *mut c_void,
msg_namelen: socklen_t,
msg_iov: *mut IOV,
msg_iovlen: usize,
msg_control: *mut c_void,
msg_controllen: usize,
msg_flags: c_int,
}
const SOL_SOCKET: c_int = 1;
const SCM_RIGHTS: c_int = 1;
#[repr(C)]
struct cmsghdr {
cmsg_len: usize,
cmsg_level: c_int,
cmsg_type: c_int,
}
const _: () = assert!(std::mem::size_of::<cmsghdr>() % std::mem::size_of::<usize>() == 0);
extern "C" {
fn sendmsg(
sockfd: BorrowedFd<'_>,
msg: *const msghdr<io::IoSlice<'_>>,
flags: c_int,
) -> isize;
fn recvmsg(
sockfd: BorrowedFd<'_>,
msg: *mut msghdr<io::IoSliceMut<'_>>,
flags: c_int,
) -> isize;
}
#[repr(C)]
struct CMsgBuf<FD, const FD_LEN: usize> {
header: cmsghdr,
fds: [FD; FD_LEN],
}
pub(super) fn stream_sendmsg<const FD_LEN: usize>(
stream: &UnixStream,
mut bytes: io::IoSlice<'_>,
fds: &[BorrowedFd<'_>; FD_LEN],
) -> io::Result<()> {
let mut cmsg_buf = CMsgBuf {
header: cmsghdr {
cmsg_len: std::mem::size_of::<cmsghdr>() + FD_LEN * 4,
cmsg_level: SOL_SOCKET,
cmsg_type: SCM_RIGHTS,
},
fds: *fds,
};
let written_len = unsafe {
sendmsg(
stream.as_fd(),
&msghdr {
msg_name: ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut bytes,
msg_iovlen: 1,
msg_control: &mut cmsg_buf as *mut _ as *mut _,
msg_controllen: std::mem::size_of_val(&cmsg_buf),
msg_flags: 0,
},
0,
)
};
if written_len == -1 {
return Err(io::Error::last_os_error());
}
if written_len as usize != bytes.len() {
return Err(io_error_other(format!(
"partial write (only {written_len} out of {})",
bytes.len()
)));
}
Ok(())
}
pub(super) fn stream_recvmsg<const FD_LEN: usize>(
stream: &UnixStream,
mut bytes: io::IoSliceMut<'_>,
) -> io::Result<[OwnedFd; FD_LEN]> {
let expected_len = bytes.len();
let mut cmsg_buf = std::mem::MaybeUninit::<CMsgBuf<Option<OwnedFd>, FD_LEN>>::zeroed();
let expected_cmsg_len = std::mem::size_of::<cmsghdr>() + FD_LEN * 4;
let expected_msg_controllen = std::mem::size_of_val(&cmsg_buf);
let mut msg = msghdr {
msg_name: ptr::null_mut(),
msg_namelen: 0,
msg_iov: &mut bytes,
msg_iovlen: 1,
msg_control: &mut cmsg_buf as *mut _ as *mut _,
msg_controllen: expected_msg_controllen,
msg_flags: 0,
};
let read_len = unsafe { recvmsg(stream.as_fd(), &mut msg, 0) };
if read_len == -1 {
return Err(io::Error::last_os_error());
}
if read_len as usize != expected_len {
return Err(io_error_other(format!(
"partial read: only {read_len} out of {expected_len}"
)));
}
if msg.msg_controllen != expected_msg_controllen {
return Err(io_error_other(format!(
"recvmsg msg_controllen mismatch: got {}, expected {expected_msg_controllen}",
msg.msg_controllen,
)));
}
let cmsg = unsafe { cmsg_buf.assume_init() };
if cmsg.header.cmsg_len != expected_cmsg_len {
return Err(io_error_other(format!(
"recvmsg cmsg_len mismatch: got {}, expected {expected_cmsg_len}",
cmsg.header.cmsg_len
)));
}
if (cmsg.header.cmsg_level, cmsg.header.cmsg_type) != (SOL_SOCKET, SCM_RIGHTS) {
return Err(io_error_other(format!("unsupported non-SCM_RIGHTS CMSG")));
}
if cmsg.fds.iter().any(|fd| fd.is_none()) {
return Err(io_error_other(format!("recvmsg got invalid (-1) fds")));
}
Ok(cmsg.fds.map(Option::unwrap))
}
}
impl<TX, RX> Channel<TX, RX> {
pub fn send<const TX_BYTE_LEN: usize, const TX_FD_LEN: usize>(&self, msg: TX) -> io::Result<()>
where
TX: FixedSizeEncoding<TX_BYTE_LEN, TX_FD_LEN>,
{
assert_ne!(
TX_FD_LEN,
0,
"Channel<{}, _> unsupported (lacks file descriptors)",
std::any::type_name::<TX>()
);
let (bytes, fds) = msg.encode();
sys::stream_sendmsg(&self.stream, io::IoSlice::new(&bytes), &fds)
}
pub fn recv<const RX_BYTE_LEN: usize, const RX_FD_LEN: usize>(&self) -> io::Result<RX>
where
RX: FixedSizeEncoding<RX_BYTE_LEN, RX_FD_LEN>,
{
assert_ne!(
RX_FD_LEN,
0,
"Channel<_, {}> unsupported (lacks file descriptors)",
std::any::type_name::<TX>()
);
let mut bytes = [0; RX_BYTE_LEN];
let fds = sys::stream_recvmsg(&self.stream, io::IoSliceMut::new(&mut bytes))?;
Ok(RX::decode(bytes, fds))
}
pub fn into_child_process_inheritable(self) -> io::Result<InheritableChannel<TX, RX>> {
extern "C" {
fn dup(fd: BorrowedFd<'_>) -> Option<OwnedFd>;
}
Ok(InheritableChannel(Self {
stream: unsafe { dup(self.stream.as_fd()) }
.ok_or_else(|| io::Error::last_os_error())?
.into(),
_marker: PhantomData,
}))
}
}
pub struct InheritableChannel<TX, RX>(Channel<TX, RX>);
impl<TX, RX> AsFd for InheritableChannel<TX, RX> {
fn as_fd(&self) -> BorrowedFd<'_> {
self.0.stream.as_fd()
}
}
impl<TX, RX> From<OwnedFd> for InheritableChannel<TX, RX> {
fn from(fd: OwnedFd) -> Self {
Self(Channel {
stream: UnixStream::from(fd),
_marker: PhantomData,
})
}
}
impl<TX, RX> InheritableChannel<TX, RX> {
pub fn into_uninheritable(self) -> io::Result<Channel<TX, RX>> {
let Self(mut channel) = self;
channel.stream = channel.stream.as_fd().try_clone_to_owned()?.into();
Ok(channel)
}
}
pub enum Never {}
pub trait FixedSizeEncoding<const BYTE_LEN: usize, const FD_LEN: usize> {
const BYTE_LEN: usize = BYTE_LEN;
const FD_LEN: usize = FD_LEN;
fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]);
fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self;
}
impl<
const BYTE_LEN: usize,
const FD_LEN: usize,
A: FixedSizeEncoding<BYTE_LEN, 0>,
B: FixedSizeEncoding<0, FD_LEN>,
> FixedSizeEncoding<BYTE_LEN, FD_LEN> for (A, B)
{
fn encode(&self) -> ([u8; BYTE_LEN], [BorrowedFd<'_>; FD_LEN]) {
let ((bytes, []), ([], fds)) = (self.0.encode(), self.1.encode());
(bytes, fds)
}
fn decode(bytes: [u8; BYTE_LEN], fds: [OwnedFd; FD_LEN]) -> Self {
(A::decode(bytes, []), B::decode([], fds))
}
}
macro_rules! fixed_size_le_prim_impls {
($($ty:ident)*) => {
$(impl FixedSizeEncoding<{(Self::BITS / 8) as usize}, 0> for $ty {
fn encode(&self) -> ([u8; Self::BYTE_LEN], [BorrowedFd<'_>; 0]) {
(self.to_le_bytes(), [])
}
fn decode(bytes: [u8; Self::BYTE_LEN], []: [OwnedFd; 0]) -> Self {
Self::from_le_bytes(bytes)
}
})*
}
}
fixed_size_le_prim_impls!(u16 u32 u64 u128);
impl FixedSizeEncoding<0, 1> for OwnedFd {
fn encode(&self) -> ([u8; 0], [BorrowedFd<'_>; 1]) {
([], [self.as_fd()])
}
fn decode([]: [u8; 0], [fd]: [OwnedFd; 1]) -> Self {
fd
}
}