use crate::{imp::implements, pod::PlainOldData, Deserializer, Object, Serializer};
use nix::libc::{AF_UNIX, SOCK_CLOEXEC, SOCK_SEQPACKET};
use nix::{
cmsg_space,
sys::socket::{recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags},
};
use std::io::{Error, ErrorKind, IoSlice, IoSliceMut, Result};
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::os::unix::{
io::{FromRawFd, OwnedFd, RawFd},
net::UnixStream,
};
pub(crate) const MAX_PACKET_SIZE: usize = 16 * 1024;
pub(crate) const MAX_PACKET_FDS: usize = 253;
pub(crate) fn socketpair(flags: i32) -> Result<(UnixStream, UnixStream)> {
unsafe {
let mut fds = [0, 0];
if nix::libc::socketpair(
AF_UNIX,
SOCK_SEQPACKET | SOCK_CLOEXEC | flags,
0,
fds.as_mut_ptr(),
) == -1
{
return Err(std::io::Error::last_os_error());
}
Ok((
UnixStream::from_raw_fd(fds[0]),
UnixStream::from_raw_fd(fds[1]),
))
}
}
pub(crate) struct SingleObjectSender<'a> {
socket_fd: RawFd,
bytes: &'a [u8],
fds: Vec<RawFd>,
buffer: Vec<u8>,
data_pos: usize,
fds_pos: usize,
flags: MsgFlags,
}
impl<'a> SingleObjectSender<'a> {
pub(crate) fn new<T: Object>(socket_fd: RawFd, value: &'a T, blocking: bool) -> Self {
let bytes;
let fds;
let buffer;
if implements!(T: PlainOldData) {
bytes = unsafe {
std::slice::from_raw_parts(value as *const T as *const u8, std::mem::size_of::<T>())
};
fds = Vec::new();
buffer = Vec::new();
} else {
bytes = &[];
let mut s = Serializer::new();
s.serialize(value);
fds = s.drain_handles();
buffer = s.into_vec();
}
Self {
socket_fd,
bytes,
fds,
buffer,
data_pos: 0,
fds_pos: 0,
flags: if blocking {
MsgFlags::empty()
} else {
MsgFlags::MSG_DONTWAIT
},
}
}
pub(crate) fn send_next(&mut self) -> Result<()> {
loop {
let buffer_end = self.data().len().min(self.data_pos + MAX_PACKET_SIZE - 1);
let fds_end = self.fds.len().min(self.fds_pos + MAX_PACKET_FDS);
let is_last = buffer_end == self.data().len() && fds_end == self.fds.len();
let n_written = sendmsg::<()>(
self.socket_fd,
&[
IoSlice::new(&[is_last as u8]),
IoSlice::new(&self.data()[self.data_pos..buffer_end]),
],
&[ControlMessage::ScmRights(&self.fds[self.fds_pos..fds_end])],
self.flags,
None,
)?;
self.data_pos += n_written - 1;
self.fds_pos = fds_end;
if is_last {
return Ok(());
}
}
}
fn data(&self) -> &[u8] {
if self.bytes.is_empty() {
&self.buffer
} else {
self.bytes
}
}
}
pub(crate) struct SingleObjectReceiver<T: Object> {
socket_fd: RawFd,
buffer: Vec<u8>,
data_pos: usize,
value: MaybeUninit<T>,
fds: Vec<OwnedFd>,
flags: MsgFlags,
terminated: bool,
marker: PhantomData<fn() -> T>,
}
unsafe impl<T: Object> Send for SingleObjectReceiver<T> {}
impl<T: Object> SingleObjectReceiver<T> {
pub(crate) unsafe fn new(socket_fd: RawFd, blocking: bool) -> Self {
Self {
socket_fd,
buffer: Vec::new(),
data_pos: 0,
value: MaybeUninit::zeroed(),
fds: Vec::new(),
flags: if blocking {
MsgFlags::empty()
} else {
MsgFlags::MSG_DONTWAIT
},
terminated: false,
marker: PhantomData,
}
}
pub(crate) fn recv_next(&mut self) -> Result<Option<T>> {
assert!(
!self.terminated,
"Calling recv_next after it returned Ok(Some(...)) or Err(...) is undefined behavior",
);
loop {
if !implements!(T: PlainOldData) {
self.buffer.resize(self.data_pos + MAX_PACKET_SIZE - 1, 0);
}
let data = if implements!(T: PlainOldData) {
unsafe {
std::slice::from_raw_parts_mut(
self.value.as_mut_ptr() as *mut u8,
std::mem::size_of::<T>(),
)
}
} else {
&mut self.buffer
};
let mut marker = [0];
let mut iovecs = [
IoSliceMut::new(&mut marker),
IoSliceMut::new(&mut data[self.data_pos..]),
];
let mut ancillary = cmsg_space!([RawFd; MAX_PACKET_FDS]);
let message = recvmsg::<()>(
self.socket_fd,
&mut iovecs,
Some(&mut ancillary),
self.flags | MsgFlags::MSG_CMSG_CLOEXEC,
)?;
for cmsg in message.cmsgs() {
if let ControlMessageOwned::ScmRights(rights) = cmsg {
for fd in rights {
self.fds.push(unsafe { OwnedFd::from_raw_fd(fd) });
}
} else {
return Err(Error::new(
ErrorKind::Other,
"Unexpected kind of cmsg on stream",
));
}
}
if message.cmsgs().next().is_none() && message.bytes == 0 {
if self.data_pos == 0 && self.fds.is_empty() {
return Ok(None);
} else {
return Err(Error::new(ErrorKind::Other, "Unterminated data on stream"));
}
}
if message.bytes == 0 {
return Err(Error::new(
ErrorKind::Other,
"Unexpected empty message on stream",
));
}
self.data_pos += message.bytes - 1;
if marker[0] != 1 {
continue;
}
self.terminated = true;
if implements!(T: PlainOldData) {
return Ok(Some(unsafe { self.value.assume_init_read() }));
}
self.buffer.truncate(self.data_pos);
let buffer = std::mem::take(&mut self.buffer);
let fds = std::mem::take(&mut self.fds);
let mut d = Deserializer::new(buffer, fds);
return match unsafe { d.deserialize() } {
Ok(value) => Ok(Some(value)),
Err(e) if e.kind() == ErrorKind::WouldBlock => {
Err(std::io::Error::other("Unexpected blocking event"))
}
Err(e) => Err(e),
};
}
}
}