together_rs/
manager.rs

1use std::{collections::HashMap, sync::mpsc};
2
3use crate::{
4    errors::{TogetherError, TogetherInternalError, TogetherResult},
5    log, log_err,
6    process::{Process, ProcessId, ProcessSignal, ProcessStdio},
7};
8
9pub enum ProcessAction {
10    Create(String),
11    CreateAdvanced(String, CreateOptions),
12    Wait(ProcessId),
13    Kill(ProcessId),
14    KillAdvanced(ProcessId, ProcessSignal),
15    KillAll,
16    List,
17}
18
19#[derive(Debug)]
20pub enum ProcessActionResponse {
21    Created(ProcessId),
22    Waited(mpsc::Receiver<()>),
23    Killed,
24    KilledAll,
25    List(Vec<ProcessId>),
26    Error(ProcessManagerError),
27}
28
29#[derive(Debug)]
30pub enum ProcessManagerError {
31    SpawnChildFailed(String),
32    KillChildFailed(String),
33    NoSuchProcess,
34    Unknown,
35}
36
37#[derive(Default, Clone)]
38pub struct CreateOptions {
39    pub stdio: Option<ProcessStdio>,
40    pub cwd: Option<String>,
41}
42
43impl CreateOptions {
44    pub fn with_stderr_only(mut self) -> Self {
45        self.stdio = Some(ProcessStdio::StderrOnly);
46        self
47    }
48}
49
50pub struct Message(ProcessAction, mpsc::Sender<ProcessActionResponse>);
51
52pub struct ProcessManager {
53    processes: HashMap<ProcessId, Process>,
54    receiver: mpsc::Receiver<Message>,
55    sender: mpsc::Sender<Message>,
56    wait_handles: HashMap<ProcessId, mpsc::Sender<()>>,
57    index: u32,
58    raw_stdio: bool,
59    exit_on_error: bool,
60    quit_on_completion: bool,
61    killed: bool,
62    cwd: Option<String>,
63}
64
65impl ProcessManager {
66    pub fn new() -> Self {
67        let (sender, receiver) = mpsc::channel();
68        Self {
69            processes: HashMap::new(),
70            receiver,
71            sender,
72            wait_handles: HashMap::new(),
73            index: 0,
74            raw_stdio: false,
75            exit_on_error: false,
76            quit_on_completion: true,
77            killed: false,
78            cwd: None,
79        }
80    }
81
82    pub fn with_raw_mode(mut self, raw_mode: bool) -> Self {
83        self.raw_stdio = raw_mode;
84        self
85    }
86
87    pub fn with_exit_on_error(mut self, exit_on_error: bool) -> Self {
88        self.exit_on_error = exit_on_error;
89        self
90    }
91
92    pub fn with_quit_on_completion(mut self, quit_on_completion: bool) -> Self {
93        self.quit_on_completion = quit_on_completion;
94        self
95    }
96
97    pub fn with_working_directory(mut self, working_directory: Option<String>) -> Self {
98        self.cwd = working_directory;
99        self
100    }
101
102    pub fn start(self) -> ProcessManagerHandle {
103        let sender = self.sender.clone();
104        let thread = std::thread::spawn(move || self.rx_message_loop());
105        ProcessManagerHandle {
106            thread: Some(thread),
107            sender,
108        }
109    }
110
111    fn rx_message_loop(mut self) {
112        let timeout = std::time::Duration::from_millis(100);
113        loop {
114            match self.receiver.recv_timeout(timeout) {
115                Ok(message) => {
116                    let response = self.process_message(message.0);
117                    message.1.send(response).unwrap();
118                }
119                Err(mpsc::RecvTimeoutError::Timeout) => {
120                    if self.killed {
121                        break;
122                    }
123                    if !self.processes.is_empty() {
124                        self.cleanup_dead_processes();
125
126                        if self.processes.is_empty() {
127                            if self.quit_on_completion || self.killed {
128                                log!("All processes have exited, stopping...");
129                                std::process::exit(0);
130                            }
131
132                            match self
133                                .receiver
134                                .recv_timeout(std::time::Duration::from_millis(100))
135                            {
136                                Ok(Message(ProcessAction::KillAll, _)) => {
137                                    std::process::exit(0);
138                                }
139                                Ok(message) => {
140                                    let response = self.process_message(message.0);
141                                    message.1.send(response).unwrap();
142                                }
143                                Err(mpsc::RecvTimeoutError::Timeout) => {
144                                    log!("No more processes running, waiting for new commands...");
145                                }
146                                Err(mpsc::RecvTimeoutError::Disconnected) => {
147                                    break;
148                                }
149                            }
150                        }
151                    }
152                }
153                Err(mpsc::RecvTimeoutError::Disconnected) => {
154                    break;
155                }
156            }
157        }
158
159        std::process::exit(0);
160    }
161
162    fn process_message(&mut self, payload: ProcessAction) -> ProcessActionResponse {
163        match payload {
164            ProcessAction::Create(command) => {
165                let id = self.index;
166                self.index += 1;
167
168                self.start_new_process(command, self.cwd.clone(), self.raw_stdio.into(), id)
169            }
170            ProcessAction::CreateAdvanced(command, options) => {
171                let id = self.index;
172                self.index += 1;
173
174                let raw = options.stdio.unwrap_or(self.raw_stdio.into());
175                let cwd = options.cwd.clone().or_else(|| self.cwd.clone());
176
177                self.start_new_process(command, cwd, raw, id)
178            }
179            ProcessAction::Wait(id) => match self.processes.get(&id) {
180                Some(_) => {
181                    let (sender, receiver) = mpsc::channel();
182                    self.wait_handles.insert(id.clone(), sender);
183                    ProcessActionResponse::Waited(receiver)
184                }
185                None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
186            },
187            ProcessAction::Kill(id) => match self.processes.get_mut(&id) {
188                Some(child) => match child.kill(None) {
189                    Ok(_) => {
190                        log!("Killing {}", id);
191                        ProcessActionResponse::Killed
192                    }
193                    Err(e) => ProcessActionResponse::Error(ProcessManagerError::KillChildFailed(
194                        e.to_string(),
195                    )),
196                },
197                None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
198            },
199            ProcessAction::KillAdvanced(id, signal) => match self.processes.get_mut(&id) {
200                Some(child) => match child.kill(Some(&signal)) {
201                    Ok(_) => {
202                        log!("Killing {} with signal {:?}", id, signal);
203                        ProcessActionResponse::Killed
204                    }
205                    Err(e) => ProcessActionResponse::Error(ProcessManagerError::KillChildFailed(
206                        e.to_string(),
207                    )),
208                },
209                None => ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess),
210            },
211            ProcessAction::KillAll => {
212                self.killed = true;
213
214                let mut errors = vec![];
215                for (id, child) in self.processes.iter_mut() {
216                    match child.kill(None) {
217                        Ok(_) => {
218                            log!("Killing {}", id);
219                        }
220                        Err(e) => {
221                            errors.push(ProcessManagerError::KillChildFailed(e.to_string()));
222                        }
223                    }
224                }
225                if errors.is_empty() {
226                    ProcessActionResponse::KilledAll
227                } else {
228                    ProcessActionResponse::Error(ProcessManagerError::Unknown)
229                }
230            }
231            ProcessAction::List => {
232                let list = self.processes.keys().cloned().collect();
233                ProcessActionResponse::List(list)
234            }
235        }
236    }
237
238    fn start_new_process(
239        &mut self,
240        command: String,
241        cwd: Option<String>,
242        stdio: ProcessStdio,
243        id: u32,
244    ) -> ProcessActionResponse {
245        match Process::spawn(&command, cwd.as_deref(), stdio) {
246            Ok(mut child) => {
247                let id = ProcessId::new(id, command);
248                if let ProcessStdio::Inherit = stdio {
249                    child.forward_stdio(&id);
250                }
251                self.processes.insert(id.clone(), child);
252                log!("Started  {}", id);
253                ProcessActionResponse::Created(id)
254            }
255            Err(e) => {
256                ProcessActionResponse::Error(ProcessManagerError::SpawnChildFailed(e.to_string()))
257            }
258        }
259    }
260
261    fn cleanup_dead_processes(&mut self) {
262        let mut remove = vec![];
263        let mut kill_all = false;
264
265        for (id, child) in self.processes.iter_mut() {
266            match child.try_wait() {
267                Ok(Some(status)) => {
268                    remove.push(id.clone());
269                    if status != 0 && self.exit_on_error {
270                        log_err!("{}: exited with non-zero status", id);
271                        kill_all = true;
272                    }
273                }
274                Ok(None) => {}
275                Err(e) => {
276                    log_err!("Failed to check child status: {}", e);
277                }
278            }
279        }
280
281        for id in remove {
282            if let Some(handle) = self.wait_handles.remove(&id) {
283                handle.send(()).unwrap();
284            }
285            self.processes.remove(&id);
286            log!("Finished {}", id);
287        }
288        if kill_all {
289            for (id, mut child) in self.processes.drain() {
290                match child.kill(None) {
291                    Ok(_) => {}
292                    Err(e) => {
293                        log_err!("Failed to kill {id} => {}", e);
294                    }
295                }
296            }
297        }
298    }
299}
300
301pub struct ProcessManagerHandle {
302    thread: Option<std::thread::JoinHandle<()>>,
303    sender: mpsc::Sender<Message>,
304}
305
306impl ProcessManagerHandle {
307    pub fn send(&self, action: ProcessAction) -> TogetherResult<ProcessActionResponse> {
308        let (sender, receiver) = mpsc::channel();
309        self.sender
310            .send(Message(action, sender))
311            .map_err(|e| TogetherError::DynError(e.into()))?;
312        receiver.recv().map_err(|e| e.into())
313    }
314    pub fn subscribe(&self) -> ProcessManagerHandle {
315        ProcessManagerHandle {
316            thread: None,
317            sender: self.sender.clone(),
318        }
319    }
320    pub fn list(&self) -> TogetherResult<Vec<ProcessId>> {
321        self.send(ProcessAction::List).and_then(|r| match r {
322            ProcessActionResponse::List(list) => Ok(list),
323            _ => Err(TogetherInternalError::UnexpectedResponse.into()),
324        })
325    }
326    pub fn spawn(&self, command: &str) -> TogetherResult<ProcessId> {
327        self.send(ProcessAction::Create(command.to_string()))
328            .and_then(|r| match r {
329                ProcessActionResponse::Created(id) => Ok(id),
330                _ => Err(TogetherInternalError::UnexpectedResponse.into()),
331            })
332    }
333    pub fn spawn_advanced(
334        &self,
335        command: &str,
336        options: &CreateOptions,
337    ) -> TogetherResult<ProcessId> {
338        self.send(ProcessAction::CreateAdvanced(
339            command.to_string(),
340            options.clone(),
341        ))
342        .and_then(|r| match r {
343            ProcessActionResponse::Created(id) => Ok(id),
344            _ => Err(TogetherInternalError::UnexpectedResponse.into()),
345        })
346    }
347    pub fn kill(&self, id: ProcessId) -> TogetherResult<Option<()>> {
348        self.send(ProcessAction::Kill(id)).and_then(|r| match r {
349            ProcessActionResponse::Killed => Ok(Some(())),
350            ProcessActionResponse::Error(ProcessManagerError::NoSuchProcess) => Ok(None),
351            _ => Err(TogetherInternalError::UnexpectedResponse.into()),
352        })
353    }
354    pub fn restart(&self, id: ProcessId, command: &str) -> TogetherResult<Option<ProcessId>> {
355        match self.kill(id)? {
356            Some(()) => Ok(Some(self.spawn(command)?)),
357            None => Ok(None),
358        }
359    }
360    pub fn wait(&self, id: ProcessId) -> TogetherResult<()> {
361        self.send(ProcessAction::Wait(id)).and_then(|r| match r {
362            ProcessActionResponse::Waited(done) => done.recv().map_err(|e| e.into()),
363            _ => Err(TogetherInternalError::UnexpectedResponse.into()),
364        })
365    }
366}
367
368impl Drop for ProcessManagerHandle {
369    fn drop(&mut self) {
370        let Some(thread) = self.thread.take() else {
371            return;
372        };
373        let (sender, receiver) = mpsc::channel();
374
375        if let Err(_) = self.sender.send(Message(ProcessAction::KillAll, sender)) {
376            // the process manager has already exited, nothing to do
377            return;
378        };
379
380        match receiver.recv() {
381            Ok(ProcessActionResponse::KilledAll) => {
382                if let Err(e) = thread.join() {
383                    log_err!("Failed to join process manager thread: {:?}", e);
384                }
385            }
386            Ok(ProcessActionResponse::Error(response)) => {
387                log_err!("Failed to kill all processes: {:?}", response);
388            }
389            Ok(_) => {
390                log_err!("Received unexpected kill all response");
391            }
392            Err(std::sync::mpsc::RecvError) => {
393                // the process manager has already exited, nothing to do
394            }
395        }
396    }
397}