use {
core::mem,
std::{
io::{self, Error, ErrorKind, IoSlice, IoSliceMut},
os::unix::{
io::{AsRawFd, FromRawFd, RawFd},
net::{AncillaryData, SocketAncillary, UnixStream},
},
process::Stdio,
},
crate::Result,
};
const RAW_FD_SIZE: usize = mem::size_of::<RawFd>();
pub trait UdsxUnixStream {
fn send_streams<B, const N: usize>(&self, id: B, streams: [RawFd; N]) -> Result<()> where B: AsRef<[u8]>;
fn send_ioe<B>(&self, id: B) -> Result<()> where B: AsRef<[u8]>;
unsafe fn recv_streams<B, T, const N: usize>(&self, id: B) -> Result<[T; N]> where B: AsRef<[u8]>, T: FromRawFd;
unsafe fn recv_ioe<B>(&self, id: B) -> Result<[Stdio; 3]> where B: AsRef<[u8]>;
}
impl UdsxUnixStream for UnixStream {
fn send_streams<B, const N: usize>(&self, id: B, streams: [RawFd; N]) -> Result<()> where B: AsRef<[u8]> {
verify_id(&id)?;
let mut ancillary_buf = vec!(0; make_size_of_streams(streams.len())?);
let mut ancillary = SocketAncillary::new(&mut ancillary_buf);
if ancillary.add_fds(&streams) == false {
return Err(Error::new(ErrorKind::Other, __!()));
}
self.send_vectored_with_ancillary(&[IoSlice::new(id.as_ref())], &mut ancillary)?;
Ok(())
}
fn send_ioe<B>(&self, id: B) -> Result<()> where B: AsRef<[u8]> {
self.send_streams(id, [io::stdin().as_raw_fd(), io::stdout().as_raw_fd(), io::stderr().as_raw_fd()])
}
unsafe fn recv_streams<B, T, const N: usize>(&self, id: B) -> Result<[T; N]> where B: AsRef<[u8]>, T: FromRawFd {
verify_id(&id)?;
let mut ancillary_buf = vec!(0; make_size_of_streams(N)?);
let mut ancillary = SocketAncillary::new(&mut ancillary_buf);
let id = id.as_ref();
if id != {
let mut io_slices = vec!(0; id.len());
{
let mut io_slices = [IoSliceMut::new(&mut io_slices)];
self.recv_vectored_with_ancillary(&mut io_slices, &mut ancillary)?;
}
io_slices
} {
return Err(Error::new(ErrorKind::InvalidData, __!("Invalid ID of streams")));
}
let mut result = Vec::with_capacity(N);
for messages in ancillary.messages() {
let data = messages.map_err(|e| Error::new(ErrorKind::Other, __!("{:?}", e)))?;
match data {
AncillaryData::ScmRights(scm_rights) => scm_rights.for_each(|fd| result.push(T::from_raw_fd(fd))),
AncillaryData::ScmCredentials(_) => return Err(
Error::new(ErrorKind::InvalidData, __!("Expected ScmRights, got ScmCredentials"))
),
};
}
Ok(<[T; N]>::try_from(result).map_err(|v| err!("Expected to receive {n} stream{s}, got: {len}", n=N, s=plural_s!(N), len=v.len()))?)
}
unsafe fn recv_ioe<B>(&self, id: B) -> Result<[Stdio; 3]> where B: AsRef<[u8]> {
self.recv_streams(id)
}
}
fn verify_id<B>(id: B) -> Result<()> where B: AsRef<[u8]> {
if id.as_ref().is_empty() {
Err(Error::new(ErrorKind::InvalidData, __!("ID must not be empty")))
} else {
Ok(())
}
}
fn make_size_of_streams(count: usize) -> Result<usize> {
const FACTOR: usize = 6;
match RAW_FD_SIZE.checked_mul(count).map(|x| x.checked_mul(FACTOR)) {
Some(Some(result)) => Ok(result),
_ => Err(Error::new(ErrorKind::InvalidData, __!("Stream has too much items: {count}", count=count))),
}
}