1use std::{collections::HashMap, sync::mpsc};
2
3use crate::{
4 errors::{TogetherError, TogetherInternalError, TogetherResult},
5 log, log_err,
6 process::{Process, ProcessId, ProcessSignal, ProcessStdio},
7};
8
9pub enum ProcessAction {
10 Create(String),
11 CreateAdvanced(String, CreateOptions),
12 Wait(ProcessId),
13 Kill(ProcessId),
14 KillAdvanced(ProcessId, ProcessSignal),
15 KillAll,
16 List,
17}
18
19#[derive(Debug)]
20pub enum ProcessActionResponse {
21 Created(ProcessId),
22 Waited(mpsc::Receiver<()>),
23 Killed,
24 KilledAll,
25 List(Vec<ProcessId>),
26 Error(ProcessManagerError),
27}
28
29#[derive(Debug)]
30pub enum ProcessManagerError {
31 SpawnChildFailed(String),
32 KillChildFailed(String),
33 NoSuchProcess,
34 Unknown,
35}
36
37#[derive(Default, Clone)]
38pub struct CreateOptions {
39 pub stdio: Option<ProcessStdio>,
40 pub cwd: Option<String>,
41}
42
43impl CreateOptions {
44 pub fn with_stderr_only(mut self) -> Self {
45 self.stdio = Some(ProcessStdio::StderrOnly);
46 self
47 }
48}
49
50pub struct Message(ProcessAction, mpsc::Sender<ProcessActionResponse>);
51
52pub struct ProcessManager {
53 processes: HashMap<ProcessId, Process>,
54 receiver: mpsc::Receiver<Message>,
55 sender: mpsc::Sender<Message>,
56 wait_handles: HashMap<ProcessId, mpsc::Sender<()>>,
57 index: u32,
58 raw_stdio: bool,
59 exit_on_error: bool,
60 quit_on_completion: bool,
61 killed: bool,
62 cwd: Option<String>,
63}
64
65impl ProcessManager {
66 pub fn new() -> Self {
67 let (sender, receiver) = mpsc::channel();
68 Self {
69 processes: HashMap::new(),
70 receiver,
71 sender,
72 wait_handles: HashMap::new(),
73 index: 0,
74 raw_stdio: false,
75 exit_on_error: false,
76 quit_on_completion: true,
77 killed: false,
78 cwd: None,
79 }
80 }
81
82 pub fn with_raw_mode(mut self, raw_mode: bool) -> Self {
83 self.raw_stdio = raw_mode;
84 self
85 }
86
87 pub fn with_exit_on_error(mut self, exit_on_error: bool) -> Self {
88 self.exit_on_error = exit_on_error;
89 self
90 }
91
92 pub fn with_quit_on_completion(mut self, quit_on_completion: bool) -> Self {
93 self.quit_on_completion = quit_on_completion;
94 self
95 }
96
97 pub fn with_working_directory(mut self, working_directory: Option<String>) -> Self {
98 self.cwd = working_directory;
99 self
100 }
101
102 pub fn start(self) -> ProcessManagerHandle {
103 let sender = self.sender.clone();
104 let thread = std::thread::spawn(move || self.rx_message_loop());
105 ProcessManagerHandle {
106 thread: Some(thread),
107 sender,
108 }
109 }
110
111 fn rx_message_loop(mut self) {
112 let timeout = std::time::Duration::from_millis(100);
113 loop {
114 match self.receiver.recv_timeout(timeout) {
115 Ok(message) => {
116 let response = self.process_message(message.0);
117 message.1.send(response).unwrap();
118 }
119 Err(mpsc::RecvTimeoutError::Timeout) => {
120 if self.killed {
121 break;
122 }
123 if !self.processes.is_empty() {
124 self.cleanup_dead_processes();
125
126 if self.processes.is_empty() {
127 if self.quit_on_completion || self.killed {
128 log!("All processes have exited, stopping...");
129 std::process::exit(0);
130 }
131
132 match self
133 .receiver
134 .recv_timeout(std::time::Duration::from_millis(100))
135 {
136 Ok(Message(ProcessAction::KillAll, _)) => {
137 std::process::exit(0);
138 }
139 Ok(message) => {
140 let response = self.process_message(message.0);
141 message.1.send(response).unwrap();
142 }
143 Err(mpsc::RecvTimeoutError::Timeout) => {
144 log!("No more processes running, waiting for new commands...");
145 }
146 Err(mpsc::RecvTimeoutError::Disconnected) => {
147 break;
148 }
149 }
150 }
151 }
152 }
153 Err(mpsc::RecvTimeoutError::Disconnected) => {
154 break;
155 }
156 }
157 }
158
159 std::process::exit(0);
160 }
161
162 fn process_message(&mut self, payload: ProcessAction) -> ProcessActionResponse {
163 match payload {
164 ProcessAction::Create(command) => {
165 let id = self.index;
166 self.index += 1;
167
168 self.start_new_process(command, self.cwd.clone(), self.raw_stdio.into(), id)
169 }
170 ProcessAction::CreateAdvanced(command, options) => {
171 let id = self.index;
172 self.index += 1;
173
174 let raw = options.stdio.unwrap_or(self.raw_stdio.into());
175 let cwd = options.cwd.clone().or_else(|| self.cwd.clone());
176
177 self.start_new_process(command, cwd, raw, id)
178 }
179 ProcessAction::Wait(id) => match self.processes.get(&id) {
180 Some(_) => {
181 let (sender, receiver) = mpsc::channel();
182 self.wait_handles.insert(id.clone(), sender);
183 ProcessActionResponse::Waited(receiver)
184 }
185 None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
186 },
187 ProcessAction::Kill(id) => match self.processes.get_mut(&id) {
188 Some(child) => match child.kill(None) {
189 Ok(_) => {
190 log!("Killing {}", id);
191 ProcessActionResponse::Killed
192 }
193 Err(e) => ProcessActionResponse::Error(ProcessManagerError::KillChildFailed(
194 e.to_string(),
195 )),
196 },
197 None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
198 },
199 ProcessAction::KillAdvanced(id, signal) => match self.processes.get_mut(&id) {
200 Some(child) => match child.kill(Some(&signal)) {
201 Ok(_) => {
202 log!("Killing {} with signal {:?}", id, signal);
203 ProcessActionResponse::Killed
204 }
205 Err(e) => ProcessActionResponse::Error(ProcessManagerError::KillChildFailed(
206 e.to_string(),
207 )),
208 },
209 None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
210 },
211 ProcessAction::KillAll => {
212 self.killed = true;
213
214 let mut errors = vec![];
215 for (id, child) in self.processes.iter_mut() {
216 match child.kill(None) {
217 Ok(_) => {
218 log!("Killing {}", id);
219 }
220 Err(e) => {
221 errors.push(ProcessManagerError::KillChildFailed(e.to_string()));
222 }
223 }
224 }
225 if errors.is_empty() {
226 ProcessActionResponse::KilledAll
227 } else {
228 ProcessActionResponse::Error(ProcessManagerError::Unknown)
229 }
230 }
231 ProcessAction::List => {
232 let list = self.processes.keys().cloned().collect();
233 ProcessActionResponse::List(list)
234 }
235 }
236 }
237
238 fn start_new_process(
239 &mut self,
240 command: String,
241 cwd: Option<String>,
242 stdio: ProcessStdio,
243 id: u32,
244 ) -> ProcessActionResponse {
245 match Process::spawn(&command, cwd.as_deref(), stdio) {
246 Ok(mut child) => {
247 let id = ProcessId::new(id, command);
248 if let ProcessStdio::Inherit = stdio {
249 child.forward_stdio(&id);
250 }
251 self.processes.insert(id.clone(), child);
252 log!("Started {}", id);
253 ProcessActionResponse::Created(id)
254 }
255 Err(e) => {
256 ProcessActionResponse::Error(ProcessManagerError::SpawnChildFailed(e.to_string()))
257 }
258 }
259 }
260
261 fn cleanup_dead_processes(&mut self) {
262 let mut remove = vec![];
263 let mut kill_all = false;
264
265 for (id, child) in self.processes.iter_mut() {
266 match child.try_wait() {
267 Ok(Some(status)) => {
268 remove.push(id.clone());
269 if status != 0 && self.exit_on_error {
270 log_err!("{}: exited with non-zero status", id);
271 kill_all = true;
272 }
273 }
274 Ok(None) => {}
275 Err(e) => {
276 log_err!("Failed to check child status: {}", e);
277 }
278 }
279 }
280
281 for id in remove {
282 if let Some(handle) = self.wait_handles.remove(&id) {
283 handle.send(()).unwrap();
284 }
285 self.processes.remove(&id);
286 log!("Finished {}", id);
287 }
288 if kill_all {
289 for (id, mut child) in self.processes.drain() {
290 match child.kill(None) {
291 Ok(_) => {}
292 Err(e) => {
293 log_err!("Failed to kill {id} => {}", e);
294 }
295 }
296 }
297 }
298 }
299}
300
301pub struct ProcessManagerHandle {
302 thread: Option<std::thread::JoinHandle<()>>,
303 sender: mpsc::Sender<Message>,
304}
305
306impl ProcessManagerHandle {
307 pub fn send(&self, action: ProcessAction) -> TogetherResult<ProcessActionResponse> {
308 let (sender, receiver) = mpsc::channel();
309 self.sender
310 .send(Message(action, sender))
311 .map_err(|e| TogetherError::DynError(e.into()))?;
312 receiver.recv().map_err(|e| e.into())
313 }
314 pub fn subscribe(&self) -> ProcessManagerHandle {
315 ProcessManagerHandle {
316 thread: None,
317 sender: self.sender.clone(),
318 }
319 }
320 pub fn list(&self) -> TogetherResult<Vec<ProcessId>> {
321 self.send(ProcessAction::List).and_then(|r| match r {
322 ProcessActionResponse::List(list) => Ok(list),
323 _ => Err(TogetherInternalError::UnexpectedResponse.into()),
324 })
325 }
326 pub fn spawn(&self, command: &str) -> TogetherResult<ProcessId> {
327 self.send(ProcessAction::Create(command.to_string()))
328 .and_then(|r| match r {
329 ProcessActionResponse::Created(id) => Ok(id),
330 _ => Err(TogetherInternalError::UnexpectedResponse.into()),
331 })
332 }
333 pub fn spawn_advanced(
334 &self,
335 command: &str,
336 options: &CreateOptions,
337 ) -> TogetherResult<ProcessId> {
338 self.send(ProcessAction::CreateAdvanced(
339 command.to_string(),
340 options.clone(),
341 ))
342 .and_then(|r| match r {
343 ProcessActionResponse::Created(id) => Ok(id),
344 _ => Err(TogetherInternalError::UnexpectedResponse.into()),
345 })
346 }
347 pub fn kill(&self, id: ProcessId) -> TogetherResult<Option<()>> {
348 self.send(ProcessAction::Kill(id)).and_then(|r| match r {
349 ProcessActionResponse::Killed => Ok(Some(())),
350 ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess) => Ok(None),
351 _ => Err(TogetherInternalError::UnexpectedResponse.into()),
352 })
353 }
354 pub fn restart(&self, id: ProcessId, command: &str) -> TogetherResult<Option<ProcessId>> {
355 match self.kill(id)? {
356 Some(()) => Ok(Some(self.spawn(command)?)),
357 None => Ok(None),
358 }
359 }
360 pub fn wait(&self, id: ProcessId) -> TogetherResult<()> {
361 self.send(ProcessAction::Wait(id)).and_then(|r| match r {
362 ProcessActionResponse::Waited(done) => done.recv().map_err(|e| e.into()),
363 _ => Err(TogetherInternalError::UnexpectedResponse.into()),
364 })
365 }
366}
367
368impl Drop for ProcessManagerHandle {
369 fn drop(&mut self) {
370 let Some(thread) = self.thread.take() else {
371 return;
372 };
373 let (sender, receiver) = mpsc::channel();
374
375 if let Err(_) = self.sender.send(Message(ProcessAction::KillAll, sender)) {
376 return;
378 };
379
380 match receiver.recv() {
381 Ok(ProcessActionResponse::KilledAll) => {
382 if let Err(e) = thread.join() {
383 log_err!("Failed to join process manager thread: {:?}", e);
384 }
385 }
386 Ok(ProcessActionResponse::Error(response)) => {
387 log_err!("Failed to kill all processes: {:?}", response);
388 }
389 Ok(_) => {
390 log_err!("Received unexpected kill all response");
391 }
392 Err(std::sync::mpsc::RecvError) => {
393 }
395 }
396 }
397}