procspawn/
proc.rs

1use std::collections::HashMap;
2use std::ffi::{OsStr, OsString};
3use std::fmt;
4use std::path::PathBuf;
5use std::process::Stdio;
6use std::process::{ChildStderr, ChildStdin, ChildStdout};
7use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use std::{env, mem, process};
11use std::{io, thread};
12
13use ipc_channel::ipc::{self, IpcOneShotServer, IpcReceiver, IpcSender};
14use serde::{de::DeserializeOwned, Serialize};
15
16use crate::core::{assert_spawn_okay, should_pass_args, MarshalledCall, ENV_NAME};
17use crate::error::{PanicInfo, SpawnError};
18use crate::pool::PooledHandle;
19use crate::serde::with_ipc_mode;
20
21#[cfg(unix)]
22type PreExecFunc = dyn FnMut() -> io::Result<()> + Send + Sync + 'static;
23
24#[derive(Clone)]
25pub struct ProcCommon {
26    pub vars: HashMap<OsString, OsString>,
27    #[cfg(unix)]
28    pub uid: Option<u32>,
29    #[cfg(unix)]
30    pub gid: Option<u32>,
31    #[cfg(unix)]
32    pub pre_exec: Option<Arc<std::sync::Mutex<Box<PreExecFunc>>>>,
33}
34
35impl fmt::Debug for ProcCommon {
36    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37        f.debug_struct("ProcCommon")
38            .field("vars", &self.vars)
39            .finish()
40    }
41}
42
43impl Default for ProcCommon {
44    fn default() -> ProcCommon {
45        ProcCommon {
46            vars: std::env::vars_os().collect(),
47            #[cfg(unix)]
48            uid: None,
49            #[cfg(unix)]
50            gid: None,
51            #[cfg(unix)]
52            pre_exec: None,
53        }
54    }
55}
56
57/// Process factory, which can be used in order to configure the properties
58/// of a process being created.
59///
60/// Methods can be chained on it in order to configure it.
61#[derive(Debug, Default)]
62pub struct Builder {
63    stdin: Option<Stdio>,
64    stdout: Option<Stdio>,
65    stderr: Option<Stdio>,
66    common: ProcCommon,
67}
68
69macro_rules! define_common_methods {
70    () => {
71        /// Set an environment variable in the spawned process.
72        ///
73        /// Equivalent to `Command::env`
74        pub fn env<K, V>(&mut self, key: K, val: V) -> &mut Self
75        where
76            K: AsRef<OsStr>,
77            V: AsRef<OsStr>,
78        {
79            self.common
80                .vars
81                .insert(key.as_ref().to_owned(), val.as_ref().to_owned());
82            self
83        }
84
85        /// Set environment variables in the spawned process.
86        ///
87        /// Equivalent to `Command::envs`
88        pub fn envs<I, K, V>(&mut self, vars: I) -> &mut Self
89        where
90            I: IntoIterator<Item = (K, V)>,
91            K: AsRef<OsStr>,
92            V: AsRef<OsStr>,
93        {
94            self.common.vars.extend(
95                vars.into_iter()
96                    .map(|(k, v)| (k.as_ref().to_owned(), v.as_ref().to_owned())),
97            );
98            self
99        }
100
101        /// Removes an environment variable in the spawned process.
102        ///
103        /// Equivalent to `Command::env_remove`
104        pub fn env_remove<K: AsRef<OsStr>>(&mut self, key: K) -> &mut Self {
105            self.common.vars.remove(key.as_ref());
106            self
107        }
108
109        /// Clears all environment variables in the spawned process.
110        ///
111        /// Equivalent to `Command::env_clear`
112        pub fn env_clear(&mut self) -> &mut Self {
113            self.common.vars.clear();
114            self
115        }
116
117        /// Sets the child process's user ID. This translates to a
118        /// `setuid` call in the child process. Failure in the `setuid`
119        /// call will cause the spawn to fail.
120        ///
121        /// Unix-specific extension only available on unix.
122        ///
123        /// Equivalent to `std::os::unix::process::CommandExt::uid`
124        #[cfg(unix)]
125        pub fn uid(&mut self, id: u32) -> &mut Self {
126            self.common.uid = Some(id);
127            self
128        }
129
130        /// Similar to `uid`, but sets the group ID of the child process. This has
131        /// the same semantics as the `uid` field.
132        ///
133        /// Unix-specific extension only available on unix.
134        ///
135        /// Equivalent to `std::os::unix::process::CommandExt::gid`
136        #[cfg(unix)]
137        pub fn gid(&mut self, id: u32) -> &mut Self {
138            self.common.gid = Some(id);
139            self
140        }
141
142        /// Schedules a closure to be run just before the `exec` function is
143        /// invoked.
144        ///
145        /// # Safety
146        ///
147        /// This method is inherently unsafe.  See the notes of the unix command
148        /// ext for more information.
149        ///
150        /// Equivalent to `std::os::unix::process::CommandExt::pre_exec`
151        #[cfg(unix)]
152        pub unsafe fn pre_exec<F>(&mut self, f: F) -> &mut Self
153        where
154            F: FnMut() -> io::Result<()> + Send + Sync + 'static,
155        {
156            self.common.pre_exec = Some(Arc::new(std::sync::Mutex::new(Box::new(f))));
157            self
158        }
159    };
160}
161
162impl Builder {
163    /// Generates the base configuration for spawning a thread, from which
164    /// configuration methods can be chained.
165    pub fn new() -> Self {
166        Self {
167            stdin: None,
168            stdout: None,
169            stderr: None,
170            common: ProcCommon::default(),
171        }
172    }
173
174    pub(crate) fn common(&mut self, common: ProcCommon) -> &mut Self {
175        self.common = common;
176        self
177    }
178
179    define_common_methods!();
180
181    /// Captures the `stdin` of the spawned process, allowing you to manually
182    /// send data via `JoinHandle::stdin`
183    pub fn stdin<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
184        self.stdin = Some(cfg.into());
185        self
186    }
187
188    /// Captures the `stdout` of the spawned process, allowing you to manually
189    /// receive data via `JoinHandle::stdout`
190    pub fn stdout<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
191        self.stdout = Some(cfg.into());
192        self
193    }
194
195    /// Captures the `stderr` of the spawned process, allowing you to manually
196    /// receive data via `JoinHandle::stderr`
197    pub fn stderr<T: Into<Stdio>>(&mut self, cfg: T) -> &mut Self {
198        self.stderr = Some(cfg.into());
199        self
200    }
201
202    /// Spawns the process.
203    pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
204        &mut self,
205        args: A,
206        func: fn(A) -> R,
207    ) -> JoinHandle<R> {
208        assert_spawn_okay();
209        JoinHandle {
210            inner: mem::take(self)
211                .spawn_helper(args, func)
212                .map(JoinHandleInner::Process),
213        }
214    }
215
216    fn spawn_helper<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
217        self,
218        args: A,
219        func: fn(A) -> R,
220    ) -> Result<ProcessHandle<R>, SpawnError> {
221        let (server, token) = IpcOneShotServer::<IpcSender<MarshalledCall>>::new()?;
222        let me = if cfg!(target_os = "linux") {
223            // will work even if exe is moved
224            let path: PathBuf = "/proc/self/exe".into();
225            if path.is_file() {
226                path
227            } else {
228                // might not exist, e.g. on chroot
229                env::current_exe()?
230            }
231        } else {
232            env::current_exe()?
233        };
234        let mut child = process::Command::new(me);
235        child.envs(self.common.vars);
236        child.env(ENV_NAME, token);
237
238        #[cfg(unix)]
239        {
240            use std::os::unix::process::CommandExt;
241            if let Some(id) = self.common.uid {
242                child.uid(id);
243            }
244            if let Some(id) = self.common.gid {
245                child.gid(id);
246            }
247            if let Some(ref func) = self.common.pre_exec {
248                let func = func.clone();
249                unsafe {
250                    #[allow(clippy::needless_borrow)]
251                    child.pre_exec(move || (&mut *func.lock().unwrap())());
252                }
253            }
254        }
255
256        let (can_pass_args, should_silence_stdout) = {
257            #[cfg(feature = "test-support")]
258            {
259                match crate::testsupport::update_command_for_tests(&mut child) {
260                    None => (true, false),
261                    Some(crate::testsupport::TestMode {
262                        can_pass_args,
263                        should_silence_stdout,
264                    }) => (can_pass_args, should_silence_stdout),
265                }
266            }
267            #[cfg(not(feature = "test-support"))]
268            {
269                (true, false)
270            }
271        };
272
273        if can_pass_args && should_pass_args() {
274            child.args(env::args_os().skip(1));
275        }
276
277        if let Some(stdin) = self.stdin {
278            child.stdin(stdin);
279        }
280        if let Some(stdout) = self.stdout {
281            child.stdout(stdout);
282        } else if should_silence_stdout {
283            child.stdout(Stdio::null());
284        }
285        if let Some(stderr) = self.stderr {
286            child.stderr(stderr);
287        }
288        let process = child.spawn()?;
289
290        let (_rx, tx) = server.accept()?;
291
292        let (args_tx, args_rx) = ipc::channel()?;
293        let (return_tx, return_rx) = ipc::channel()?;
294
295        tx.send(MarshalledCall::marshal::<A, R>(func, args_rx, return_tx))?;
296        with_ipc_mode(|| -> Result<_, SpawnError> {
297            args_tx.send(args)?;
298            Ok(())
299        })?;
300
301        Ok(ProcessHandle {
302            recv: return_rx,
303            state: Arc::new(ProcessHandleState::new(Some(process.id()))),
304            process,
305        })
306    }
307}
308
309#[derive(Debug)]
310pub struct ProcessHandleState {
311    pub exited: AtomicBool,
312    pub pid: AtomicUsize,
313}
314
315impl ProcessHandleState {
316    pub fn new(pid: Option<u32>) -> ProcessHandleState {
317        ProcessHandleState {
318            exited: AtomicBool::new(false),
319            pid: AtomicUsize::new(pid.unwrap_or(0) as usize),
320        }
321    }
322
323    pub fn pid(&self) -> Option<u32> {
324        match self.pid.load(Ordering::SeqCst) {
325            0 => None,
326            x => Some(x as u32),
327        }
328    }
329
330    pub fn kill(&self) {
331        if !self.exited.load(Ordering::SeqCst) {
332            self.exited.store(true, Ordering::SeqCst);
333            if let Some(pid) = self.pid() {
334                unsafe {
335                    #[cfg(unix)]
336                    {
337                        libc::kill(pid as i32, libc::SIGKILL);
338                    }
339                    #[cfg(windows)]
340                    {
341                        use windows_sys::Win32::System::Threading;
342                        let proc =
343                            Threading::OpenProcess(Threading::PROCESS_ALL_ACCESS, 0, pid as _);
344                        Threading::TerminateProcess(proc, 1);
345                    }
346                }
347            }
348        }
349    }
350}
351
352pub struct ProcessHandle<T> {
353    pub(crate) recv: IpcReceiver<Result<T, PanicInfo>>,
354    pub(crate) process: process::Child,
355    pub(crate) state: Arc<ProcessHandleState>,
356}
357
358fn is_ipc_timeout(err: &ipc_channel::ipc::TryRecvError) -> bool {
359    matches!(err, ipc_channel::ipc::TryRecvError::Empty)
360}
361
362impl<T> ProcessHandle<T> {
363    pub fn state(&self) -> Arc<ProcessHandleState> {
364        self.state.clone()
365    }
366
367    pub fn kill(&mut self) -> Result<(), SpawnError> {
368        if self.state.exited.load(Ordering::SeqCst) {
369            return Ok(());
370        }
371
372        let rv = self.process.kill().map_err(Into::into);
373        self.wait();
374        rv
375    }
376
377    pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
378        self.process.stdin.as_mut()
379    }
380
381    pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
382        self.process.stdout.as_mut()
383    }
384
385    pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
386        self.process.stderr.as_mut()
387    }
388
389    fn wait(&mut self) {
390        self.process.wait().ok();
391        self.state.exited.store(true, Ordering::SeqCst);
392    }
393}
394
395impl<T: Serialize + DeserializeOwned> ProcessHandle<T> {
396    pub fn join(&mut self) -> Result<T, SpawnError> {
397        let rv = with_ipc_mode(|| self.recv.recv())?.map_err(Into::into);
398        self.wait();
399        rv
400    }
401
402    pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
403        let deadline = match Instant::now().checked_add(timeout) {
404            Some(deadline) => deadline,
405            None => {
406                return Err(io::Error::new(io::ErrorKind::Other, "timeout out of bounds").into())
407            }
408        };
409        let mut to_sleep = Duration::from_millis(1);
410        let rv = loop {
411            match with_ipc_mode(|| self.recv.try_recv()) {
412                Ok(rv) => break rv.map_err(Into::into),
413                Err(err) if is_ipc_timeout(&err) => {
414                    if let Some(remaining) = deadline.checked_duration_since(Instant::now()) {
415                        thread::sleep(remaining.min(to_sleep));
416                        to_sleep *= 2;
417                    } else {
418                        return Err(SpawnError::new_timeout());
419                    }
420                }
421                Err(err) => return Err(err.into()),
422            }
423        };
424
425        self.wait();
426        rv
427    }
428}
429
430pub enum JoinHandleInner<T> {
431    Process(ProcessHandle<T>),
432    Pooled(PooledHandle<T>),
433}
434
435/// An owned permission to join on a process (block on its termination).
436///
437/// The join handle can be used to join a process but also provides the
438/// ability to kill it.
439pub struct JoinHandle<T> {
440    pub(crate) inner: Result<JoinHandleInner<T>, SpawnError>,
441}
442
443impl<T> fmt::Debug for JoinHandle<T> {
444    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
445        f.debug_struct("JoinHandle")
446            .field("pid", &self.pid())
447            .finish()
448    }
449}
450
451impl<T> JoinHandle<T> {
452    pub(crate) fn process_handle_state(&self) -> Option<Arc<ProcessHandleState>> {
453        match self.inner {
454            Ok(JoinHandleInner::Process(ref handle)) => Some(handle.state()),
455            Ok(JoinHandleInner::Pooled(ref handle)) => handle.process_handle_state(),
456            Err(..) => None,
457        }
458    }
459
460    /// Returns the process ID if available.
461    ///
462    /// The process ID is unavailable when pooled calls are not scheduled to
463    /// processes.
464    pub fn pid(&self) -> Option<u32> {
465        self.process_handle_state().and_then(|x| x.pid())
466    }
467
468    /// Kill the child process.
469    ///
470    /// If the join handle was created from a pool this call will do one of
471    /// two things depending on the situation:
472    ///
473    /// * if the call was already picked up by the process, the process will
474    ///   be killed.
475    /// * if the call was not yet scheduled to a process it will be cancelled.
476    pub fn kill(&mut self) -> Result<(), SpawnError> {
477        match self.inner {
478            Ok(JoinHandleInner::Process(ref mut handle)) => handle.kill(),
479            Ok(JoinHandleInner::Pooled(ref mut handle)) => handle.kill(),
480            Err(_) => Ok(()),
481        }
482    }
483
484    /// Fetch the `stdin` handle if it has been captured
485    pub fn stdin(&mut self) -> Option<&mut ChildStdin> {
486        match self.inner {
487            Ok(JoinHandleInner::Process(ref mut process)) => process.stdin(),
488            Ok(JoinHandleInner::Pooled(..)) => None,
489            Err(_) => None,
490        }
491    }
492
493    /// Fetch the `stdout` handle if it has been captured
494    pub fn stdout(&mut self) -> Option<&mut ChildStdout> {
495        match self.inner {
496            Ok(JoinHandleInner::Process(ref mut process)) => process.stdout(),
497            Ok(JoinHandleInner::Pooled(..)) => None,
498            Err(_) => None,
499        }
500    }
501
502    /// Fetch the `stderr` handle if it has been captured
503    pub fn stderr(&mut self) -> Option<&mut ChildStderr> {
504        match self.inner {
505            Ok(JoinHandleInner::Process(ref mut process)) => process.stderr(),
506            Ok(JoinHandleInner::Pooled(..)) => None,
507            Err(_) => None,
508        }
509    }
510}
511
512impl<T: Serialize + DeserializeOwned> JoinHandle<T> {
513    /// Wait for the child process to return a result.
514    ///
515    /// If the join handle was created from a pool the join is virtualized.
516    pub fn join(self) -> Result<T, SpawnError> {
517        match self.inner {
518            Ok(JoinHandleInner::Process(mut handle)) => handle.join(),
519            Ok(JoinHandleInner::Pooled(mut handle)) => handle.join(),
520            Err(err) => Err(err),
521        }
522    }
523
524    /// Like `join` but with a timeout.
525    ///
526    /// Can be called multiple times. If anything other than a timeout error is returned, the
527    /// handle becomes unusuable, and subsequent calls to either `join` or `join_timeout` will
528    /// return an error.
529    pub fn join_timeout(&mut self, timeout: Duration) -> Result<T, SpawnError> {
530        match self.inner {
531            Ok(ref mut handle_inner) => {
532                let result = match handle_inner {
533                    JoinHandleInner::Process(ref mut handle) => handle.join_timeout(timeout),
534                    JoinHandleInner::Pooled(ref mut handle) => handle.join_timeout(timeout),
535                };
536
537                if result.is_ok() {
538                    self.inner = Err(SpawnError::new_consumed());
539                }
540
541                result
542            }
543            Err(ref mut err) => {
544                let mut rv_err = SpawnError::new_consumed();
545                mem::swap(&mut rv_err, err);
546                Err(rv_err)
547            }
548        }
549    }
550}
551
552/// Spawn a new process to run a function with some payload.
553///
554/// ```rust,no_run
555/// // call this early in your main() function.  This is where all spawned
556/// // functions will be invoked.
557/// procspawn::init();
558///
559/// let data = vec![1, 2, 3, 4];
560/// let handle = procspawn::spawn(data, |data| {
561///     println!("Received data {:?}", &data);
562///     data.into_iter().sum::<i64>()
563/// });
564/// let result = handle.join().unwrap();
565/// ```
566pub fn spawn<A: Serialize + DeserializeOwned, R: Serialize + DeserializeOwned>(
567    args: A,
568    f: fn(A) -> R,
569) -> JoinHandle<R> {
570    Builder::new().spawn(args, f)
571}