use colored::Colorize;
use log::error;
#[cfg(not(target_os = "windows"))]
use nix::{sys::signal, unistd};
use std::collections::HashMap;
#[cfg(not(target_os = "windows"))]
use std::collections::HashSet;
use std::ffi::OsStr;
use std::io;
use std::process::Stdio;
use std::time::Duration;
#[cfg(not(target_os = "windows"))]
use std::time::Instant;
#[cfg(not(target_os = "windows"))]
use sysinfo::{ProcessExt, System, SystemExt};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio::task;
use tokio::time::sleep;
#[cfg(target_os = "windows")]
use winapi::um::processthreadsapi::{OpenProcess, TerminateProcess};
#[cfg(target_os = "windows")]
use winapi::um::winnt::{PROCESS_QUERY_INFORMATION, PROCESS_TERMINATE};
#[cfg(target_os = "windows")]
use std::ptr::null_mut;
pub const SLEEP_STEP: Duration = Duration::from_millis(100);
#[cfg(target_os = "windows")]
fn kill(pid: u32) {
let pc = unsafe { OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_TERMINATE, 0, pid) };
if pc != null_mut() {
unsafe { TerminateProcess(pc, 1) };
}
}
pub fn suicide(timeout: Duration, warn: bool) {
if warn {
let msg = format!("Killing process in {:?}", timeout);
eprintln!("{}", msg.red().bold());
}
std::thread::spawn(move || {
let pid = std::process::id();
std::thread::sleep(timeout);
#[allow(clippy::cast_possible_wrap)]
#[cfg(not(target_os = "windows"))]
let _ = signal::kill(unistd::Pid::from_raw(pid as i32), signal::Signal::SIGKILL);
#[cfg(target_os = "windows")]
kill(pid);
});
}
#[derive(Debug)]
pub struct CommandResult {
pub code: Option<i32>,
pub out: Vec<String>,
pub err: Vec<String>,
}
impl Default for CommandResult {
fn default() -> Self {
Self::new()
}
}
impl CommandResult {
#[must_use]
pub fn new() -> Self {
Self {
code: None,
out: Vec::new(),
err: Vec::new(),
}
}
#[must_use]
pub fn ok(&self) -> bool {
match self.code {
Some(v) => v == 0,
None => false,
}
}
}
#[cfg(not(target_os = "windows"))]
fn get_child_pids_recursive(pid: i32, sys: &System, to: &mut HashSet<i32>) {
for (i, p) in sys.processes() {
if let Some(parent) = p.parent() {
if parent == pid {
to.insert(*i);
get_child_pids_recursive(*i, sys, to);
}
};
}
}
#[allow(clippy::cast_possible_wrap)]
#[cfg(not(target_os = "windows"))]
fn _kill_pstree(
pid: u32,
sys: &mut sysinfo::System,
pids: &mut HashSet<i32>,
signal: nix::sys::signal::Signal,
kill_parent: bool,
) {
sys.refresh_processes();
if kill_parent {
pids.insert(pid as i32);
}
get_child_pids_recursive(pid as i32, sys, pids);
for cpid in pids.iter() {
let _ = signal::kill(unistd::Pid::from_raw(*cpid), signal);
}
}
#[cfg(not(target_os = "windows"))]
pub fn kill_pstree_sync(pid: u32, kill_parent: bool) {
let mut sys = System::new();
let mut pids = HashSet::new();
_kill_pstree(
pid,
&mut sys,
&mut pids,
signal::Signal::SIGKILL,
kill_parent,
);
}
#[allow(clippy::cast_possible_wrap)]
#[cfg(not(target_os = "windows"))]
pub async fn kill_pstree(pid: u32, tki: Option<Duration>, kill_parent: bool) {
let mut sys = System::new();
let mut pids = HashSet::new();
let signal = if tki.is_some() {
signal::Signal::SIGTERM
} else {
signal::Signal::SIGKILL
};
_kill_pstree(pid, &mut sys, &mut pids, signal, kill_parent);
if let Some(t) = tki {
let now = Instant::now();
while now.elapsed() < t {
sleep(SLEEP_STEP).await;
if signal::kill(unistd::Pid::from_raw(pid as i32), signal::Signal::SIGTERM).is_err() {
break;
}
}
_kill_pstree(
pid,
&mut sys,
&mut pids,
signal::Signal::SIGKILL,
kill_parent,
);
}
}
#[derive(Debug)]
enum CommandFrame {
Finished(i32),
Terminated,
Stdout(String),
Stderr(String),
Error(io::Error),
}
#[derive(Default, Clone)]
pub struct Options<'a> {
environment: HashMap<&'a str, &'a str>,
tki: Option<Duration>,
input_data: Option<std::borrow::Cow<'a, Vec<u8>>>,
}
impl<'a> Options<'a> {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn tki(mut self, t: Duration) -> Self {
self.tki.replace(t);
self
}
#[inline]
pub fn input(mut self, data: std::borrow::Cow<'a, Vec<u8>>) -> Self {
self.input_data.replace(data);
self
}
#[inline]
pub fn env(mut self, name: &'a str, value: &'a str) -> Self {
self.environment.insert(name, value);
self
}
#[inline]
pub fn environment(&self) -> &HashMap<&str, &str> {
&self.environment
}
#[inline]
pub fn environment_mut(&'a mut self) -> &mut HashMap<&str, &str> {
&mut self.environment
}
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::missing_panics_doc)]
pub async fn command<'a, I, S>(
command: &str,
args: I,
timeout: Duration,
opts: Options<'a>,
) -> Result<CommandResult, io::Error>
where
I: IntoIterator<Item = S>,
S: AsRef<OsStr>,
{
let mut child = Command::new(command)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.args(args)
.envs(opts.environment)
.spawn()?;
let stdin = if opts.input_data.is_some() {
match child.stdin.take() {
Some(v) => Some(v),
None => {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to create stdin writer",
))
}
}
} else {
None
};
let stdin_writer = stdin.map(BufWriter::new);
let stdout = match child.stdout.take() {
Some(v) => v,
None => {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to create stdout reader",
))
}
};
let mut stdout_reader = BufReader::new(stdout).lines();
let stderr = match child.stderr.take() {
Some(v) => v,
None => {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Unable to create stderr reader",
))
}
};
let mut stderr_reader = BufReader::new(stderr).lines();
let ppid = child.id();
let (tx_runner, mut rx) = mpsc::channel(2);
let tx_guard = tx_runner.clone();
let tx_out = tx_runner.clone();
let tx_err = tx_runner.clone();
let runner = task::spawn(async move {
let frame = match child.wait().await {
Ok(v) => CommandFrame::Finished(if let Some(v) = v.code() {
v
} else {
sleep(timeout).await;
-15
}),
Err(e) => CommandFrame::Error(e),
};
let _r = tx_runner.send(frame).await;
});
let guard = ppid.map(|pid| {
task::spawn(async move {
sleep(timeout).await;
#[allow(clippy::cast_possible_wrap)]
#[cfg(not(target_os = "windows"))]
kill_pstree(pid, opts.tki, true).await;
#[cfg(target_os = "windows")]
kill(pid);
let _r = tx_guard.send(CommandFrame::Terminated).await;
})
});
let fut_stdin = stdin_writer.map(|mut writer| {
let input_data = opts.input_data.unwrap().into_owned();
task::spawn(async move {
if let Err(e) = writer.write_all(&input_data).await {
error!("Unable to write to stdin: {}", e);
} else if let Err(e) = writer.flush().await {
error!("Unable to flush stdin: {}", e);
}
})
});
let fut_stdout = task::spawn(async move {
while let Some(line) = match stdout_reader.next_line().await {
Ok(v) => v,
Err(e) => {
let _r = tx_out.send(CommandFrame::Error(e)).await;
return;
}
} {
let _r = tx_out.send(CommandFrame::Stdout(line)).await;
}
});
let fut_stderr = task::spawn(async move {
while let Some(line) = match stderr_reader.next_line().await {
Ok(v) => v,
Err(e) => {
let _r = tx_err.send(CommandFrame::Error(e)).await;
return;
}
} {
let _r = tx_err.send(CommandFrame::Stderr(line)).await;
}
});
let mut result = CommandResult::new();
while let Some(r) = rx.recv().await {
match r {
CommandFrame::Finished(code) => {
if let Some(g) = guard {
g.abort();
}
result.code = Some(code);
while let Some(r) = rx.recv().await {
match r {
CommandFrame::Stdout(v) => result.out.push(v),
CommandFrame::Stderr(v) => result.err.push(v),
_ => {}
}
}
return Ok(result);
}
CommandFrame::Terminated => {
runner.abort();
if let Some(f) = fut_stdin {
f.abort();
}
fut_stdout.abort();
fut_stderr.abort();
return Ok(result);
}
CommandFrame::Error(e) => {
runner.abort();
if let Some(g) = guard {
g.abort();
}
if let Some(f) = fut_stdin {
f.abort();
}
fut_stdout.abort();
fut_stderr.abort();
#[allow(clippy::cast_possible_wrap)]
ppid.map(|pid| async move {
#[cfg(not(target_os = "windows"))]
kill_pstree(pid, opts.tki, true).await;
#[cfg(target_os = "windows")]
kill(pid);
});
return Err(e);
}
CommandFrame::Stdout(v) => result.out.push(v),
CommandFrame::Stderr(v) => result.err.push(v),
}
}
Ok(result)
}