psup_impl/
supervisor.rs

1//! Supervisor manages a collection of worker processes.
2use std::{
3    collections::{hash_map::DefaultHasher, HashMap},
4    hash::Hasher,
5    io,
6    path::{Path, PathBuf},
7    sync::Arc,
8    time::Duration,
9};
10
11use tokio::{
12    net::{UnixListener, UnixStream},
13    process::Command,
14    sync::oneshot::{self, Sender},
15    sync::{mpsc, Mutex},
16    time,
17};
18
19use log::{error, info, warn};
20use once_cell::sync::OnceCell;
21use rand::Rng;
22
23use super::{Result, SOCKET, WORKER_ID};
24
25type IpcHandler = Box<dyn Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync>;
26
27/// Get the supervisor state.
28fn supervisor_state() -> &'static Mutex<SupervisorState> {
29    static INSTANCE: OnceCell<Mutex<SupervisorState>> = OnceCell::new();
30    INSTANCE.get_or_init(|| Mutex::new(SupervisorState { workers: vec![] }))
31}
32
33/// Control messages sent by the server handler to the
34/// supervisor.
35pub enum Message {
36    /// Shutdown a worker process using it's opaque identifier.
37    ///
38    /// If the worker is a daemon it will *not be restarted*.
39    Shutdown {
40        /// Opaque identifier for the worker.
41        id: String,
42    },
43
44    /// Spawn a new worker process.
45    Spawn {
46        /// Task definition for the new process.
47        task: Task,
48    },
49}
50
51/// Defines a worker process command.
52#[derive(Debug, Clone)]
53pub struct Task {
54    cmd: String,
55    args: Vec<String>,
56    envs: HashMap<String, String>,
57    daemon: bool,
58    detached: bool,
59    limit: usize,
60    factor: usize,
61}
62
63impl Task {
64    /// Create a new task.
65    pub fn new(cmd: &str) -> Self {
66        Self {
67            cmd: cmd.to_string(),
68            args: Vec::new(),
69            envs: HashMap::new(),
70            daemon: false,
71            detached: false,
72            limit: 5,
73            factor: 0,
74        }
75    }
76
77    /// Set command arguments.
78    pub fn args<I, S>(mut self, args: I) -> Self
79    where
80        I: IntoIterator<Item = S>,
81        S: AsRef<str>,
82    {
83        let args = args
84            .into_iter()
85            .map(|s| s.as_ref().to_string())
86            .collect::<Vec<_>>();
87        self.args = args;
88        self
89    }
90
91    /// Set command environment variables.
92    pub fn envs<I, K, V>(mut self, vars: I) -> Self
93    where
94        I: IntoIterator<Item = (K, V)>,
95        K: AsRef<str>,
96        V: AsRef<str>,
97    {
98        let envs = vars
99            .into_iter()
100            .map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
101            .collect::<HashMap<_, _>>();
102        self.envs = envs;
103        self
104    }
105
106    /// Set the daemon flag for the worker command.
107    ///
108    /// Daemon processes are restarted if they die without being explicitly
109    /// shutdown by the supervisor.
110    pub fn daemon(mut self, flag: bool) -> Self {
111        self.daemon = flag;
112        self
113    }
114
115    /// Set the detached flag for the worker command.
116    ///
117    /// If a worker is detached it will not connect to the IPC socket.
118    pub fn detached(mut self, flag: bool) -> Self {
119        self.detached = flag;
120        self
121    }
122
123    /// Set the retry limit when restarting dead workers.
124    ///
125    /// Only applies to tasks that have the `daemon` flag set;
126    /// non-daemon tasks are not restarted. If this value is
127    /// set to zero then it overrides the `daemon` flag and no
128    /// attempts to restart the process are made.
129    ///
130    /// The default value is `5`.
131    pub fn retry_limit(mut self, limit: usize) -> Self {
132        self.limit = limit;
133        self
134    }
135
136    /// Set the retry factor in milliseconds.
137    ///
138    /// The default value is `0` which means retry attempts
139    /// are performed immediately.
140    pub fn retry_factor(mut self, factor: usize) -> Self {
141        self.factor = factor;
142        self
143    }
144
145    /// Get a retry state for this task.
146    fn retry(&self) -> Retry {
147        Retry {
148            limit: self.limit,
149            factor: self.factor,
150            attempts: 0,
151        }
152    }
153}
154
155#[derive(Clone, Copy)]
156struct Retry {
157    /// The limit on the number of times to attempt
158    /// to restart a process.
159    limit: usize,
160    /// The retry delay factor in milliseconds.
161    factor: usize,
162    /// The current number of attempts.
163    attempts: usize,
164}
165
166/// Build a supervisor.
167pub struct SupervisorBuilder {
168    socket: PathBuf,
169    commands: Vec<Task>,
170    ipc_handler: Option<IpcHandler>,
171    shutdown: Option<oneshot::Receiver<()>>,
172}
173
174impl SupervisorBuilder {
175    /// Create a new supervisor builder.
176    pub fn new() -> Self {
177        let socket = std::env::temp_dir().join("psup.sock");
178        Self {
179            socket,
180            commands: Vec::new(),
181            ipc_handler: None,
182            shutdown: None,
183        }
184    }
185
186    /// Set the IPC server handler.
187    pub fn server<F: 'static>(mut self, handler: F) -> Self
188    where
189        F: Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync,
190    {
191        self.ipc_handler = Some(Box::new(handler));
192        self
193    }
194
195    /// Set the socket path.
196    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
197        self.socket = path.as_ref().to_path_buf();
198        self
199    }
200
201    /// Add a worker process.
202    pub fn add_worker(mut self, task: Task) -> Self {
203        self.commands.push(task);
204        self
205    }
206
207    /// Register a shutdown handler with the supervisor.
208    ///
209    /// When a message is received on the shutdown receiver all
210    /// managed processes are killed.
211    pub fn shutdown(mut self, rx: oneshot::Receiver<()>) -> Self {
212        self.shutdown = Some(rx);
213        self
214    }
215
216    /// Return the supervisor.
217    pub fn build(self) -> Supervisor {
218        Supervisor {
219            socket: self.socket,
220            commands: self.commands,
221            ipc_handler: self.ipc_handler.map(Arc::new),
222            shutdown: self.shutdown,
223        }
224    }
225}
226
227/// Supervisor manages long-running worker processes.
228pub struct Supervisor {
229    socket: PathBuf,
230    commands: Vec<Task>,
231    ipc_handler: Option<Arc<IpcHandler>>,
232    shutdown: Option<oneshot::Receiver<()>>,
233}
234
235impl Supervisor {
236    /// Start the supervisor running.
237    ///
238    /// Listens on the socket path and starts any initial workers.
239    pub async fn run(&mut self) -> Result<()> {
240        // Set up the server listener and control channel.
241        if let Some(ref ipc_handler) = self.ipc_handler {
242            let socket = self.socket.clone();
243            let control_socket = self.socket.clone();
244
245            let (control_tx, mut control_rx) = mpsc::channel::<Message>(1024);
246            let (tx, rx) = oneshot::channel::<()>();
247            let handler = Arc::clone(ipc_handler);
248
249            // Handle global shutdown signal, kills all the workers
250            if let Some(shutdown) = self.shutdown.take() {
251                tokio::spawn(async move {
252                    let _ = shutdown.await;
253                    let mut state = supervisor_state().lock().await;
254                    let workers = state.workers.drain(..);
255                    for worker in workers {
256                        let tx = worker.shutdown.clone();
257                        let _ = tx.send(worker).await;
258                    }
259                });
260            }
261
262            tokio::spawn(async move {
263                while let Some(msg) = control_rx.recv().await {
264                    match msg {
265                        Message::Shutdown { id } => {
266                            let mut state = supervisor_state().lock().await;
267                            let mut worker = state.remove(&id);
268                            drop(state);
269                            if let Some(worker) = worker.take() {
270                                let tx = worker.shutdown.clone();
271                                let _ = tx.send(worker).await;
272                            } else {
273                                warn!("Could not find worker to shutdown with id: {}", id);
274                            }
275                        }
276                        Message::Spawn { task } => {
277                            // FIXME: return the id to the caller?
278                            let id = id();
279                            let retry = task.retry();
280                            spawn_worker(
281                                id,
282                                task,
283                                control_socket.clone(),
284                                retry,
285                            );
286                        }
287                    }
288                }
289            });
290
291            tokio::spawn(async move {
292                listen(&socket, tx, handler, control_tx)
293                    .await
294                    .expect("Supervisor failed to bind to socket");
295            });
296
297            let _ = rx.await?;
298            info!("Supervisor is listening {}", self.socket.display());
299        }
300
301        // Spawn initial worker processes.
302        for task in self.commands.iter() {
303            self.spawn(task.clone());
304        }
305
306        Ok(())
307    }
308
309    /// Spawn a worker task.
310    pub fn spawn(&self, task: Task) -> String {
311        let id = id();
312        let retry = task.retry();
313        spawn_worker(id.clone(), task, self.socket.clone(), retry);
314        id
315    }
316
317    /*
318    /// Get the workers mapped from opaque identifier to process PID.
319    pub fn workers() -> HashMap<String, u32> {
320        let state = supervisor_state().lock().unwrap();
321        state.workers.iter()
322            .map(|w| (w.id.clone(), w.pid))
323            .collect::<HashMap<_, _>>()
324    }
325    */
326}
327
328/// State of the supervisor worker processes.
329struct SupervisorState {
330    workers: Vec<WorkerState>,
331}
332
333impl SupervisorState {
334    fn remove(&mut self, id: &str) -> Option<WorkerState> {
335        let res = self.workers.iter().enumerate().find_map(|(i, w)| {
336            if &w.id == id {
337                Some(i)
338            } else {
339                None
340            }
341        });
342        if let Some(position) = res {
343            Some(self.workers.swap_remove(position))
344        } else {
345            None
346        }
347    }
348}
349
350#[derive(Debug)]
351struct WorkerState {
352    task: Task,
353    id: String,
354    socket: PathBuf,
355    pid: Option<u32>,
356    /// If we are shutting down this worker explicitly
357    /// this flag will be set to prevent the worker from
358    /// being re-spawned.
359    reap: bool,
360    shutdown: mpsc::Sender<WorkerState>,
361}
362
363impl PartialEq for WorkerState {
364    fn eq(&self, other: &Self) -> bool {
365        self.id == other.id && self.pid == other.pid
366    }
367}
368
369impl Eq for WorkerState {}
370
371/// Attempt to restart a worker that died.
372async fn restart(worker: WorkerState, mut retry: Retry) {
373    info!("Restarting worker {}", worker.id);
374    retry.attempts = retry.attempts + 1;
375
376    if retry.attempts >= retry.limit {
377        error!(
378            "Failed to restart worker {}, exceeded retry limit {}",
379            worker.id, retry.limit
380        );
381    } else {
382        if retry.factor > 0 {
383            let ms = retry.attempts * retry.factor;
384            info!("Delay restart {}ms", ms);
385            time::sleep(Duration::from_millis(ms as u64)).await;
386        }
387        spawn_worker(worker.id, worker.task, worker.socket, retry)
388    }
389}
390
391/// Generate a random opaque identifier.
392pub fn id() -> String {
393    let mut rng = rand::thread_rng();
394    let mut hasher = DefaultHasher::new();
395    hasher.write_usize(rng.gen());
396    format!("{:x}", hasher.finish())
397}
398
399fn spawn_worker(id: String, task: Task, socket: PathBuf, retry: Retry) {
400    tokio::task::spawn(async move {
401        // Setup built in environment variables
402        let mut envs = task.envs.clone();
403        envs.insert(WORKER_ID.to_string(), id.clone());
404        if !task.detached {
405            envs.insert(
406                SOCKET.to_string(),
407                socket.to_string_lossy().into_owned(),
408            );
409        }
410
411        let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<WorkerState>(1);
412
413        info!("Spawn worker {} {}", &task.cmd, task.args.join(" "));
414
415        let mut child = Command::new(task.cmd.clone())
416            .args(task.args.clone())
417            .envs(envs)
418            .spawn()?;
419
420        let pid = child.id();
421
422        if let Some(ref id) = pid {
423            info!("Worker pid {}", id);
424        }
425
426        {
427            let worker = WorkerState {
428                id: id.clone(),
429                task,
430                socket,
431                pid,
432                reap: false,
433                shutdown: shutdown_tx,
434            };
435            let mut state = supervisor_state().lock().await;
436            state.workers.push(worker);
437        }
438
439        let mut reaping = false;
440
441        loop {
442            tokio::select!(
443                res = child.wait() => {
444                    match res {
445                        Ok(status) => {
446                            let pid = pid.unwrap_or(0);
447                            if !reaping {
448                                if let Some(code) = status.code() {
449                                    warn!("Worker process died: {} (code: {})", pid, code);
450                                } else {
451                                    warn!("Worker process died: {} ({})", pid, status);
452                                }
453                            }
454                            let mut state = supervisor_state().lock().await;
455                            let worker = state.remove(&id);
456                            drop(state);
457                            if let Some(worker) = worker {
458                                info!("Removed child worker (id: {}, pid {})", worker.id, pid);
459                                if !worker.reap && worker.task.daemon {
460                                    restart(worker, retry).await;
461                                }
462                            } else {
463                                if !reaping {
464                                    error!("Failed to remove stale worker for pid {}", pid);
465                                }
466                            }
467                            break;
468                        }
469                        Err(e) => return Err(e),
470                    }
471                }
472                mut worker = shutdown_rx.recv() => {
473                    if let Some(mut worker) = worker.take() {
474                        reaping = true;
475                        info!("Shutdown worker {}", worker.id);
476                        worker.reap = true;
477                        child.kill().await?;
478                    }
479                }
480            )
481        }
482
483        Ok::<(), io::Error>(())
484    });
485}
486
487async fn listen<P: AsRef<Path>>(
488    socket: P,
489    tx: Sender<()>,
490    handler: Arc<IpcHandler>,
491    control_tx: mpsc::Sender<Message>,
492) -> Result<()> {
493    let path = socket.as_ref();
494
495    // If the socket file exists we must remove to prevent `EADDRINUSE`
496    if path.exists() {
497        std::fs::remove_file(path)?;
498    }
499
500    let listener = UnixListener::bind(socket).unwrap();
501    tx.send(()).unwrap();
502
503    loop {
504        match listener.accept().await {
505            Ok((stream, _addr)) => (handler)(stream, control_tx.clone()),
506            Err(e) => {
507                warn!("Supervisor failed to accept worker socket {}", e);
508            }
509        }
510    }
511}