#![doc = include_str!("../README.md")]
#![deny(missing_docs)]
use std::{
collections::VecDeque,
io::{IoSlice, IoSliceMut},
mem::MaybeUninit,
os::{
fd::{AsFd, BorrowedFd, OwnedFd},
unix::net::UnixStream,
},
pin::Pin,
task::{Context, Poll, ready},
};
use rustix::net::{
RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
SendAncillaryMessage, SendFlags, recvmsg, sendmsg,
};
use tokio::io::{self, AsyncRead, AsyncWrite, ReadBuf, unix::AsyncFd};
pub const DBUS_FD_LIMIT: usize = 253;
pub const DBUS_SCM_RIGHTS: usize = rustix::cmsg_space!(ScmRights(DBUS_FD_LIMIT));
pub const WAYLAND_FD_LIMIT: usize = 28;
pub const WAYLAND_SCM_RIGHTS: usize = rustix::cmsg_space!(ScmRights(WAYLAND_FD_LIMIT));
pub struct AnchovyStream<const S: usize> {
stream: AsyncFd<UnixStream>,
decode_fds: VecDeque<OwnedFd>,
encode_fds: VecDeque<OwnedFd>,
}
mod sealed {
pub trait Sealed {}
impl Sealed for tokio::net::UnixStream {}
impl Sealed for std::os::unix::net::UnixStream {}
}
pub trait IntoUnixStream: sealed::Sealed {
fn into_unix_stream(self) -> io::Result<UnixStream>;
}
impl IntoUnixStream for UnixStream {
fn into_unix_stream(self) -> io::Result<UnixStream> {
Ok(self)
}
}
impl IntoUnixStream for tokio::net::UnixStream {
fn into_unix_stream(self) -> io::Result<UnixStream> {
self.into_std()
}
}
impl<const S: usize> AnchovyStream<S> {
pub fn new<T: IntoUnixStream>(stream: T) -> io::Result<Self> {
AsyncFd::new(stream.into_unix_stream()?).map(|stream| Self {
stream,
decode_fds: VecDeque::new(),
encode_fds: VecDeque::new(),
})
}
pub fn read_queue(&self) -> &VecDeque<OwnedFd> {
&self.decode_fds
}
pub fn read_queue_mut(&mut self) -> &mut VecDeque<OwnedFd> {
&mut self.decode_fds
}
pub fn write_queue(&self) -> &VecDeque<OwnedFd> {
&self.encode_fds
}
pub fn write_queue_mut(&mut self) -> &mut VecDeque<OwnedFd> {
&mut self.encode_fds
}
fn poll_write_impl(
&mut self,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let stream = &mut self.stream;
let encode_fds = &mut self.encode_fds;
loop {
let mut guard = ready!(stream.poll_write_ready(cx))?;
let send_result = {
let raw: Vec<BorrowedFd<'_>> = encode_fds.iter().map(|fd| fd.as_fd()).collect();
let mut cmsg_space = [MaybeUninit::uninit(); S];
let mut ancillary = SendAncillaryBuffer::new(&mut cmsg_space);
if !raw.is_empty() {
ancillary.push(SendAncillaryMessage::ScmRights(&raw));
}
guard.try_io(|inner| {
sendmsg(
inner.get_ref(),
bufs,
&mut ancillary,
SendFlags::DONTWAIT | SendFlags::NOSIGNAL,
)
.map_err(|e| io::Error::from_raw_os_error(e.raw_os_error()))
})
};
match send_result {
Ok(Ok(msg)) => {
encode_fds.clear();
return Poll::Ready(Ok(msg));
}
Ok(Err(err)) => {
return Poll::Ready(Err(err));
}
Err(_would_block) => continue,
}
}
}
}
impl<const S: usize> AsyncRead for AnchovyStream<S> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
let stream = &mut this.stream;
let decode_fds = &mut this.decode_fds;
loop {
let mut guard = ready!(stream.poll_read_ready(cx))?;
let mut cmsg_space = [MaybeUninit::uninit(); S];
let mut ancillary = RecvAncillaryBuffer::new(&mut cmsg_space);
let unfilled = buf.initialize_unfilled();
match guard.try_io(|inner| {
recvmsg(
inner.get_ref(),
&mut [IoSliceMut::new(unfilled)],
&mut ancillary,
RecvFlags::DONTWAIT | RecvFlags::CMSG_CLOEXEC,
)
.map_err(|e| io::Error::from_raw_os_error(e.raw_os_error()))
}) {
Ok(Ok(msg)) => {
for message in ancillary.drain() {
if let RecvAncillaryMessage::ScmRights(fds) = message {
for fd in fds {
decode_fds.push_back(fd);
}
}
}
buf.advance(msg.bytes);
return Poll::Ready(Ok(()));
}
Ok(Err(err)) => return Poll::Ready(Err(err)),
Err(_would_block) => continue,
}
}
}
}
impl<const S: usize> AsyncWrite for AnchovyStream<S> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut().poll_write_impl(cx, &[IoSlice::new(buf)])
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
self.get_mut().poll_write_impl(cx, bufs)
}
fn is_write_vectored(&self) -> bool {
true
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut()
.stream
.get_ref()
.shutdown(std::net::Shutdown::Write)?;
Poll::Ready(Ok(()))
}
}