#![forbid(missing_docs)]
#[cfg(not(unix))]
compile_error!("io-mux only runs on UNIX");
#[cfg(all(
unix,
not(target_os = "linux"),
not(feature = "experimental-unix-support")
))]
compile_error!(
"io-mux support for non-Linux platforms is experimental.
Please read the portability note in the io-mux documentation for more information
and potential caveats, before enabling io-mux's experimental UNIX support."
);
use std::io;
use std::net::Shutdown;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
#[cfg(target_os = "linux")]
use std::os::linux::net::SocketAddrExt;
use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
use std::os::unix::net::{SocketAddr, UnixDatagram};
use std::path::Path;
use std::process::Stdio;
#[cfg(feature = "async")]
use async_io::Async;
use rustix::net::RecvFlags;
const DEFAULT_BUF_SIZE: usize = 8192;
pub struct Mux {
receive: UnixDatagram,
receive_addr: SocketAddr,
tempdir: Option<tempfile::TempDir>,
buf: Vec<u8>,
}
impl AsFd for Mux {
fn as_fd(&self) -> BorrowedFd<'_> {
self.receive.as_fd()
}
}
impl AsRawFd for Mux {
fn as_raw_fd(&self) -> RawFd {
self.receive.as_raw_fd()
}
}
pub struct MuxSender(UnixDatagram);
impl AsRawFd for MuxSender {
fn as_raw_fd(&self) -> RawFd {
self.0.as_raw_fd()
}
}
impl IntoRawFd for MuxSender {
fn into_raw_fd(self) -> RawFd {
self.0.into_raw_fd()
}
}
impl AsFd for MuxSender {
fn as_fd(&self) -> BorrowedFd<'_> {
self.0.as_fd()
}
}
impl From<MuxSender> for OwnedFd {
fn from(sender: MuxSender) -> OwnedFd {
sender.0.into()
}
}
impl From<MuxSender> for Stdio {
fn from(sender: MuxSender) -> Stdio {
Stdio::from(OwnedFd::from(sender))
}
}
impl io::Write for MuxSender {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.send(buf)
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[derive(Clone, Debug)]
pub struct Tag(SocketAddr);
impl PartialEq<Tag> for Tag {
fn eq(&self, rhs: &Tag) -> bool {
#[cfg(target_os = "linux")]
if let (Some(lhs), Some(rhs)) = (self.0.as_abstract_name(), rhs.0.as_abstract_name()) {
return lhs == rhs;
}
if let (Some(lhs), Some(rhs)) = (self.0.as_pathname(), rhs.0.as_pathname()) {
return lhs == rhs;
}
self.0.is_unnamed() && rhs.0.is_unnamed()
}
}
impl Eq for Tag {}
#[derive(Debug, Eq, PartialEq)]
pub struct TaggedData<'a> {
pub data: &'a [u8],
pub tag: Tag,
}
impl Mux {
#[cfg(target_os = "linux")]
pub fn new_abstract() -> io::Result<Self> {
for _ in 0..32768 {
let receive_addr =
SocketAddr::from_abstract_name(format!("io-mux-{:x}", fastrand::u128(..)))?;
match Self::new_with_addr(receive_addr, None) {
Err(e) if e.kind() == io::ErrorKind::AddrInUse => continue,
result => return result,
}
}
Err(io::Error::new(
io::ErrorKind::AddrInUse,
"couldn't create unique socket name",
))
}
pub fn new() -> io::Result<Self> {
Self::new_with_tempdir(tempfile::tempdir()?)
}
pub fn new_in<P: AsRef<Path>>(dir: P) -> io::Result<Self> {
Self::new_with_tempdir(tempfile::tempdir_in(dir)?)
}
fn new_with_tempdir(tempdir: tempfile::TempDir) -> io::Result<Self> {
let receive_addr = SocketAddr::from_pathname(tempdir.path().join("r"))?;
Self::new_with_addr(receive_addr, Some(tempdir))
}
fn new_with_addr(
receive_addr: SocketAddr,
tempdir: Option<tempfile::TempDir>,
) -> io::Result<Self> {
let receive = UnixDatagram::bind_addr(&receive_addr)?;
let _ = receive.shutdown(Shutdown::Write);
Ok(Mux {
receive,
receive_addr,
tempdir,
buf: vec![0; DEFAULT_BUF_SIZE],
})
}
pub fn make_sender(&self) -> io::Result<(Tag, MuxSender)> {
if let Some(ref tempdir) = self.tempdir {
self.make_sender_with_retry(|n| {
SocketAddr::from_pathname(tempdir.path().join(format!("{n:x}")))
})
} else {
#[cfg(target_os = "linux")]
return self.make_sender_with_retry(|n| {
SocketAddr::from_abstract_name(format!("io-mux-send-{n:x}"))
});
#[cfg(not(target_os = "linux"))]
panic!("Mux without tempdir on non-Linux platform")
}
}
fn make_sender_with_retry(
&self,
make_sender_addr: impl Fn(u128) -> io::Result<SocketAddr>,
) -> io::Result<(Tag, MuxSender)> {
for _ in 0..32768 {
let sender_addr = make_sender_addr(fastrand::u128(..))?;
let sender = match UnixDatagram::bind_addr(&sender_addr) {
Err(e) if e.kind() == io::ErrorKind::AddrInUse => continue,
result => result,
}?;
sender.connect_addr(&self.receive_addr)?;
sender.shutdown(Shutdown::Read)?;
return Ok((Tag(sender_addr), MuxSender(sender)));
}
Err(io::Error::new(
io::ErrorKind::AddrInUse,
"couldn't create unique socket name",
))
}
#[cfg(all(target_os = "linux", not(feature = "test-portable")))]
fn recv_from_full(&mut self) -> io::Result<(&[u8], SocketAddr)> {
let next_packet_len = rustix::net::recv(
&mut self.receive,
&mut [],
RecvFlags::PEEK | RecvFlags::TRUNC,
)?;
if next_packet_len > self.buf.len() {
self.buf.resize(next_packet_len, 0);
}
let (bytes, addr) = self.receive.recv_from(&mut self.buf)?;
Ok((&self.buf[..bytes], addr))
}
#[cfg(not(all(target_os = "linux", not(feature = "test-portable"))))]
fn recv_from_full(&mut self) -> io::Result<(&[u8], SocketAddr)> {
loop {
let bytes = rustix::net::recv(&mut self.receive, &mut self.buf, RecvFlags::PEEK)?;
if bytes == self.buf.len() {
let new_len = self.buf.len().saturating_mul(2);
self.buf.resize(new_len, 0);
} else {
let (_, addr) = self.receive.recv_from(&mut [])?;
return Ok((&self.buf[..bytes], addr));
}
}
}
pub fn read(&mut self) -> io::Result<TaggedData<'_>> {
let (data, addr) = self.recv_from_full()?;
let tag = Tag(addr);
Ok(TaggedData { data, tag })
}
}
#[cfg(feature = "async")]
pub struct AsyncMux(Async<Mux>);
#[cfg(feature = "async")]
impl AsyncMux {
#[cfg(target_os = "linux")]
pub fn new_abstract() -> io::Result<Self> {
Ok(Self(Async::new(Mux::new_abstract()?)?))
}
pub fn new() -> io::Result<Self> {
Ok(Self(Async::new(Mux::new()?)?))
}
pub fn new_in<P: AsRef<Path>>(dir: P) -> io::Result<Self> {
Ok(Self(Async::new(Mux::new_in(dir)?)?))
}
pub fn make_sender(&self) -> io::Result<(Tag, MuxSender)> {
self.0.get_ref().make_sender()
}
pub async fn read(&mut self) -> io::Result<TaggedData<'_>> {
self.0.readable().await?;
let m = unsafe { self.0.get_mut() };
m.read()
}
pub fn read_nonblock(&mut self) -> io::Result<Option<TaggedData<'_>>> {
let m = unsafe { self.0.get_mut() };
match m.read() {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(None),
ret => ret.map(Some),
}
}
}
#[cfg(test)]
mod test {
#[cfg(feature = "async")]
use super::AsyncMux;
use super::Mux;
#[test]
fn test() -> std::io::Result<()> {
test_with_mux(Mux::new()?)
}
#[test]
fn test_new_in() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;
let dir_entries = || -> std::io::Result<usize> {
Ok(dir.path().read_dir()?.collect::<Result<Vec<_>, _>>()?.len())
};
assert_eq!(dir_entries()?, 0);
let mux = Mux::new_in(dir.path())?;
assert_eq!(dir_entries()?, 1);
test_with_mux(mux)
}
#[test]
#[cfg(target_os = "linux")]
fn test_abstract() -> std::io::Result<()> {
test_with_mux(Mux::new_abstract()?)
}
fn test_with_mux(mut mux: Mux) -> std::io::Result<()> {
let (out_tag, out_sender) = mux.make_sender()?;
let (err_tag, err_sender) = mux.make_sender()?;
let mut child = std::process::Command::new("sh")
.arg("-c")
.arg("echo out1 && echo err1 1>&2 && echo out2 && echo err2 1>&2")
.stdout(out_sender)
.stderr(err_sender)
.spawn()?;
let (done_tag, mut done_sender) = mux.make_sender()?;
std::thread::spawn(move || {
use std::io::Write;
match child.wait() {
Ok(status) if status.success() => {
let _ = write!(done_sender, "Done\n");
}
Ok(_) => {
let _ = write!(done_sender, "Child process failed\n");
}
Err(e) => {
let _ = write!(done_sender, "Error: {:?}\n", e);
}
}
});
let data1 = mux.read()?;
assert_eq!(data1.tag, out_tag);
assert_eq!(data1.data, b"out1\n");
let data2 = mux.read()?;
assert_eq!(data2.tag, err_tag);
assert_eq!(data2.data, b"err1\n");
let data3 = mux.read()?;
assert_eq!(data3.tag, out_tag);
assert_eq!(data3.data, b"out2\n");
let data4 = mux.read()?;
assert_eq!(data4.tag, err_tag);
assert_eq!(data4.data, b"err2\n");
let done = mux.read()?;
assert_eq!(done.tag, done_tag);
assert_eq!(done.data, b"Done\n");
Ok(())
}
#[cfg(feature = "async")]
fn test_with_async_mux(mut mux: AsyncMux) -> std::io::Result<()> {
use futures_lite::{FutureExt, future};
future::block_on(async {
let (out_tag, out_sender) = mux.make_sender()?;
let (err_tag, err_sender) = mux.make_sender()?;
let mut child = async_process::Command::new("sh")
.arg("-c")
.arg("echo out1 && echo err1 1>&2 && echo out2 && echo err2 1>&2")
.stdout(out_sender)
.stderr(err_sender)
.spawn()?;
let mut expected = vec![
(out_tag.clone(), b"out1\n"),
(err_tag.clone(), b"err1\n"),
(out_tag, b"out2\n"),
(err_tag, b"err2\n"),
];
let mut expected = expected.drain(..);
let mut status = None;
while status.is_none() {
async {
status = Some(child.status().await?);
Ok::<(), std::io::Error>(())
}
.or(async {
let data = mux.read().await?;
let (expected_tag, expected_data) = expected.next().unwrap();
assert_eq!(data.tag, expected_tag);
assert_eq!(data.data, expected_data);
Ok(())
})
.await?;
}
while let Some(data) = mux.read_nonblock()? {
let (expected_tag, expected_data) = expected.next().unwrap();
assert_eq!(data.tag, expected_tag);
assert_eq!(data.data, expected_data);
}
assert!(status.unwrap().success());
assert_eq!(expected.next(), None);
Ok(())
})
}
#[cfg(feature = "async")]
#[test]
fn test_async() -> std::io::Result<()> {
test_with_async_mux(AsyncMux::new()?)
}
#[cfg(all(feature = "async", target_os = "linux"))]
#[test]
fn test_abstract_async() -> std::io::Result<()> {
test_with_async_mux(AsyncMux::new_abstract()?)
}
}