use std::io;
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::path::Path;
use std::slice;
use std::sync::atomic::{AtomicBool, Ordering};
use nix::errno::Errno;
use nix::sys::socket::{
c_uint, recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, CMSG_SPACE,
};
use nix::sys::uio::IoVec;
use nix::unistd;
use tokio::io::unix::AsyncFd;
#[cfg(target_os = "linux")]
const MSG_FLAGS: MsgFlags = MsgFlags::MSG_CMSG_CLOEXEC;
#[cfg(target_os = "macos")]
const MSG_FLAGS: MsgFlags = MsgFlags::empty();
#[repr(C)]
#[derive(Default, Debug)]
struct MsgHeader {
payload_len: u32,
fd_count: u32,
}
impl MsgHeader {
pub fn as_buf(&self) -> &[u8] {
unsafe { slice::from_raw_parts((self as *const _) as *const u8, mem::size_of_val(self)) }
}
pub fn as_buf_mut(&mut self) -> &mut [u8] {
unsafe { slice::from_raw_parts_mut((self as *mut _) as *mut u8, mem::size_of_val(self)) }
}
pub fn make_buffer(&self) -> Vec<u8> {
vec![0u8; self.payload_len as usize]
}
}
#[derive(Debug, Clone)]
pub struct Credentials {
pid: libc::pid_t,
uid: libc::uid_t,
gid: libc::gid_t,
}
impl Credentials {
pub fn pid(&self) -> libc::pid_t {
self.pid
}
pub fn uid(&self) -> libc::uid_t {
self.uid
}
pub fn gid(&self) -> libc::gid_t {
self.gid
}
}
#[cfg(any(target_os = "android", target_os = "linux"))]
impl From<nix::sys::socket::UnixCredentials> for Credentials {
fn from(c: nix::sys::socket::UnixCredentials) -> Self {
Self {
pid: c.pid(),
uid: c.uid(),
gid: c.gid(),
}
}
}
macro_rules! fd_impl {
($ty:ty) => {
#[allow(dead_code)]
impl $ty {
pub(crate) unsafe fn from_raw_fd(fd: RawFd) -> io::Result<Self> {
Ok(Self {
inner: AsyncFd::new(fd)?,
dead: AtomicBool::new(false),
})
}
pub(crate) fn from_std(stream: UnixStream) -> io::Result<Self> {
unsafe { Self::from_raw_fd(stream.into_raw_fd()) }
}
pub(crate) fn extract_raw_fd(&self) -> RawFd {
if self.dead.swap(true, Ordering::SeqCst) {
panic!("handle was moved previously");
} else {
self.inner.as_raw_fd()
}
}
}
impl FromRawFd for $ty {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::from_raw_fd(fd)
.expect("conversion from RawFd requires an active tokio runtime")
}
}
impl IntoRawFd for $ty {
fn into_raw_fd(self) -> RawFd {
self.extract_raw_fd()
}
}
impl AsRawFd for $ty {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl Drop for $ty {
fn drop(&mut self) {
if !self.dead.load(Ordering::SeqCst) {
unistd::close(self.as_raw_fd()).ok();
}
}
}
};
}
fd_impl!(RawReceiver);
fd_impl!(RawSender);
macro_rules! nix_eintr {
($expr:expr) => {
loop {
match $expr {
Err(Errno::EINTR) => continue,
other => break other,
}
}
};
}
fn recv_impl(
fd: RawFd,
buf: &mut [u8],
fds: Option<Vec<i32>>,
fd_count: usize,
_want_creds: bool,
) -> io::Result<(usize, Option<Vec<RawFd>>, Option<Credentials>)> {
let iov = [IoVec::from_mut_slice(buf)];
let mut new_fds = None;
#[allow(unused_mut)]
let mut creds = None;
let msgspace_size = {
let fd_size = unsafe { CMSG_SPACE(mem::size_of::<RawFd>() as c_uint) * fd_count as u32 };
#[cfg(any(target_os = "android", target_os = "linux"))]
{
let cred_size: u32 = _want_creds
.then(|| unsafe {
CMSG_SPACE(mem::size_of::<nix::sys::socket::UnixCredentials>() as c_uint)
})
.unwrap_or_default();
fd_size + cred_size
}
#[cfg(not(any(target_os = "android", target_os = "linux")))]
{
fd_size
}
};
let mut cmsgspace = vec![0u8; msgspace_size as usize];
let msg = nix_eintr!(recvmsg(fd, &iov, Some(&mut cmsgspace), MSG_FLAGS))?;
for cmsg in msg.cmsgs() {
match cmsg {
ControlMessageOwned::ScmRights(fds) => {
if !fds.is_empty() {
#[cfg(target_os = "macos")]
unsafe {
for &fd in &fds {
libc::ioctl(fd, libc::FIOCLEX);
}
}
new_fds = Some(fds);
}
}
#[cfg(any(target_os = "android", target_os = "linux"))]
ControlMessageOwned::ScmCredentials(c) => {
creds = Some(c.into());
}
_ => {}
}
}
if msg.bytes == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"could not read",
));
}
let fds = match (fds, new_fds) {
(None, Some(new)) => Some(new),
(Some(mut old), Some(new)) => {
old.extend(new);
Some(old)
}
(old, None) => old,
};
Ok((msg.bytes, fds, creds))
}
#[cfg(any(target_os = "android", target_os = "linux"))]
fn send_impl(fd: RawFd, data: &[u8], fds: &[RawFd], creds: bool) -> io::Result<usize> {
let iov = [IoVec::from_slice(&data)];
let creds = creds.then(|| nix::sys::socket::UnixCredentials::new());
let sent = match (fds, creds.as_ref()) {
([], None) => nix_eintr!(sendmsg(fd, &iov, &[], MsgFlags::empty(), None))?,
([], Some(creds)) => nix_eintr!(sendmsg(
fd,
&iov,
&[ControlMessage::ScmCredentials(creds),],
MsgFlags::empty(),
None,
))?,
(fds, Some(creds)) => {
let cmsgs = &[
ControlMessage::ScmRights(fds),
ControlMessage::ScmCredentials(creds),
];
nix_eintr!(sendmsg(fd, &iov, cmsgs, MsgFlags::empty(), None,))?
}
(fds, None) => {
let cmsgs = &[ControlMessage::ScmRights(fds)];
nix_eintr!(sendmsg(fd, &iov, cmsgs, MsgFlags::empty(), None,))?
}
};
if sent == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
}
Ok(sent)
}
#[cfg(not(any(target_os = "android", target_os = "linux")))]
fn send_impl(fd: RawFd, data: &[u8], fds: &[RawFd], _creds: bool) -> io::Result<usize> {
let iov = [IoVec::from_slice(&data)];
let sent = if !fds.is_empty() {
nix_eintr!(sendmsg(
fd,
&iov,
&[ControlMessage::ScmRights(fds)],
MsgFlags::empty(),
None,
))?
} else {
nix_eintr!(sendmsg(fd, &iov, &[], MsgFlags::empty(), None))?
};
if sent == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "could not send"));
}
Ok(sent)
}
pub fn raw_channel() -> io::Result<(RawSender, RawReceiver)> {
let (sender, receiver) = tokio::net::UnixStream::pair()?;
Ok((
RawSender::from_std(sender.into_std()?)?,
RawReceiver::from_std(receiver.into_std()?)?,
))
}
#[derive(Debug)]
pub struct RawReceiver {
inner: AsyncFd<RawFd>,
dead: AtomicBool,
}
impl RawReceiver {
pub async fn connect<P: AsRef<Path>>(p: P) -> io::Result<RawReceiver> {
let stream = tokio::net::UnixStream::connect(p).await?;
RawReceiver::from_std(stream.into_std()?)
}
pub async fn recv(&self) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>)> {
let mut header = MsgHeader::default();
self.recv_impl(header.as_buf_mut(), 0, false).await?;
let mut buf = header.make_buffer();
let (_, fds, _) = self
.recv_impl(&mut buf, header.fd_count as usize, false)
.await?;
Ok((buf, fds))
}
#[cfg(any(target_os = "android", target_os = "linux"))]
pub async fn recv_with_credentials(
&self,
) -> io::Result<(Vec<u8>, Option<Vec<RawFd>>, Credentials)> {
nix::sys::socket::setsockopt(
self.inner.as_raw_fd(),
nix::sys::socket::sockopt::PassCred,
&true,
)?;
let mut header = MsgHeader::default();
let (_, _, creds) = self.recv_impl(header.as_buf_mut(), 0, true).await?;
let creds = creds.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"Remote did not provide credentials",
)
})?;
let mut buf = header.make_buffer();
let (_, fds, _) = self
.recv_impl(&mut buf, header.fd_count as usize, false)
.await?;
Ok((buf, fds, creds))
}
async fn recv_impl(
&self,
buf: &mut [u8],
fd_count: usize,
want_creds: bool,
) -> io::Result<(usize, Option<Vec<RawFd>>, Option<Credentials>)> {
let mut pos = 0;
let mut fds = None;
loop {
let mut guard = self.inner.readable().await?;
let (bytes, new_fds, creds) = match guard.try_io(|inner| {
recv_impl(
inner.as_raw_fd(),
&mut buf[pos..],
fds.take(),
fd_count,
want_creds,
)
}) {
Ok(result) => result,
Err(_would_block) => continue,
}?;
fds = new_fds;
pos += bytes;
if pos >= buf.len() {
return Ok((pos, fds, creds));
}
}
}
}
unsafe impl Send for RawReceiver {}
unsafe impl Sync for RawReceiver {}
#[derive(Debug)]
pub struct RawSender {
inner: AsyncFd<RawFd>,
#[allow(dead_code)]
dead: AtomicBool,
}
impl RawSender {
pub async fn send(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
let header = MsgHeader {
payload_len: data.len() as u32,
fd_count: fds.len() as u32,
};
self.send_impl(header.as_buf(), &[][..], false).await?;
self.send_impl(&data, fds, false).await
}
#[cfg(any(target_os = "android", target_os = "linux"))]
pub async fn send_with_credentials(&self, data: &[u8], fds: &[RawFd]) -> io::Result<usize> {
let header = MsgHeader {
payload_len: data.len() as u32,
fd_count: fds.len() as u32,
};
self.send_impl(header.as_buf(), &[][..], true).await?;
self.send_impl(&data, fds, false).await
}
async fn send_impl(&self, data: &[u8], mut fds: &[RawFd], creds: bool) -> io::Result<usize> {
let mut pos = 0;
loop {
let mut guard = self.inner.writable().await?;
let sent = match guard
.try_io(|inner| send_impl(inner.as_raw_fd(), &data[pos..], fds, creds))
{
Ok(result) => result,
Err(_would_block) => continue,
}?;
pos += sent;
fds = &[][..];
if pos >= data.len() {
return Ok(pos);
}
}
}
}
unsafe impl Send for RawSender {}
unsafe impl Sync for RawSender {}