use std::ffi::CString;
use std::io;
use std::os::fd::{AsRawFd, OwnedFd, RawFd};
use std::path::PathBuf;
use std::process::ExitStatus;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use nix::pty::ForkptyResult;
use nix::sys::signal::Signal;
use nix::sys::wait::{WaitStatus, waitpid};
use nix::unistd::{Pid, chdir, execvp};
use tokio::io::unix::AsyncFd;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::procserv::error::{ProcServError, ProcServResult};
#[derive(Debug)]
pub enum ChildEvent {
Output(Vec<u8>),
Exited { status: Option<ExitStatus> },
}
#[derive(Debug, Clone)]
pub struct ChildSpec {
pub program: PathBuf,
pub args: Vec<String>,
pub cwd: Option<PathBuf>,
pub ignore_chars: Vec<u8>,
}
#[derive(Clone)]
pub struct ChildHandle {
pid: Pid,
master: Arc<AsyncFd<OwnedFd>>,
ignore_chars: Arc<Vec<u8>>,
alive: Arc<AtomicBool>,
}
impl std::fmt::Debug for ChildHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChildHandle")
.field("pid", &self.pid.as_raw())
.field("alive", &self.alive.load(Ordering::Relaxed))
.finish()
}
}
impl ChildHandle {
pub fn spawn(spec: &ChildSpec) -> ProcServResult<(Self, mpsc::Receiver<ChildEvent>)> {
let result = unsafe { nix::pty::forkpty(None, None) }
.map_err(|e| ProcServError::Forkpty(e.to_string()))?;
match result {
ForkptyResult::Parent { child, master } => {
let alive = Arc::new(AtomicBool::new(true));
set_nonblocking(&master)
.map_err(|e| ProcServError::Forkpty(format!("set O_NONBLOCK: {e}")))?;
let master_fd = Arc::new(
AsyncFd::new(master)
.map_err(|e| ProcServError::Forkpty(format!("AsyncFd: {e}")))?,
);
let (tx, rx) = mpsc::channel::<ChildEvent>(64);
spawn_reader(master_fd.clone(), tx.clone());
spawn_reaper(child, alive.clone(), tx);
Ok((
Self {
pid: child,
master: master_fd,
ignore_chars: Arc::new(spec.ignore_chars.clone()),
alive,
},
rx,
))
}
ForkptyResult::Child => {
in_child_setup_and_exec(spec);
}
}
}
pub async fn write_stdin(&self, bytes: &[u8]) -> ProcServResult<()> {
if !self.alive.load(Ordering::Acquire) {
return Err(ProcServError::ChildExited(None));
}
let filtered: Vec<u8> = if self.ignore_chars.is_empty() {
bytes.to_vec()
} else {
bytes
.iter()
.copied()
.filter(|b| !self.ignore_chars.contains(b))
.collect()
};
if filtered.is_empty() {
return Ok(());
}
let mut written = 0;
while written < filtered.len() {
let mut guard = self.master.writable().await.map_err(ProcServError::Io)?;
let raw = self.master.as_ref().as_raw_fd();
let result = guard.try_io(|_| {
let n = unsafe {
libc::write(
raw,
filtered[written..].as_ptr() as *const libc::c_void,
filtered.len() - written,
)
};
if n < 0 {
Err(io::Error::last_os_error())
} else {
Ok(n as usize)
}
});
match result {
Ok(Ok(n)) => written += n,
Ok(Err(e)) => return Err(ProcServError::Io(e)),
Err(_would_block) => continue,
}
}
Ok(())
}
pub fn signal(&self, signo: i32) -> ProcServResult<()> {
let sig = Signal::try_from(signo)
.map_err(|e| ProcServError::Config(format!("invalid signal {signo}: {e}")))?;
let pgid = Pid::from_raw(-self.pid.as_raw());
nix::sys::signal::kill(pgid, sig)
.map_err(|e| ProcServError::Io(io::Error::other(e.to_string())))?;
Ok(())
}
pub fn is_alive(&self) -> bool {
self.alive.load(Ordering::Acquire)
}
pub fn pid(&self) -> i32 {
self.pid.as_raw()
}
}
fn in_child_setup_and_exec(spec: &ChildSpec) -> ! {
if let Some(ref cwd) = spec.cwd {
let c_cwd = match CString::new(cwd.as_os_str().as_encoded_bytes()) {
Ok(c) => c,
Err(_) => {
eprintln!("procserv child: invalid chdir path");
std::process::exit(126);
}
};
if let Err(e) = chdir(c_cwd.as_c_str()) {
eprintln!("procserv child: chdir to {} failed: {e}", cwd.display());
std::process::exit(126);
}
}
let prog = match CString::new(spec.program.as_os_str().as_encoded_bytes()) {
Ok(c) => c,
Err(_) => {
eprintln!("procserv child: program name contains NUL");
std::process::exit(126);
}
};
let mut argv: Vec<CString> = Vec::with_capacity(1 + spec.args.len());
argv.push(prog.clone());
for a in &spec.args {
match CString::new(a.as_bytes()) {
Ok(c) => argv.push(c),
Err(_) => {
eprintln!("procserv child: argument contains NUL: {a:?}");
std::process::exit(126);
}
}
}
let argv_refs: Vec<&std::ffi::CStr> = argv.iter().map(|c| c.as_c_str()).collect();
match execvp(prog.as_c_str(), &argv_refs) {
Ok(infallible) => match infallible {},
Err(e) => {
eprintln!(
"procserv child: execvp({}) failed: {e}",
spec.program.display()
);
std::process::exit(127);
}
}
}
fn set_nonblocking(fd: &OwnedFd) -> io::Result<()> {
let raw = fd.as_raw_fd();
let flags = unsafe { libc::fcntl(raw, libc::F_GETFL) };
if flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { libc::fcntl(raw, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn spawn_reader(master: Arc<AsyncFd<OwnedFd>>, tx: mpsc::Sender<ChildEvent>) -> JoinHandle<()> {
tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let raw: RawFd = master.as_ref().as_raw_fd();
loop {
let mut guard = match master.readable().await {
Ok(g) => g,
Err(e) => {
tracing::debug!(error = %e, "procserv child PTY readable() ended");
break;
}
};
match guard.try_io(|_| {
let n =
unsafe { libc::read(raw, buf.as_mut_ptr() as *mut libc::c_void, buf.len()) };
if n < 0 {
Err(io::Error::last_os_error())
} else {
Ok(n as usize)
}
}) {
Ok(Ok(0)) => break, Ok(Ok(n)) => {
let chunk = buf[..n].to_vec();
if tx.send(ChildEvent::Output(chunk)).await.is_err() {
break;
}
}
Ok(Err(e)) => {
if e.raw_os_error() == Some(libc::EIO) {
tracing::debug!("procserv child PTY EIO (slave closed)");
} else {
tracing::debug!(error = %e, "procserv child PTY read error");
}
break;
}
Err(_would_block) => continue,
}
}
})
}
fn spawn_reaper(pid: Pid, alive: Arc<AtomicBool>, tx: mpsc::Sender<ChildEvent>) -> JoinHandle<()> {
tokio::task::spawn(async move {
let res = tokio::task::spawn_blocking(move || waitpid(pid, None))
.await
.ok();
let exit_code = match res {
Some(Ok(WaitStatus::Exited(_, code))) => Some(make_exit_status(code)),
Some(Ok(WaitStatus::Signaled(_, sig, _))) => Some(make_exit_status(128 + sig as i32)),
_ => None,
};
alive.store(false, Ordering::Release);
let _ = tx.send(ChildEvent::Exited { status: exit_code }).await;
})
}
#[cfg(unix)]
fn make_exit_status(code: i32) -> ExitStatus {
use std::os::unix::process::ExitStatusExt;
ExitStatus::from_raw(code << 8)
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
async fn drain_until_closed(
rx: &mut mpsc::Receiver<ChildEvent>,
deadline: tokio::time::Instant,
) -> (Vec<u8>, bool) {
let mut output = Vec::new();
let mut exited = false;
while tokio::time::Instant::now() < deadline {
tokio::select! {
ev = rx.recv() => match ev {
Some(ChildEvent::Output(b)) => output.extend_from_slice(&b),
Some(ChildEvent::Exited { .. }) => exited = true,
None => break,
},
_ = sleep(Duration::from_millis(50)) => {}
}
}
(output, exited)
}
#[tokio::test]
async fn spawn_echo_child_yields_output_and_exits() {
let spec = ChildSpec {
program: PathBuf::from("/bin/echo"),
args: vec!["hello procserv".into()],
cwd: None,
ignore_chars: Vec::new(),
};
let (handle, mut rx) = ChildHandle::spawn(&spec).expect("spawn");
let deadline = tokio::time::Instant::now() + Duration::from_secs(3);
let (output, exited) = drain_until_closed(&mut rx, deadline).await;
assert!(exited, "child should have exited");
assert!(!handle.is_alive(), "alive flag should flip false");
let text = String::from_utf8_lossy(&output);
assert!(text.contains("hello procserv"), "got: {text:?}");
}
#[tokio::test]
async fn write_stdin_filters_ignore_chars() {
let spec = ChildSpec {
program: PathBuf::from("/bin/cat"),
args: vec![],
cwd: None,
ignore_chars: vec![b'X'],
};
let (handle, mut rx) = ChildHandle::spawn(&spec).expect("spawn");
sleep(Duration::from_millis(150)).await;
handle.write_stdin(b"abXXcd\n").await.expect("write");
sleep(Duration::from_millis(150)).await;
handle.write_stdin(&[0x04]).await.ok();
let deadline = tokio::time::Instant::now() + Duration::from_secs(3);
let (output, _) = drain_until_closed(&mut rx, deadline).await;
let text = String::from_utf8_lossy(&output);
assert!(text.contains("abcd"), "filter stripped X's, got: {text:?}");
assert!(
!text.contains('X'),
"X bytes should not appear, got: {text:?}"
);
}
}