use std::collections::HashMap;
use std::ffi::{OsStr, OsString};
use std::fmt;
use std::path::PathBuf;
use std::process::Stdio;
use std::process::{ChildStderr, ChildStdin, ChildStdout};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::{env, mem, process};
use std::{io, thread};
use ipc_channel::ipc::{self, IpcOneShotServer, IpcReceiver, IpcSender};
use serde::{de::DeserializeOwned, Serialize};
use crate::core::{assert_spawn_okay, should_pass_args, MarshalledCall, ENV_NAME};
use crate::error::{PanicInfo, SpawnError};
use crate::pool::PooledHandle;
use crate::serde::with_ipc_mode;
type PreExecFunc = dyn FnMut() -> io::Result<()> + Send + Sync + 'static;
#[derive(Clone)]
pub struct ProcCommon {
pub vars: HashMap<OsString, OsString>,
#[cfg(unix)]
pub uid: Option<u32>,
#[cfg(unix)]
pub gid: Option<u32>,
#[cfg(unix)]
pub pre_exec: Option<Arc<Mutex<Box<PreExecFunc>>>>,
}
impl fmt::Debug for ProcCommon {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("ProcCommon")
.field("vars", &self.vars)
.finish()
}
}
impl Default for ProcCommon {
fn default() -> ProcCommon {
ProcCommon {
vars: std::env::vars_os().collect(),
#[cfg(unix)]
uid: None,
#[cfg(unix)]
gid: None,
#[cfg(unix)]
pre_exec: None,
}
}
}
#[derive(Debug, Default)]
pub struct Builder {
stdin: Option<Stdio>,
stdout: Option<Stdio>,
stderr: Option<Stdio>,
common: ProcCommon,
}
macro_rules! define_common_methods {
() => {
pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
where
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
self.common
.vars
.insert(key.as_ref().to_owned(), val.as_ref().to_owned());
self
}
pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<OsStr>,
V: AsRef<OsStr>,
{
self.common.vars.extend(
vars.into_iter()
.map(|(k, v)| (k.as_ref().to_owned(), v.as_ref().to_owned())),
);
self
}
pub fn env_remove<K: AsRef<OsStr>>(&mut self, key: K) -> &mut Self {
self.common.vars.remove(key.as_ref());
self
}
pub fn env_clear(&mut self) -> &mut Self {
self.common.vars.clear();
self
}
#[cfg(unix)]
pub fn uid(&mut self, id: u32) -> &mut Self {
self.common.uid = Some(id);
self
}
#[cfg(unix)]
pub fn gid(&mut self, id: u32) -> &mut Self {
self.common.gid = Some(id);
self
}
#[cfg(unix)]
pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Self
where
F: FnMut() -> io::Result<()> + Send + Sync + 'static,
{
self.common.pre_exec = Some(Arc::new(Mutex::new(Box::new(f))));
self
}
};
}
impl Builder {
pub fn new() -> Self {
Self {
stdin: None,
stdout: None,
stderr: None,
common: ProcCommon::default(),
}
}
pub(crate) fn common(&mut self, common: ProcCommon) -> &mut Self {
self.common = common;
self
}
define_common_methods!();
pub fn stdin<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stdin = Some(cfg.into());
self
}
pub fn stdout<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stdout = Some(cfg.into());
self
}
pub fn stderr<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
self.stderr = Some(cfg.into());
self
}
pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
&mut self,
args: A,
func: fn(A) -> R,
) -> JoinHandle<R> {
assert_spawn_okay();
JoinHandle {
inner: mem::take(self)
.spawn_helper(args, func)
.map(JoinHandleInner::Process),
}
}
fn spawn_helper<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
self,
args: A,
func: fn(A) -> R,
) -> Result<ProcessHandle<R>, SpawnError> {
let (server, token) = IpcOneShotServer::<IpcSender<MarshalledCall>>::new()?;
let me = if cfg!(target_os = "linux") {
let path: PathBuf = "/proc/self/exe".into();
if path.is_file() {
path
} else {
env::current_exe()?
}
} else {
env::current_exe()?
};
let mut child = process::Command::new(me);
child.envs(self.common.vars.into_iter());
child.env(ENV_NAME, token);
#[cfg(unix)]
{
use std::os::unix::process::CommandExt;
if let Some(id) = self.common.uid {
child.uid(id);
}
if let Some(id) = self.common.gid {
child.gid(id);
}
if let Some(ref func) = self.common.pre_exec {
let func = func.clone();
unsafe {
#[allow(clippy::needless_borrow)]
child.pre_exec(move || (&mut *func.lock().unwrap())());
}
}
}
let (can_pass_args, should_silence_stdout) = {
#[cfg(feature = "test-support")]
{
match crate::testsupport::update_command_for_tests(&mut child) {
None => (true, false),
Some(crate::testsupport::TestMode {
can_pass_args,
should_silence_stdout,
}) => (can_pass_args, should_silence_stdout),
}
}
#[cfg(not(feature = "test-support"))]
{
(true, false)
}
};
if can_pass_args && should_pass_args() {
child.args(env::args_os().skip(1));
}
if let Some(stdin) = self.stdin {
child.stdin(stdin);
}
if let Some(stdout) = self.stdout {
child.stdout(stdout);
} else if should_silence_stdout {
child.stdout(Stdio::null());
}
if let Some(stderr) = self.stderr {
child.stderr(stderr);
}
let process = child.spawn()?;
let (_rx, tx) = server.accept()?;
let (args_tx, args_rx) = ipc::channel()?;
let (return_tx, return_rx) = ipc::channel()?;
tx.send(MarshalledCall::marshal::<A, R>(func, args_rx, return_tx))?;
with_ipc_mode(|| -> Result<_, SpawnError> {
args_tx.send(args)?;
Ok(())
})?;
Ok(ProcessHandle {
recv: return_rx,
state: Arc::new(ProcessHandleState::new(Some(process.id()))),
process,
})
}
}
#[derive(Debug)]
pub struct ProcessHandleState {
pub exited: AtomicBool,
pub pid: AtomicUsize,
}
impl ProcessHandleState {
pub fn new(pid: Option<u32>) -> ProcessHandleState {
ProcessHandleState {
exited: AtomicBool::new(false),
pid: AtomicUsize::new(pid.unwrap_or(0) as usize),
}
}
pub fn pid(&self) -> Option<u32> {
match self.pid.load(Ordering::SeqCst) {
0 => None,
x => Some(x as u32),
}
}
pub fn kill(&self) {
if !self.exited.load(Ordering::SeqCst) {
self.exited.store(true, Ordering::SeqCst);
if let Some(pid) = self.pid() {
unsafe {
libc::kill(pid as i32, libc::SIGKILL);
}
}
}
}
}
pub struct ProcessHandle<T> {
pub(crate) recv: IpcReceiver<Result<T, PanicInfo>>,
pub(crate) process: process::Child,
pub(crate) state: Arc<ProcessHandleState>,
}
fn is_ipc_timeout(err: &ipc_channel::ipc::TryRecvError) -> bool {
matches!(err, ipc_channel::ipc::TryRecvError::Empty)
}
impl<T> ProcessHandle<T> {
pub fn state(&self) -> Arc<ProcessHandleState> {
self.state.clone()
}
pub fn kill(&mut self) -> Result<(), SpawnError> {
if self.state.exited.load(Ordering::SeqCst) {
return Ok(());
}
let rv = self.process.kill().map_err(Into::into);
self.wait();
rv
}
pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
self.process.stdin.as_mut()
}
pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
self.process.stdout.as_mut()
}
pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
self.process.stderr.as_mut()
}
fn wait(&mut self) {
self.process.wait().ok();
self.state.exited.store(true, Ordering::SeqCst);
}
}
impl<T: Serialize + DeserializeOwned> ProcessHandle<T> {
pub fn join(&mut self) -> Result<T, SpawnError> {
let rv = with_ipc_mode(|| self.recv.recv())?.map_err(Into::into);
self.wait();
rv
}
pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
let deadline = match Instant::now().checked_add(timeout) {
Some(deadline) => deadline,
None => {
return Err(io::Error::new(io::ErrorKind::Other, "timeout out of bounds").into())
}
};
let mut to_sleep = Duration::from_millis(1);
let rv = loop {
match with_ipc_mode(|| self.recv.try_recv()) {
Ok(rv) => break rv.map_err(Into::into),
Err(err) if is_ipc_timeout(&err) => {
if let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
thread::sleep(remaining.min(to_sleep));
to_sleep *= 2;
} else {
return Err(SpawnError::new_timeout());
}
}
Err(err) => return Err(err.into()),
}
};
self.wait();
rv
}
}
pub enum JoinHandleInner<T> {
Process(ProcessHandle<T>),
Pooled(PooledHandle<T>),
}
pub struct JoinHandle<T> {
pub(crate) inner: Result<JoinHandleInner<T>, SpawnError>,
}
impl<T> fmt::Debug for JoinHandle<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("JoinHandle")
.field("pid", &self.pid())
.finish()
}
}
impl<T> JoinHandle<T> {
pub(crate) fn process_handle_state(&self) -> Option<Arc<ProcessHandleState>> {
match self.inner {
Ok(JoinHandleInner::Process(ref handle)) => Some(handle.state()),
Ok(JoinHandleInner::Pooled(ref handle)) => handle.process_handle_state(),
Err(..) => None,
}
}
pub fn pid(&self) -> Option<u32> {
self.process_handle_state().and_then(|x| x.pid())
}
pub fn kill(&mut self) -> Result<(), SpawnError> {
match self.inner {
Ok(JoinHandleInner::Process(ref mut handle)) => handle.kill(),
Ok(JoinHandleInner::Pooled(ref mut handle)) => handle.kill(),
Err(_) => Ok(()),
}
}
pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
match self.inner {
Ok(JoinHandleInner::Process(ref mut process)) => process.stdin(),
Ok(JoinHandleInner::Pooled(..)) => None,
Err(_) => None,
}
}
pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
match self.inner {
Ok(JoinHandleInner::Process(ref mut process)) => process.stdout(),
Ok(JoinHandleInner::Pooled(..)) => None,
Err(_) => None,
}
}
pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
match self.inner {
Ok(JoinHandleInner::Process(ref mut process)) => process.stderr(),
Ok(JoinHandleInner::Pooled(..)) => None,
Err(_) => None,
}
}
}
impl<T: Serialize + DeserializeOwned> JoinHandle<T> {
pub fn join(self) -> Result<T, SpawnError> {
match self.inner {
Ok(JoinHandleInner::Process(mut handle)) => handle.join(),
Ok(JoinHandleInner::Pooled(mut handle)) => handle.join(),
Err(err) => Err(err),
}
}
pub fn join_timeout(self, timeout: Duration) -> Result<T, SpawnError> {
match self.inner {
Ok(JoinHandleInner::Process(mut handle)) => handle.join_timeout(timeout),
Ok(JoinHandleInner::Pooled(mut handle)) => handle.join_timeout(timeout),
Err(err) => Err(err),
}
}
}
pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
args: A,
f: fn(A) -> R,
) -> JoinHandle<R> {
Builder::new().spawn(args, f)
}