process_wrap/std/
job_object.rs

1use std::{
2	io::Result,
3	os::windows::{io::AsRawHandle, process::CommandExt},
4	process::{Command, ExitStatus},
5	time::Duration,
6};
7
8#[cfg(feature = "tracing")]
9use tracing::{debug, instrument};
10use windows::Win32::{
11	Foundation::{CloseHandle, HANDLE},
12	System::Threading::CREATE_SUSPENDED,
13};
14
15use crate::{
16	windows::{make_job_object, resume_threads, terminate_job, wait_on_job, JobPort},
17	ChildExitStatus,
18};
19
20#[cfg(feature = "creation-flags")]
21use super::CreationFlags;
22use super::{ChildWrapper, CommandWrap, CommandWrapper};
23
24/// Wrapper which creates a job object context for a `Command`.
25///
26/// This wrapper is only available on Windows.
27///
28/// It creates a Windows Job Object and associates the [`Command`] to it. This behaves analogously
29/// to process groups on Unix or even cgroups on Linux, with the ability to restrict resource use.
30/// See [Job Objects](https://docs.microsoft.com/en-us/windows/win32/procthread/job-objects).
31///
32/// This wrapper provides a child wrapper: [`JobObjectChild`].
33///
34/// When both [`CreationFlags`] and [`JobObject`] are used together, either:
35/// - `CreationFlags` must come first, or
36/// - `CreationFlags` must include `CREATE_SUSPENDED`
37#[derive(Clone, Copy, Debug)]
38pub struct JobObject;
39
40impl CommandWrapper for JobObject {
41	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
42	fn pre_spawn(&mut self, command: &mut Command, core: &CommandWrap) -> Result<()> {
43		let mut flags = CREATE_SUSPENDED;
44		#[cfg(feature = "creation-flags")]
45		if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>() {
46			flags |= *user_flags;
47		}
48
49		command.creation_flags(flags.0);
50		Ok(())
51	}
52
53	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
54	fn wrap_child(
55		&mut self,
56		inner: Box<dyn ChildWrapper>,
57		core: &CommandWrap,
58	) -> Result<Box<dyn ChildWrapper>> {
59		#[cfg(feature = "creation-flags")]
60		let create_suspended = core
61			.get_wrap::<CreationFlags>()
62			.map_or(false, |flags| flags.0.contains(CREATE_SUSPENDED));
63		#[cfg(not(feature = "creation-flags"))]
64		let create_suspended = false;
65
66		#[cfg(feature = "tracing")]
67		debug!(?create_suspended, "options from other wrappers");
68
69		let handle = HANDLE(inner.inner_child().as_raw_handle() as _);
70
71		let job_port = make_job_object(handle, false)?;
72
73		// only resume if the user didn't specify CREATE_SUSPENDED
74		if !create_suspended {
75			resume_threads(handle)?;
76		}
77
78		Ok(Box::new(JobObjectChild::new(inner, job_port)))
79	}
80}
81
82/// Wrapper for `Child` which waits on all processes within the job.
83#[derive(Debug)]
84pub struct JobObjectChild {
85	inner: Box<dyn ChildWrapper>,
86	exit_status: ChildExitStatus,
87	job_port: JobPort,
88}
89
90impl JobObjectChild {
91	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(job_port)))]
92	pub(crate) fn new(inner: Box<dyn ChildWrapper>, job_port: JobPort) -> Self {
93		Self {
94			inner,
95			exit_status: ChildExitStatus::Running,
96			job_port,
97		}
98	}
99}
100
101impl ChildWrapper for JobObjectChild {
102	fn inner(&self) -> &dyn ChildWrapper {
103		self.inner.inner()
104	}
105	fn inner_mut(&mut self) -> &mut dyn ChildWrapper {
106		self.inner.inner_mut()
107	}
108	fn into_inner(self: Box<Self>) -> Box<dyn ChildWrapper> {
109		// manually drop the completion port
110		let its = std::mem::ManuallyDrop::new(self.job_port);
111		unsafe { CloseHandle(its.completion_port.0) }.ok();
112		// we leave the job handle unclosed, otherwise the Child is useless
113		// (as closing it will terminate the job)
114
115		self.inner.into_inner()
116	}
117
118	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
119	fn start_kill(&mut self) -> Result<()> {
120		terminate_job(self.job_port.job, 1)
121	}
122
123	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
124	fn wait(&mut self) -> Result<ExitStatus> {
125		if let ChildExitStatus::Exited(status) = &self.exit_status {
126			return Ok(*status);
127		}
128
129		// always wait for parent to exit first, as by the time it does,
130		// it's likely that all its children have already exited.
131		let status = self.inner.wait()?;
132		self.exit_status = ChildExitStatus::Exited(status);
133
134		// nevertheless, now wait and make sure we reap all children.
135		let JobPort {
136			completion_port, ..
137		} = self.job_port;
138		let _ = wait_on_job(completion_port, None)?;
139		Ok(status)
140	}
141
142	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
143	fn try_wait(&mut self) -> Result<Option<ExitStatus>> {
144		let _ = wait_on_job(self.job_port.completion_port, Some(Duration::ZERO))?;
145		self.inner.try_wait()
146	}
147}