use std::{
future::Future,
io::{Error, Result},
ops::ControlFlow,
os::unix::process::ExitStatusExt,
process::ExitStatus,
};
use nix::{
errno::Errno,
libc,
sys::{
signal::{killpg, Signal},
wait::WaitPidFlag,
},
unistd::{setpgid, Pid},
};
use tokio::{
process::{Child, Command},
task::spawn_blocking,
};
use tracing::instrument;
use crate::ChildExitStatus;
use super::{TokioChildWrapper, TokioCommandWrap, TokioCommandWrapper};
#[derive(Debug, Clone)]
pub struct ProcessGroup {
leader: Pid,
}
impl ProcessGroup {
pub fn leader() -> Self {
Self {
leader: Pid::from_raw(0),
}
}
pub fn attach_to(leader: u32) -> Self {
Self {
leader: Pid::from_raw(leader as i32),
}
}
}
#[derive(Debug)]
pub struct ProcessGroupChild {
inner: Box<dyn TokioChildWrapper>,
exit_status: ChildExitStatus,
pgid: Pid,
}
impl ProcessGroupChild {
#[instrument(level = "debug")]
pub(crate) fn new(inner: Box<dyn TokioChildWrapper>, pgid: Pid) -> Self {
Self {
inner,
exit_status: ChildExitStatus::Running,
pgid,
}
}
pub fn pgid(&self) -> u32 {
self.pgid.as_raw() as _
}
}
impl TokioCommandWrapper for ProcessGroup {
#[instrument(level = "debug", skip(self))]
fn pre_spawn(&mut self, command: &mut Command, _core: &TokioCommandWrap) -> Result<()> {
#[cfg(tokio_unstable)]
{
command.process_group(self.leader.as_raw());
}
#[cfg(not(tokio_unstable))]
let leader = self.leader;
unsafe {
command.pre_exec(move || {
setpgid(Pid::this(), leader)
.map_err(Error::from)
.map(|_| ())
});
}
Ok(())
}
#[instrument(level = "debug", skip(self))]
fn wrap_child(
&mut self,
inner: Box<dyn TokioChildWrapper>,
_core: &TokioCommandWrap,
) -> Result<Box<dyn TokioChildWrapper>> {
let pgid = Pid::from_raw(
i32::try_from(
inner
.id()
.expect("Command was reaped before we could read its PID"),
)
.expect("Command PID > i32::MAX"),
);
Ok(Box::new(ProcessGroupChild::new(inner, pgid)))
}
}
impl ProcessGroupChild {
#[instrument(level = "debug", skip(self))]
fn signal_imp(&self, sig: Signal) -> Result<()> {
killpg(self.pgid, sig).map_err(Error::from)
}
#[instrument(level = "debug")]
fn wait_imp(pgid: Pid, flag: WaitPidFlag) -> Result<ControlFlow<Option<ExitStatus>>> {
let mut parent_exit_status: Option<ExitStatus> = None;
loop {
let mut status: i32 = 0;
match unsafe {
libc::waitpid(-pgid.as_raw(), &mut status as *mut libc::c_int, flag.bits())
} {
0 => {
return Ok(ControlFlow::Continue(()));
}
-1 => {
match Errno::last() {
Errno::ECHILD => {
return Ok(ControlFlow::Break(parent_exit_status));
}
errno => {
return Err(Error::from(errno));
}
}
}
pid => {
if pgid == Pid::from_raw(pid) {
parent_exit_status = Some(ExitStatus::from_raw(status));
} else {
}
}
};
}
}
}
impl TokioChildWrapper for ProcessGroupChild {
fn inner(&self) -> &Child {
self.inner.inner()
}
fn inner_mut(&mut self) -> &mut Child {
self.inner.inner_mut()
}
fn into_inner(self: Box<Self>) -> Child {
self.inner.into_inner()
}
#[instrument(level = "debug", skip(self))]
fn start_kill(&mut self) -> Result<()> {
self.signal_imp(Signal::SIGKILL)
}
#[instrument(level = "debug", skip(self))]
fn wait(&mut self) -> Box<dyn Future<Output = Result<ExitStatus>> + '_> {
Box::new(async {
if let ChildExitStatus::Exited(status) = &self.exit_status {
return Ok(*status);
}
const MAX_RETRY_ATTEMPT: usize = 10;
let pgid = self.pgid;
let status = Box::into_pin(self.inner.wait()).await?;
self.exit_status = ChildExitStatus::Exited(status);
for _ in 1..MAX_RETRY_ATTEMPT {
if Self::wait_imp(pgid, WaitPidFlag::WNOHANG)?.is_break() {
return Ok(status);
}
}
spawn_blocking(move || Self::wait_imp(pgid, WaitPidFlag::empty())).await??;
Ok(status)
})
}
#[instrument(level = "debug", skip(self))]
fn try_wait(&mut self) -> Result<Option<ExitStatus>> {
if let ChildExitStatus::Exited(status) = &self.exit_status {
return Ok(Some(*status));
}
match Self::wait_imp(self.pgid, WaitPidFlag::WNOHANG)? {
ControlFlow::Break(res) => {
if let Some(status) = res {
self.exit_status = ChildExitStatus::Exited(status);
}
Ok(res)
}
ControlFlow::Continue(()) => {
let exited = self.inner.try_wait()?;
if let Some(exited) = exited {
self.exit_status = ChildExitStatus::Exited(exited);
}
Ok(exited)
}
}
}
fn signal(&self, sig: i32) -> Result<()> {
self.signal_imp(Signal::try_from(sig)?)
}
}