launch_pad/
lib.rs

1// SPDX-License-Identifier: MPL-2.0
2#[macro_use]
3extern crate log;
4
5pub mod error;
6pub mod message;
7pub mod process;
8pub mod util;
9
10use self::{
11    error::{Error, Result},
12    process::{Process, ProcessCallbacks, ReturnFuture},
13};
14
15use rand::Rng;
16use slotmap::{SlotMap, new_key_type};
17use std::{
18    borrow::Cow,
19    os::{
20        fd::{AsRawFd, BorrowedFd, OwnedFd},
21        unix::process::ExitStatusExt,
22    },
23    process::Stdio,
24    sync::Arc,
25};
26use tokio::{
27    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
28    process::{Child, Command},
29    sync::{RwLock, mpsc, oneshot},
30    time::Duration,
31};
32use tokio_util::sync::CancellationToken;
33
34#[cfg(target_os = "linux")]
35use rustix::io_uring::Signal;
36#[cfg(target_os = "linux")]
37use rustix::process::{Pid, kill_process};
38
39new_key_type! { pub struct ProcessKey; }
40
41#[derive(Clone)]
42pub struct ProcessManager {
43    inner: Arc<RwLock<ProcessManagerInner>>,
44    /// Transmitter for ProcessManager instances
45    /// a Process will be sent to the main loop for spawning
46    /// and a key will be sent back to the caller
47    tx: mpsc::UnboundedSender<(Process, oneshot::Sender<Result<ProcessKey>>)>,
48    cancel_token: CancellationToken,
49}
50
51impl ProcessManager {
52    pub async fn new() -> Self {
53        let (tx, mut rx) = mpsc::unbounded_channel();
54        let cancel = CancellationToken::new();
55        let inner = Arc::new(RwLock::new(ProcessManagerInner {
56            restart_mode: RestartMode::Instant,
57            max_restarts: 3,
58            processes: SlotMap::with_key(),
59        }));
60        let manager = ProcessManager {
61            inner,
62            tx,
63            cancel_token: cancel.clone(),
64        };
65        let manager_clone = manager.clone();
66        tokio::spawn(async move {
67            loop {
68                tokio::select! {
69                    _ = cancel.cancelled() => break,
70                    msg = rx.recv() => match msg {
71                        Some((process, return_tx)) => {
72                            return_tx
73                                .send(manager_clone.start_process(process).await)
74                                .unwrap();
75                        }
76                        None => break,
77                    }
78                }
79            }
80        });
81        manager
82    }
83
84    /// Starts a process with the given configuration. implicitly calls
85    /// `start_process`
86    pub async fn start(&self, process: Process) -> Result<ProcessKey> {
87        let (return_tx, return_rx) = oneshot::channel();
88        // send a process to spawn and a transmitter to the loop above
89        // and wait for the key to be returned
90        let _ = self.tx.send((process, return_tx));
91        return_rx.await?
92    }
93
94    /// Returns the current restart mode.
95    pub async fn restart_mode(&self) -> RestartMode {
96        self.inner.read().await.restart_mode
97    }
98
99    /// Sets the restart mode.
100    pub async fn set_restart_mode(&self, restart_mode: RestartMode) {
101        self.inner.write().await.restart_mode = restart_mode;
102    }
103
104    /// Returns the maximum amount of times a process can be restarted before
105    /// giving up.
106    pub async fn max_restarts(&self) -> usize {
107        self.inner.read().await.max_restarts
108    }
109
110    /// Sets the maximum amount of times a process can be restarted before
111    /// giving up.
112    pub async fn set_max_restarts(&self, max_restarts: usize) {
113        self.inner.write().await.max_restarts = max_restarts;
114    }
115
116    /// Returns whether the process manager has been stopped or not.
117    /// If the process manager has been stopped, no new processes can be
118    /// started.
119    pub fn is_stopped(&self) -> bool {
120        self.cancel_token.is_cancelled()
121    }
122
123    /// Stops the process manager, halting all processes and preventing new
124    /// processes from being started.
125    pub fn stop(&self) {
126        self.cancel_token.cancel();
127    }
128
129    /// Stops a single process.
130    pub async fn stop_process(&self, key: ProcessKey) -> Result<()> {
131        let inner = self.inner.read().await;
132        let process = inner.processes.get(key).ok_or(Error::NonExistantProcess)?;
133        process.cancel_token.cancel();
134        Ok(())
135    }
136
137    /// Send a message to a process over stdin
138    pub async fn send_message(&self, key: ProcessKey, message: Cow<'static, [u8]>) -> Result<()> {
139        let inner = self.inner.read().await;
140        let process = inner.processes.get(key).ok_or(Error::NonExistantProcess)?;
141        process.process.stdin_tx.send(message).await?;
142        Ok(())
143    }
144
145    pub async fn start_process(&self, mut process: Process) -> Result<ProcessKey> {
146        if self.is_stopped() {
147            return Err(Error::Stopped);
148        }
149
150        let Some(rx) = process.stdin_rx.take() else {
151            return Err(Error::MissingStdinReceiver);
152        };
153        info!(
154            "starting process '{} {} {}'",
155            process.env_text(),
156            process.exe_text(),
157            process.args_text()
158        );
159        let mut callbacks = std::mem::take(&mut process.callbacks);
160        let cancel_timeout = process.cancel_timeout;
161        let (callback_tx, mut callback_rx) = mpsc::unbounded_channel();
162
163        let cancel_token = self.cancel_token.child_token();
164
165        let fd_list = if let Some(fds) = callbacks.fds.take() {
166            fds()
167        } else {
168            Vec::new()
169        };
170        let raw_fds = fd_list.iter().map(|fd| fd.as_raw_fd()).collect::<Vec<_>>();
171
172        let mut command = Command::new(&process.executable);
173
174        command
175            .args(&process.args)
176            .envs(process.env.iter().map(|(k, v)| (k.as_str(), v.as_str())))
177            .stdout(Stdio::piped())
178            .stderr(Stdio::piped())
179            .stdin(Stdio::piped())
180            .kill_on_drop(true);
181
182        let key = self.inner.write().await.processes.insert(ProcessData {
183            process,
184            pid: None,
185            restarts: 0,
186            cancel_token: cancel_token.clone(),
187            cancel_timeout,
188        });
189
190        let command = unsafe {
191            command
192                .pre_exec(move || {
193                    for fd in &raw_fds {
194                        util::mark_as_not_cloexec(BorrowedFd::borrow_raw(*fd))?;
195                    }
196                    Ok(())
197                })
198                .spawn()
199                .map_err(Error::Process)?
200        };
201        drop(fd_list);
202        self.inner.write().await.processes.get_mut(key).unwrap().pid = command.id();
203        // This adds futures into a queue and executes them in a separate task, in order
204        // to both ensure execution of callbacks is in the same order the events are
205        // received, and to avoid blocking the reception of events if a callback is slow
206        // to return.
207        tokio::spawn(async move {
208            while let Some(f) = callback_rx.recv().await {
209                f.await
210            }
211        });
212        if let Some(on_start) = &callbacks.on_start {
213            let _ = callback_tx.send(on_start(self.clone(), key, false));
214        }
215        tokio::spawn(self.clone().process_loop(
216            key,
217            cancel_token.child_token(),
218            command,
219            callbacks,
220            callback_tx,
221            rx,
222        ));
223        Ok(key)
224    }
225
226    /// Just gives you the exe, along with the pid, of a managed process
227    pub async fn get_exe_and_pid(&self, key: ProcessKey) -> Result<(String, Option<u32>)> {
228        let inner = self.inner.read().await;
229        let pdata = inner
230            .processes
231            .get(key)
232            .ok_or(error::Error::NonExistantProcess)?;
233        Ok((pdata.process.executable.clone(), pdata.pid))
234    }
235
236    /// Get the pid of a managed process
237    pub async fn get_pid(&self, key: ProcessKey) -> Result<Option<u32>> {
238        let inner = self.inner.read().await;
239        Ok(inner
240            .processes
241            .get(key)
242            .ok_or(error::Error::NonExistantProcess)?
243            .pid)
244    }
245
246    async fn restart_process(&self, process_key: ProcessKey) -> Result<Child> {
247        let inner = self.inner.read().await;
248        let restart_mode = inner.restart_mode;
249        let process_data = inner
250            .processes
251            .get(process_key)
252            .ok_or(Error::InvalidProcess(process_key))?;
253        let restarts = process_data.restarts;
254        let executable = process_data.process.executable.clone();
255        drop(inner);
256
257        // delay before restarting
258        match restart_mode {
259            RestartMode::ExponentialBackoff(backoff) => {
260                let backoff = backoff.as_millis() as u64;
261                let jittered_delay: u64 = rand::rng().random_range(0..backoff);
262                let backoff = Duration::from_millis(
263                    2_u64
264                        .saturating_pow(restarts as u32)
265                        .saturating_mul(jittered_delay),
266                );
267                info!(
268                    "sleeping for {}ms before restarting process {} (restart {})",
269                    backoff.as_millis(),
270                    executable,
271                    restarts
272                );
273
274                tokio::time::sleep(backoff).await;
275            }
276            RestartMode::Delayed(backoff) => {
277                info!(
278                    "sleeping for {}ms before restarting process {} (restart {})",
279                    backoff.as_millis(),
280                    executable,
281                    restarts
282                );
283                tokio::time::sleep(backoff).await;
284            }
285            RestartMode::Instant => {}
286        }
287        let mut inner = self.inner.write().await;
288        let process_data = inner
289            .processes
290            .get_mut(process_key)
291            .ok_or(Error::InvalidProcess(process_key))?;
292        let mut fd_callback = process_data.process.callbacks.fds.take();
293        let fd_list = if let Some(fds) = fd_callback.take() {
294            fds()
295        } else {
296            Vec::new()
297        };
298        let raw_fds = fd_list.iter().map(|fd| fd.as_raw_fd()).collect::<Vec<_>>();
299        let command = unsafe {
300            Command::new(&process_data.process.executable)
301                .args(&process_data.process.args)
302                .envs(process_data.process.env.clone())
303                .stdout(Stdio::piped())
304                .stderr(Stdio::piped())
305                .stdin(Stdio::piped())
306                .kill_on_drop(true)
307                .pre_exec(move || {
308                    for fd in &raw_fds {
309                        util::mark_as_not_cloexec(BorrowedFd::borrow_raw(*fd))?;
310                    }
311                    Ok(())
312                })
313                .spawn()
314                .map_err(Error::Process)?
315        };
316        process_data.pid = command.id();
317        drop(fd_list);
318
319        process_data.restarts += 1;
320        info!(
321            "restarted process '{} {} {}', now at {} restarts",
322            process_data.process.env_text(),
323            process_data.process.exe_text(),
324            process_data.process.args_text(),
325            process_data.restarts
326        );
327        Ok(command)
328    }
329
330    async fn process_loop(
331        self,
332        key: ProcessKey,
333        cancel_token: CancellationToken,
334        mut command: Child,
335        callbacks: ProcessCallbacks,
336        callback_tx: mpsc::UnboundedSender<ReturnFuture>,
337        mut stdin_rx: mpsc::Receiver<Cow<'static, [u8]>>,
338    ) {
339        let (mut stdout, mut stderr) = match (command.stdout.take(), command.stderr.take()) {
340            (Some(stdout), Some(stderr)) => (
341                BufReader::new(stdout).lines(),
342                BufReader::new(stderr).lines(),
343            ),
344            (Some(_), None) => panic!("no stderr in process, even though we should be piping it"),
345            (None, Some(_)) => panic!("no stdout in process, even though we should be piping it"),
346            (None, None) => {
347                panic!("no stdout or stderr in process, even though we should be piping it")
348            }
349        };
350        let mut stdin = command
351            .stdin
352            .take()
353            .expect("No stdin in process, even though we should be piping it");
354        loop {
355            tokio::select! {
356                _ = cancel_token.cancelled() => {
357                    info!("process '{:?}' cancelled", key);
358                    let mut exit_code = None;
359                    if let Some(id) = command.id() {
360                        #[cfg(target_os = "linux")]
361                        if let Some(pid) = Pid::from_raw(id as i32) {
362                            if let Err(err) = kill_process(pid, Signal::TERM) {
363                                log::error!("Error sending SIGTERM: {err:?}");
364                            }
365                        }
366
367                        #[cfg(not(target_os = "linux"))]
368                        if unsafe { libc::kill(id as i32, libc::SIGTERM) == -1 } {
369                            log::error!("Error sending SIGTERM: {:?}", io::Error::last_os_error());
370                        }
371
372                        if let Some(t) = {
373                            let inner = self.inner.read().await;
374                            inner.processes.get(key).and_then(|p| p.cancel_timeout)
375                        } {
376                            match tokio::time::timeout(t, command.wait()).await {
377                                Ok(res) => {
378                                    match res {
379                                        Ok(status) => {
380                                            exit_code = status.code();
381                                        },
382                                        Err(err) => {
383                                            log::error!("Failed to stop program gracefully. {err:?}");
384                                        },
385                                    }
386                                }
387                                Err(_) => {
388                                    log::error!("Failed to stop program gracefully before cancel timeout.");
389                                }
390                            };
391                        } else {
392                            match command.wait().await {
393                                Ok(status) => {
394                                    exit_code = status.code();
395                                },
396                                Err(err) => {
397                                    log::error!("Failed to stop program gracefully. {err:?}");
398                                },
399                            }						}
400
401                    } else {
402                        log::error!("Failed to stop program gracefully. Missing pid.");
403                    }
404
405                    if exit_code.is_none() {
406                        if let Err(err) = command.kill().await {
407                            log::error!("Failed to kill program. {err:?}");
408                        };
409                        exit_code = Some(137);
410                    }
411
412                    if let Some(on_exit) = &callbacks.on_exit {
413                        // wait for this to complete before potentially restarting
414                        on_exit(self.clone(), key, exit_code, false).await;
415                    }
416                    break;
417                },
418                Some(message) = stdin_rx.recv() => {
419                    if let Err(err) =
420                        stdin.write_all(&message).await {
421                        error!("failed to write to stdin of process '{:?}': {}", key, err);
422                    }
423                }
424                Ok(Some(stdout_line)) = stdout.next_line() => {
425                    if let Some(on_stdout) = &callbacks.on_stdout {
426                        let _ = callback_tx.send(on_stdout(self.clone(), key, stdout_line));
427                    }
428                }
429                Ok(Some(stderr_line)) = stderr.next_line() => {
430                    if let Some(on_stderr) = &callbacks.on_stderr {
431                        let _ = callback_tx.send(on_stderr(self.clone(), key, stderr_line));
432                    }
433                }
434                ret = command.wait() => {
435                    let ret = ret.unwrap();
436                    let is_restarting = {
437                        let inner = self.inner.read().await;
438                        let process = inner.processes.get(key).unwrap();
439                        if !ret.success() {
440                            let env_text = process.process.env_text();
441                            let exe_text = process.process.exe_text();
442                            let args_text = process.process.args_text();
443                            if let Some(signal) = ret.signal() {
444                                error!("process '{} {} {}' terminated with signal {}", env_text, exe_text, args_text, signal);
445                            } else if let Some(code) = ret.code() {
446                                error!("process '{} {} {}' failed with code {}", env_text, exe_text, args_text, code);
447                            }
448                        }
449                        !ret.success() && (inner.max_restarts > process.restarts)
450                    };
451                    if let Some(on_exit) = &callbacks.on_exit {
452                        // wait for this to complete before potentially restarting
453                        on_exit(self.clone(), key, ret.code(), is_restarting).await;
454                    }
455                    if is_restarting {
456                        info!("draining stdin receiver before restarting process");
457                        while let Ok(_) = stdin_rx.try_recv() {}
458
459                        match self.restart_process(key).await {
460                            Ok(new_command) =>  {
461                                command = new_command;
462                                (stdout, stderr) = match (command.stdout.take(), command.stderr.take()) {
463                                    (Some(stdout), Some(stderr)) => (
464                                        BufReader::new(stdout).lines(),
465                                        BufReader::new(stderr).lines(),
466                                    ),
467                                    (Some(_), None) => panic!("no stderr in process, even though we should be piping it"),
468                                    (None, Some(_)) => panic!("no stdout in process, even though we should be piping it"),
469                                    (None, None) => {
470                                        panic!("no stdout or stderr in process, even though we should be piping it")
471                                    }
472                                };
473                                stdin = command
474                                    .stdin
475                                    .take()
476                                    .expect("No stdin in process, even though we should be piping it");
477                                if let Some(on_start) = &callbacks.on_start {
478                                    let _ = callback_tx.send(on_start(self.clone(), key, true));
479                                }
480                                continue;
481                            }
482                            Err(err) => {
483                                error!("failed to restart process '{:?}: {}", key, err);
484                            }
485                        }
486                    }
487                    break;
488                }
489            }
490        }
491    }
492
493    /// update the args of a managed process
494    /// This will reset previous args if they are not set again
495    /// changes will be applied after the process restarts
496    pub async fn update_process_args(&mut self, key: &ProcessKey, args: Vec<String>) -> Result<()> {
497        let mut r = self.inner.write().await;
498        if let Some(pdata) = r.processes.get_mut(*key) {
499            pdata.process.args = args;
500            Ok(())
501        } else {
502            Err(Error::NonExistantProcess)
503        }
504    }
505
506    /// update the env of a managed process
507    /// changes will be applied after the process restarts
508    pub async fn update_process_env(
509        &mut self,
510        key: &ProcessKey,
511        env: impl IntoIterator<Item = (impl ToString, impl ToString)>,
512    ) -> Result<()> {
513        let mut r = self.inner.write().await;
514        if let Some(pdata) = r.processes.get_mut(*key) {
515            let mut new_env: Vec<(_, _)> = env
516                .into_iter()
517                .map(|(k, v)| (k.to_string(), v.to_string()))
518                .collect();
519            pdata
520                .process
521                .env
522                .retain(|(k, _)| !new_env.iter().any(|(k_new, _)| k == k_new));
523            pdata.process.env.append(&mut new_env);
524            Ok(())
525        } else {
526            Err(Error::NonExistantProcess)
527        }
528    }
529
530    pub async fn update_process_fds<F>(&mut self, key: &ProcessKey, f: F) -> Result<()>
531    where
532        F: FnOnce() -> Vec<OwnedFd> + Send + Sync + 'static,
533    {
534        let mut r = self.inner.write().await;
535        if let Some(pdata) = r.processes.get_mut(*key) {
536            pdata.process.callbacks.fds = Some(Box::new(f));
537            Ok(())
538        } else {
539            Err(Error::NonExistantProcess)
540        }
541    }
542
543    /// update the env of a managed process
544    /// changes will be applied after the process restarts
545    pub async fn clear_process_env(&mut self, key: &ProcessKey) -> Result<()> {
546        let mut r = self.inner.write().await;
547        if let Some(pdata) = r.processes.get_mut(*key) {
548            pdata.process.env.clear();
549            Ok(())
550        } else {
551            Err(Error::NonExistantProcess)
552        }
553    }
554
555    // TODO methods for modifying other process data
556}
557
558struct ProcessData {
559    process: Process,
560    pid: Option<u32>,
561    restarts: usize,
562    cancel_token: CancellationToken,
563    cancel_timeout: Option<Duration>,
564}
565
566struct ProcessManagerInner {
567    restart_mode: RestartMode,
568    max_restarts: usize,
569    processes: SlotMap<ProcessKey, ProcessData>,
570}
571
572#[derive(Clone, Copy, Debug)]
573pub enum RestartMode {
574    Instant,
575    Delayed(Duration),
576    ExponentialBackoff(Duration),
577}