1use std::collections::HashMap;
2use std::ffi::{OsStr, OsString};
3use std::fmt;
4use std::path::PathBuf;
5use std::process::Stdio;
6use std::process::{ChildStderr, ChildStdin, ChildStdout};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use std::{env, mem, process};
11use std::{io, thread};
12
13use ipc_channel::ipc::{self, IpcOneShotServer, IpcReceiver, IpcSender};
14use serde::{de::DeserializeOwned, Serialize};
15
16use crate::core::{assert_spawn_okay, should_pass_args, MarshalledCall, ENV_NAME};
17use crate::error::{PanicInfo, SpawnError};
18use crate::pool::PooledHandle;
19use crate::serde::with_ipc_mode;
20
21#[cfg(unix)]
22type PreExecFunc = dyn FnMut() -> io::Result<()> + Send + Sync + 'static;
23
24#[derive(Clone)]
25pub struct ProcCommon {
26 pub vars: HashMap<OsString, OsString>,
27 #[cfg(unix)]
28 pub uid: Option<u32>,
29 #[cfg(unix)]
30 pub gid: Option<u32>,
31 #[cfg(unix)]
32 pub pre_exec: Option<Arc<std::sync::Mutex<Box<PreExecFunc>>>>,
33}
34
35impl fmt::Debug for ProcCommon {
36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37 f.debug_struct("ProcCommon")
38 .field("vars", &self.vars)
39 .finish()
40 }
41}
42
43impl Default for ProcCommon {
44 fn default() -> ProcCommon {
45 ProcCommon {
46 vars: std::env::vars_os().collect(),
47 #[cfg(unix)]
48 uid: None,
49 #[cfg(unix)]
50 gid: None,
51 #[cfg(unix)]
52 pre_exec: None,
53 }
54 }
55}
56
57#[derive(Debug, Default)]
62pub struct Builder {
63 stdin: Option<Stdio>,
64 stdout: Option<Stdio>,
65 stderr: Option<Stdio>,
66 common: ProcCommon,
67}
68
69macro_rules! define_common_methods {
70 () => {
71 pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
75 where
76 K: AsRef<OsStr>,
77 V: AsRef<OsStr>,
78 {
79 self.common
80 .vars
81 .insert(key.as_ref().to_owned(), val.as_ref().to_owned());
82 self
83 }
84
85 pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
89 where
90 I: IntoIterator<Item = (K, V)>,
91 K: AsRef<OsStr>,
92 V: AsRef<OsStr>,
93 {
94 self.common.vars.extend(
95 vars.into_iter()
96 .map(|(k, v)| (k.as_ref().to_owned(), v.as_ref().to_owned())),
97 );
98 self
99 }
100
101 pub fn env_remove<K: AsRef<OsStr>>(&mut self, key: K) -> &mut Self {
105 self.common.vars.remove(key.as_ref());
106 self
107 }
108
109 pub fn env_clear(&mut self) -> &mut Self {
113 self.common.vars.clear();
114 self
115 }
116
117 #[cfg(unix)]
125 pub fn uid(&mut self, id: u32) -> &mut Self {
126 self.common.uid = Some(id);
127 self
128 }
129
130 #[cfg(unix)]
137 pub fn gid(&mut self, id: u32) -> &mut Self {
138 self.common.gid = Some(id);
139 self
140 }
141
142 #[cfg(unix)]
152 pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Self
153 where
154 F: FnMut() -> io::Result<()> + Send + Sync + 'static,
155 {
156 self.common.pre_exec = Some(Arc::new(std::sync::Mutex::new(Box::new(f))));
157 self
158 }
159 };
160}
161
162impl Builder {
163 pub fn new() -> Self {
166 Self {
167 stdin: None,
168 stdout: None,
169 stderr: None,
170 common: ProcCommon::default(),
171 }
172 }
173
174 pub(crate) fn common(&mut self, common: ProcCommon) -> &mut Self {
175 self.common = common;
176 self
177 }
178
179 define_common_methods!();
180
181 pub fn stdin<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
184 self.stdin = Some(cfg.into());
185 self
186 }
187
188 pub fn stdout<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
191 self.stdout = Some(cfg.into());
192 self
193 }
194
195 pub fn stderr<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
198 self.stderr = Some(cfg.into());
199 self
200 }
201
202 pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
204 &mut self,
205 args: A,
206 func: fn(A) -> R,
207 ) -> JoinHandle<R> {
208 assert_spawn_okay();
209 JoinHandle {
210 inner: mem::take(self)
211 .spawn_helper(args, func)
212 .map(JoinHandleInner::Process),
213 }
214 }
215
216 fn spawn_helper<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
217 self,
218 args: A,
219 func: fn(A) -> R,
220 ) -> Result<ProcessHandle<R>, SpawnError> {
221 let (server, token) = IpcOneShotServer::<IpcSender<MarshalledCall>>::new()?;
222 let me = if cfg!(target_os = "linux") {
223 let path: PathBuf = "/proc/self/exe".into();
225 if path.is_file() {
226 path
227 } else {
228 env::current_exe()?
230 }
231 } else {
232 env::current_exe()?
233 };
234 let mut child = process::Command::new(me);
235 child.envs(self.common.vars);
236 child.env(ENV_NAME, token);
237
238 #[cfg(unix)]
239 {
240 use std::os::unix::process::CommandExt;
241 if let Some(id) = self.common.uid {
242 child.uid(id);
243 }
244 if let Some(id) = self.common.gid {
245 child.gid(id);
246 }
247 if let Some(ref func) = self.common.pre_exec {
248 let func = func.clone();
249 unsafe {
250 #[allow(clippy::needless_borrow)]
251 child.pre_exec(move || (&mut *func.lock().unwrap())());
252 }
253 }
254 }
255
256 let (can_pass_args, should_silence_stdout) = {
257 #[cfg(feature = "test-support")]
258 {
259 match crate::testsupport::update_command_for_tests(&mut child) {
260 None => (true, false),
261 Some(crate::testsupport::TestMode {
262 can_pass_args,
263 should_silence_stdout,
264 }) => (can_pass_args, should_silence_stdout),
265 }
266 }
267 #[cfg(not(feature = "test-support"))]
268 {
269 (true, false)
270 }
271 };
272
273 if can_pass_args && should_pass_args() {
274 child.args(env::args_os().skip(1));
275 }
276
277 if let Some(stdin) = self.stdin {
278 child.stdin(stdin);
279 }
280 if let Some(stdout) = self.stdout {
281 child.stdout(stdout);
282 } else if should_silence_stdout {
283 child.stdout(Stdio::null());
284 }
285 if let Some(stderr) = self.stderr {
286 child.stderr(stderr);
287 }
288 let process = child.spawn()?;
289
290 let (_rx, tx) = server.accept()?;
291
292 let (args_tx, args_rx) = ipc::channel()?;
293 let (return_tx, return_rx) = ipc::channel()?;
294
295 tx.send(MarshalledCall::marshal::<A, R>(func, args_rx, return_tx))?;
296 with_ipc_mode(|| -> Result<_, SpawnError> {
297 args_tx.send(args)?;
298 Ok(())
299 })?;
300
301 Ok(ProcessHandle {
302 recv: return_rx,
303 state: Arc::new(ProcessHandleState::new(Some(process.id()))),
304 process,
305 })
306 }
307}
308
309#[derive(Debug)]
310pub struct ProcessHandleState {
311 pub exited: AtomicBool,
312 pub pid: AtomicUsize,
313}
314
315impl ProcessHandleState {
316 pub fn new(pid: Option<u32>) -> ProcessHandleState {
317 ProcessHandleState {
318 exited: AtomicBool::new(false),
319 pid: AtomicUsize::new(pid.unwrap_or(0) as usize),
320 }
321 }
322
323 pub fn pid(&self) -> Option<u32> {
324 match self.pid.load(Ordering::SeqCst) {
325 0 => None,
326 x => Some(x as u32),
327 }
328 }
329
330 pub fn kill(&self) {
331 if !self.exited.load(Ordering::SeqCst) {
332 self.exited.store(true, Ordering::SeqCst);
333 if let Some(pid) = self.pid() {
334 unsafe {
335 #[cfg(unix)]
336 {
337 libc::kill(pid as i32, libc::SIGKILL);
338 }
339 #[cfg(windows)]
340 {
341 use windows_sys::Win32::System::Threading;
342 let proc =
343 Threading::OpenProcess(Threading::PROCESS_ALL_ACCESS, 0, pid as _);
344 Threading::TerminateProcess(proc, 1);
345 }
346 }
347 }
348 }
349 }
350}
351
352pub struct ProcessHandle<T> {
353 pub(crate) recv: IpcReceiver<Result<T, PanicInfo>>,
354 pub(crate) process: process::Child,
355 pub(crate) state: Arc<ProcessHandleState>,
356}
357
358fn is_ipc_timeout(err: &ipc_channel::ipc::TryRecvError) -> bool {
359 matches!(err, ipc_channel::ipc::TryRecvError::Empty)
360}
361
362impl<T> ProcessHandle<T> {
363 pub fn state(&self) -> Arc<ProcessHandleState> {
364 self.state.clone()
365 }
366
367 pub fn kill(&mut self) -> Result<(), SpawnError> {
368 if self.state.exited.load(Ordering::SeqCst) {
369 return Ok(());
370 }
371
372 let rv = self.process.kill().map_err(Into::into);
373 self.wait();
374 rv
375 }
376
377 pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
378 self.process.stdin.as_mut()
379 }
380
381 pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
382 self.process.stdout.as_mut()
383 }
384
385 pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
386 self.process.stderr.as_mut()
387 }
388
389 fn wait(&mut self) {
390 self.process.wait().ok();
391 self.state.exited.store(true, Ordering::SeqCst);
392 }
393}
394
395impl<T: Serialize + DeserializeOwned> ProcessHandle<T> {
396 pub fn join(&mut self) -> Result<T, SpawnError> {
397 let rv = with_ipc_mode(|| self.recv.recv())?.map_err(Into::into);
398 self.wait();
399 rv
400 }
401
402 pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
403 let deadline = match Instant::now().checked_add(timeout) {
404 Some(deadline) => deadline,
405 None => {
406 return Err(io::Error::new(io::ErrorKind::Other, "timeout out of bounds").into())
407 }
408 };
409 let mut to_sleep = Duration::from_millis(1);
410 let rv = loop {
411 match with_ipc_mode(|| self.recv.try_recv()) {
412 Ok(rv) => break rv.map_err(Into::into),
413 Err(err) if is_ipc_timeout(&err) => {
414 if let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
415 thread::sleep(remaining.min(to_sleep));
416 to_sleep *= 2;
417 } else {
418 return Err(SpawnError::new_timeout());
419 }
420 }
421 Err(err) => return Err(err.into()),
422 }
423 };
424
425 self.wait();
426 rv
427 }
428}
429
430pub enum JoinHandleInner<T> {
431 Process(ProcessHandle<T>),
432 Pooled(PooledHandle<T>),
433}
434
435pub struct JoinHandle<T> {
440 pub(crate) inner: Result<JoinHandleInner<T>, SpawnError>,
441}
442
443impl<T> fmt::Debug for JoinHandle<T> {
444 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
445 f.debug_struct("JoinHandle")
446 .field("pid", &self.pid())
447 .finish()
448 }
449}
450
451impl<T> JoinHandle<T> {
452 pub(crate) fn process_handle_state(&self) -> Option<Arc<ProcessHandleState>> {
453 match self.inner {
454 Ok(JoinHandleInner::Process(ref handle)) => Some(handle.state()),
455 Ok(JoinHandleInner::Pooled(ref handle)) => handle.process_handle_state(),
456 Err(..) => None,
457 }
458 }
459
460 pub fn pid(&self) -> Option<u32> {
465 self.process_handle_state().and_then(|x| x.pid())
466 }
467
468 pub fn kill(&mut self) -> Result<(), SpawnError> {
477 match self.inner {
478 Ok(JoinHandleInner::Process(ref mut handle)) => handle.kill(),
479 Ok(JoinHandleInner::Pooled(ref mut handle)) => handle.kill(),
480 Err(_) => Ok(()),
481 }
482 }
483
484 pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
486 match self.inner {
487 Ok(JoinHandleInner::Process(ref mut process)) => process.stdin(),
488 Ok(JoinHandleInner::Pooled(..)) => None,
489 Err(_) => None,
490 }
491 }
492
493 pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
495 match self.inner {
496 Ok(JoinHandleInner::Process(ref mut process)) => process.stdout(),
497 Ok(JoinHandleInner::Pooled(..)) => None,
498 Err(_) => None,
499 }
500 }
501
502 pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
504 match self.inner {
505 Ok(JoinHandleInner::Process(ref mut process)) => process.stderr(),
506 Ok(JoinHandleInner::Pooled(..)) => None,
507 Err(_) => None,
508 }
509 }
510}
511
512impl<T: Serialize + DeserializeOwned> JoinHandle<T> {
513 pub fn join(self) -> Result<T, SpawnError> {
517 match self.inner {
518 Ok(JoinHandleInner::Process(mut handle)) => handle.join(),
519 Ok(JoinHandleInner::Pooled(mut handle)) => handle.join(),
520 Err(err) => Err(err),
521 }
522 }
523
524 pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
530 match self.inner {
531 Ok(ref mut handle_inner) => {
532 let result = match handle_inner {
533 JoinHandleInner::Process(ref mut handle) => handle.join_timeout(timeout),
534 JoinHandleInner::Pooled(ref mut handle) => handle.join_timeout(timeout),
535 };
536
537 if result.is_ok() {
538 self.inner = Err(SpawnError::new_consumed());
539 }
540
541 result
542 }
543 Err(ref mut err) => {
544 let mut rv_err = SpawnError::new_consumed();
545 mem::swap(&mut rv_err, err);
546 Err(rv_err)
547 }
548 }
549 }
550}
551
552pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
567 args: A,
568 f: fn(A) -> R,
569) -> JoinHandle<R> {
570 Builder::new().spawn(args, f)
571}