use std::{
ffi::c_int,
io::{self, Read, Write},
mem::size_of,
os::{
fd::{AsRawFd, RawFd},
unix::net::UnixStream,
},
};
use crate::system::interface::ProcessId;
use crate::{exec::signal_fmt, system::wait::WaitStatus};
type Prefix = u8;
type ParentData = c_int;
type MonitorData = c_int;
const PREFIX_LEN: usize = size_of::<Prefix>();
const PARENT_DATA_LEN: usize = size_of::<ParentData>();
const MONITOR_DATA_LEN: usize = size_of::<MonitorData>();
pub(super) struct BackchannelPair {
pub(super) parent: ParentBackchannel,
pub(super) monitor: MonitorBackchannel,
}
impl BackchannelPair {
pub(super) fn new() -> io::Result<Self> {
let (sock1, sock2) = UnixStream::pair()?;
sock1.set_nonblocking(true)?;
sock2.set_nonblocking(true)?;
Ok(Self {
parent: ParentBackchannel { socket: sock1 },
monitor: MonitorBackchannel { socket: sock2 },
})
}
}
pub(super) enum ParentMessage {
IoError(c_int),
CommandStatus(WaitStatus),
CommandPid(ProcessId),
ShortRead,
}
impl ParentMessage {
const LEN: usize = PREFIX_LEN + PARENT_DATA_LEN;
const IO_ERROR: Prefix = 0;
const CMD_STATUS: Prefix = 1;
const CMD_PID: Prefix = 2;
const SHORT_READ: Prefix = 3;
fn from_parts(prefix: Prefix, data: ParentData) -> Self {
match prefix {
Self::IO_ERROR => Self::IoError(data),
Self::CMD_STATUS => Self::CommandStatus(WaitStatus::from_raw(data)),
Self::CMD_PID => Self::CommandPid(data),
Self::SHORT_READ => Self::ShortRead,
_ => unreachable!(),
}
}
fn to_parts(&self) -> (Prefix, ParentData) {
let prefix = match self {
ParentMessage::IoError(_) => Self::IO_ERROR,
ParentMessage::CommandStatus(_) => Self::CMD_STATUS,
ParentMessage::CommandPid(_) => Self::CMD_PID,
ParentMessage::ShortRead => Self::SHORT_READ,
};
let data = match self {
ParentMessage::IoError(data) | ParentMessage::CommandPid(data) => *data,
ParentMessage::CommandStatus(status) => status.into_raw(),
ParentMessage::ShortRead => 0,
};
(prefix, data)
}
}
impl TryFrom<io::Error> for ParentMessage {
type Error = io::Error;
fn try_from(err: io::Error) -> Result<Self, Self::Error> {
err.raw_os_error()
.map(Self::IoError)
.or_else(|| (err.kind() == io::ErrorKind::UnexpectedEof).then_some(Self::ShortRead))
.ok_or(err)
}
}
impl From<WaitStatus> for ParentMessage {
fn from(status: WaitStatus) -> Self {
Self::CommandStatus(status)
}
}
pub(super) struct ParentBackchannel {
socket: UnixStream,
}
impl ParentBackchannel {
pub(super) fn send(&mut self, event: &MonitorMessage) -> io::Result<()> {
let mut buf = [0; MonitorMessage::LEN];
let (prefix_buf, data_buf) = buf.split_at_mut(PREFIX_LEN);
let (prefix, data) = event.to_parts();
prefix_buf.copy_from_slice(&prefix.to_ne_bytes());
data_buf.copy_from_slice(&data.to_ne_bytes());
self.socket.write_all(&buf)
}
pub(super) fn recv(&mut self) -> io::Result<ParentMessage> {
let mut buf = [0; ParentMessage::LEN];
self.socket.read_exact(&mut buf)?;
let (prefix_buf, data_buf) = buf.split_at(PREFIX_LEN);
let prefix = Prefix::from_ne_bytes(prefix_buf.try_into().unwrap());
let data = ParentData::from_ne_bytes(data_buf.try_into().unwrap());
Ok(ParentMessage::from_parts(prefix, data))
}
}
impl AsRawFd for ParentBackchannel {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
#[derive(PartialEq, Eq)]
pub(super) enum MonitorMessage {
ExecCommand,
Signal(c_int),
}
impl MonitorMessage {
const LEN: usize = PREFIX_LEN + MONITOR_DATA_LEN;
const EXEC_CMD: Prefix = 0;
const SIGNAL: Prefix = 1;
fn from_parts(prefix: Prefix, data: MonitorData) -> Self {
match prefix {
Self::EXEC_CMD => Self::ExecCommand,
Self::SIGNAL => Self::Signal(data),
_ => unreachable!(),
}
}
fn to_parts(&self) -> (Prefix, MonitorData) {
let prefix = match self {
MonitorMessage::ExecCommand => Self::EXEC_CMD,
MonitorMessage::Signal(_) => Self::SIGNAL,
};
let data = match self {
MonitorMessage::ExecCommand => 0,
MonitorMessage::Signal(data) => *data,
};
(prefix, data)
}
}
impl std::fmt::Debug for MonitorMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ExecCommand => "ExecCommand".fmt(f),
&Self::Signal(signal) => write!(f, "Signal({})", signal_fmt(signal)),
}
}
}
pub(super) struct MonitorBackchannel {
socket: UnixStream,
}
impl MonitorBackchannel {
pub(super) fn send(&mut self, event: &ParentMessage) -> io::Result<()> {
let mut buf = [0; ParentMessage::LEN];
let (prefix_buf, data_buf) = buf.split_at_mut(PREFIX_LEN);
let (prefix, data) = event.to_parts();
prefix_buf.copy_from_slice(&prefix.to_ne_bytes());
data_buf.copy_from_slice(&data.to_ne_bytes());
self.socket.write_all(&buf)
}
pub(super) fn recv(&mut self) -> io::Result<MonitorMessage> {
let mut buf = [0; MonitorMessage::LEN];
self.socket.read_exact(&mut buf)?;
let (prefix_buf, data_buf) = buf.split_at(PREFIX_LEN);
let prefix = Prefix::from_ne_bytes(prefix_buf.try_into().unwrap());
let data = MonitorData::from_ne_bytes(data_buf.try_into().unwrap());
Ok(MonitorMessage::from_parts(prefix, data))
}
}
impl AsRawFd for MonitorBackchannel {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}