process_wrap/tokio/
job_object.rs

1use std::{future::Future, io::Result, pin::Pin, process::ExitStatus, time::Duration};
2
3use tokio::{process::Command, task::spawn_blocking};
4#[cfg(feature = "tracing")]
5use tracing::{debug, instrument};
6use windows::Win32::{
7	Foundation::{CloseHandle, HANDLE},
8	System::Threading::CREATE_SUSPENDED,
9};
10
11use crate::{
12	windows::{make_job_object, resume_threads, terminate_job, wait_on_job, JobPort},
13	ChildExitStatus,
14};
15
16#[cfg(feature = "creation-flags")]
17use super::CreationFlags;
18#[cfg(feature = "kill-on-drop")]
19use super::KillOnDrop;
20use super::{ChildWrapper, CommandWrap, CommandWrapper};
21
22/// Wrapper which creates a job object context for a `Command`.
23///
24/// This wrapper is only available on Windows.
25///
26/// It creates a Windows Job Object and associates the [`Command`] to it. This behaves analogously
27/// to process groups on Unix or even cgroups on Linux, with the ability to restrict resource use.
28/// See [Job Objects](https://docs.microsoft.com/en-us/windows/win32/procthread/job-objects).
29///
30/// This wrapper provides a child wrapper: [`JobObjectChild`].
31///
32/// When both [`CreationFlags`] and [`JobObject`] are used together, either:
33/// - `CreationFlags` must come first, or
34/// - `CreationFlags` must include `CREATE_SUSPENDED`
35#[derive(Clone, Copy, Debug)]
36pub struct JobObject;
37
38impl CommandWrapper for JobObject {
39	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
40	fn pre_spawn(&mut self, command: &mut Command, core: &CommandWrap) -> Result<()> {
41		let mut flags = CREATE_SUSPENDED;
42		#[cfg(feature = "creation-flags")]
43		if let Some(CreationFlags(user_flags)) = core.get_wrap::<CreationFlags>() {
44			flags |= *user_flags;
45		}
46
47		command.creation_flags(flags.0);
48		Ok(())
49	}
50
51	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
52	fn wrap_child(
53		&mut self,
54		inner: Box<dyn ChildWrapper>,
55		core: &CommandWrap,
56	) -> Result<Box<dyn ChildWrapper>> {
57		#[cfg(feature = "kill-on-drop")]
58		let kill_on_drop = core.has_wrap::<KillOnDrop>();
59		#[cfg(not(feature = "kill-on-drop"))]
60		let kill_on_drop = false;
61
62		#[cfg(feature = "creation-flags")]
63		let create_suspended = core
64			.get_wrap::<CreationFlags>()
65			.map_or(false, |flags| flags.0.contains(CREATE_SUSPENDED));
66		#[cfg(not(feature = "creation-flags"))]
67		let create_suspended = false;
68
69		#[cfg(feature = "tracing")]
70		debug!(
71			?kill_on_drop,
72			?create_suspended,
73			"options from other wrappers"
74		);
75
76		let handle = HANDLE(
77			inner
78				.inner_child()
79				.raw_handle()
80				.expect("child has exited but it has not even started") as _,
81		);
82
83		let job_port = make_job_object(handle, kill_on_drop)?;
84
85		// only resume if the user didn't specify CREATE_SUSPENDED
86		if !create_suspended {
87			resume_threads(handle)?;
88		}
89
90		Ok(Box::new(JobObjectChild::new(inner, job_port)))
91	}
92}
93
94/// Wrapper for `Child` which waits on all processes within the job.
95#[derive(Debug)]
96pub struct JobObjectChild {
97	inner: Box<dyn ChildWrapper>,
98	exit_status: ChildExitStatus,
99	job_port: JobPort,
100}
101
102impl JobObjectChild {
103	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(job_port)))]
104	pub(crate) fn new(inner: Box<dyn ChildWrapper>, job_port: JobPort) -> Self {
105		Self {
106			inner,
107			exit_status: ChildExitStatus::Running,
108			job_port,
109		}
110	}
111}
112
113impl ChildWrapper for JobObjectChild {
114	fn inner(&self) -> &dyn ChildWrapper {
115		self.inner.inner()
116	}
117	fn inner_mut(&mut self) -> &mut dyn ChildWrapper {
118		self.inner.inner_mut()
119	}
120	fn into_inner(self: Box<Self>) -> Box<dyn ChildWrapper> {
121		// manually drop the completion port
122		let its = std::mem::ManuallyDrop::new(self.job_port);
123		unsafe { CloseHandle(its.completion_port.0) }.ok();
124		// we leave the job handle unclosed, otherwise the Child is useless
125		// (as closing it will terminate the job)
126
127		self.inner.into_inner()
128	}
129
130	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
131	fn start_kill(&mut self) -> Result<()> {
132		terminate_job(self.job_port.job, 1)
133	}
134
135	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
136	fn wait(&mut self) -> Pin<Box<dyn Future<Output = Result<ExitStatus>> + Send + '_>> {
137		Box::pin(async {
138			if let ChildExitStatus::Exited(status) = &self.exit_status {
139				return Ok(*status);
140			}
141
142			const MAX_RETRY_ATTEMPT: usize = 10;
143
144			// always wait for parent to exit first, as by the time it does,
145			// it's likely that all its children have already exited.
146			let status = self.inner.wait().await?;
147			self.exit_status = ChildExitStatus::Exited(status);
148
149			// nevertheless, now try reaping all children a few times...
150			for _ in 1..MAX_RETRY_ATTEMPT {
151				if wait_on_job(self.job_port.completion_port, Some(Duration::ZERO))?.is_break() {
152					return Ok(status);
153				}
154			}
155
156			// ...finally, if there are some that are still alive,
157			// block in the background to reap them fully.
158			let JobPort {
159				completion_port, ..
160			} = self.job_port;
161			let _ = spawn_blocking(move || wait_on_job(completion_port, None)).await??;
162			Ok(status)
163		})
164	}
165
166	#[cfg_attr(feature = "tracing", instrument(level = "debug", skip(self)))]
167	fn try_wait(&mut self) -> Result<Option<ExitStatus>> {
168		let _ = wait_on_job(self.job_port.completion_port, Some(Duration::ZERO))?;
169		self.inner.try_wait()
170	}
171}