1use std::{
3 collections::{hash_map::DefaultHasher, HashMap},
4 hash::Hasher,
5 io,
6 path::{Path, PathBuf},
7 sync::Arc,
8 time::Duration,
9};
10
11use tokio::{
12 net::{UnixListener, UnixStream},
13 process::Command,
14 sync::oneshot::{self, Sender},
15 sync::{mpsc, Mutex},
16 time,
17};
18
19use log::{error, info, warn};
20use once_cell::sync::OnceCell;
21use rand::Rng;
22
23use super::{Result, SOCKET, WORKER_ID};
24
25type IpcHandler = Box<dyn Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync>;
26
27fn supervisor_state() -> &'static Mutex<SupervisorState> {
29 static INSTANCE: OnceCell<Mutex<SupervisorState>> = OnceCell::new();
30 INSTANCE.get_or_init(|| Mutex::new(SupervisorState { workers: vec![] }))
31}
32
33pub enum Message {
36 Shutdown {
40 id: String,
42 },
43
44 Spawn {
46 task: Task,
48 },
49}
50
51#[derive(Debug, Clone)]
53pub struct Task {
54 cmd: String,
55 args: Vec<String>,
56 envs: HashMap<String, String>,
57 daemon: bool,
58 detached: bool,
59 limit: usize,
60 factor: usize,
61}
62
63impl Task {
64 pub fn new(cmd: &str) -> Self {
66 Self {
67 cmd: cmd.to_string(),
68 args: Vec::new(),
69 envs: HashMap::new(),
70 daemon: false,
71 detached: false,
72 limit: 5,
73 factor: 0,
74 }
75 }
76
77 pub fn args<I, S>(mut self, args: I) -> Self
79 where
80 I: IntoIterator<Item = S>,
81 S: AsRef<str>,
82 {
83 let args = args
84 .into_iter()
85 .map(|s| s.as_ref().to_string())
86 .collect::<Vec<_>>();
87 self.args = args;
88 self
89 }
90
91 pub fn envs<I, K, V>(mut self, vars: I) -> Self
93 where
94 I: IntoIterator<Item = (K, V)>,
95 K: AsRef<str>,
96 V: AsRef<str>,
97 {
98 let envs = vars
99 .into_iter()
100 .map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
101 .collect::<HashMap<_, _>>();
102 self.envs = envs;
103 self
104 }
105
106 pub fn daemon(mut self, flag: bool) -> Self {
111 self.daemon = flag;
112 self
113 }
114
115 pub fn detached(mut self, flag: bool) -> Self {
119 self.detached = flag;
120 self
121 }
122
123 pub fn retry_limit(mut self, limit: usize) -> Self {
132 self.limit = limit;
133 self
134 }
135
136 pub fn retry_factor(mut self, factor: usize) -> Self {
141 self.factor = factor;
142 self
143 }
144
145 fn retry(&self) -> Retry {
147 Retry {
148 limit: self.limit,
149 factor: self.factor,
150 attempts: 0,
151 }
152 }
153}
154
155#[derive(Clone, Copy)]
156struct Retry {
157 limit: usize,
160 factor: usize,
162 attempts: usize,
164}
165
166pub struct SupervisorBuilder {
168 socket: PathBuf,
169 commands: Vec<Task>,
170 ipc_handler: Option<IpcHandler>,
171 shutdown: Option<oneshot::Receiver<()>>,
172}
173
174impl SupervisorBuilder {
175 pub fn new() -> Self {
177 let socket = std::env::temp_dir().join("psup.sock");
178 Self {
179 socket,
180 commands: Vec::new(),
181 ipc_handler: None,
182 shutdown: None,
183 }
184 }
185
186 pub fn server<F: 'static>(mut self, handler: F) -> Self
188 where
189 F: Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync,
190 {
191 self.ipc_handler = Some(Box::new(handler));
192 self
193 }
194
195 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
197 self.socket = path.as_ref().to_path_buf();
198 self
199 }
200
201 pub fn add_worker(mut self, task: Task) -> Self {
203 self.commands.push(task);
204 self
205 }
206
207 pub fn shutdown(mut self, rx: oneshot::Receiver<()>) -> Self {
212 self.shutdown = Some(rx);
213 self
214 }
215
216 pub fn build(self) -> Supervisor {
218 Supervisor {
219 socket: self.socket,
220 commands: self.commands,
221 ipc_handler: self.ipc_handler.map(Arc::new),
222 shutdown: self.shutdown,
223 }
224 }
225}
226
227pub struct Supervisor {
229 socket: PathBuf,
230 commands: Vec<Task>,
231 ipc_handler: Option<Arc<IpcHandler>>,
232 shutdown: Option<oneshot::Receiver<()>>,
233}
234
235impl Supervisor {
236 pub async fn run(&mut self) -> Result<()> {
240 if let Some(ref ipc_handler) = self.ipc_handler {
242 let socket = self.socket.clone();
243 let control_socket = self.socket.clone();
244
245 let (control_tx, mut control_rx) = mpsc::channel::<Message>(1024);
246 let (tx, rx) = oneshot::channel::<()>();
247 let handler = Arc::clone(ipc_handler);
248
249 if let Some(shutdown) = self.shutdown.take() {
251 tokio::spawn(async move {
252 let _ = shutdown.await;
253 let mut state = supervisor_state().lock().await;
254 let workers = state.workers.drain(..);
255 for worker in workers {
256 let tx = worker.shutdown.clone();
257 let _ = tx.send(worker).await;
258 }
259 });
260 }
261
262 tokio::spawn(async move {
263 while let Some(msg) = control_rx.recv().await {
264 match msg {
265 Message::Shutdown { id } => {
266 let mut state = supervisor_state().lock().await;
267 let mut worker = state.remove(&id);
268 drop(state);
269 if let Some(worker) = worker.take() {
270 let tx = worker.shutdown.clone();
271 let _ = tx.send(worker).await;
272 } else {
273 warn!("Could not find worker to shutdown with id: {}", id);
274 }
275 }
276 Message::Spawn { task } => {
277 let id = id();
279 let retry = task.retry();
280 spawn_worker(
281 id,
282 task,
283 control_socket.clone(),
284 retry,
285 );
286 }
287 }
288 }
289 });
290
291 tokio::spawn(async move {
292 listen(&socket, tx, handler, control_tx)
293 .await
294 .expect("Supervisor failed to bind to socket");
295 });
296
297 let _ = rx.await?;
298 info!("Supervisor is listening {}", self.socket.display());
299 }
300
301 for task in self.commands.iter() {
303 self.spawn(task.clone());
304 }
305
306 Ok(())
307 }
308
309 pub fn spawn(&self, task: Task) -> String {
311 let id = id();
312 let retry = task.retry();
313 spawn_worker(id.clone(), task, self.socket.clone(), retry);
314 id
315 }
316
317 }
327
328struct SupervisorState {
330 workers: Vec<WorkerState>,
331}
332
333impl SupervisorState {
334 fn remove(&mut self, id: &str) -> Option<WorkerState> {
335 let res = self.workers.iter().enumerate().find_map(|(i, w)| {
336 if &w.id == id {
337 Some(i)
338 } else {
339 None
340 }
341 });
342 if let Some(position) = res {
343 Some(self.workers.swap_remove(position))
344 } else {
345 None
346 }
347 }
348}
349
350#[derive(Debug)]
351struct WorkerState {
352 task: Task,
353 id: String,
354 socket: PathBuf,
355 pid: Option<u32>,
356 reap: bool,
360 shutdown: mpsc::Sender<WorkerState>,
361}
362
363impl PartialEq for WorkerState {
364 fn eq(&self, other: &Self) -> bool {
365 self.id == other.id && self.pid == other.pid
366 }
367}
368
369impl Eq for WorkerState {}
370
371async fn restart(worker: WorkerState, mut retry: Retry) {
373 info!("Restarting worker {}", worker.id);
374 retry.attempts = retry.attempts + 1;
375
376 if retry.attempts >= retry.limit {
377 error!(
378 "Failed to restart worker {}, exceeded retry limit {}",
379 worker.id, retry.limit
380 );
381 } else {
382 if retry.factor > 0 {
383 let ms = retry.attempts * retry.factor;
384 info!("Delay restart {}ms", ms);
385 time::sleep(Duration::from_millis(ms as u64)).await;
386 }
387 spawn_worker(worker.id, worker.task, worker.socket, retry)
388 }
389}
390
391pub fn id() -> String {
393 let mut rng = rand::thread_rng();
394 let mut hasher = DefaultHasher::new();
395 hasher.write_usize(rng.gen());
396 format!("{:x}", hasher.finish())
397}
398
399fn spawn_worker(id: String, task: Task, socket: PathBuf, retry: Retry) {
400 tokio::task::spawn(async move {
401 let mut envs = task.envs.clone();
403 envs.insert(WORKER_ID.to_string(), id.clone());
404 if !task.detached {
405 envs.insert(
406 SOCKET.to_string(),
407 socket.to_string_lossy().into_owned(),
408 );
409 }
410
411 let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<WorkerState>(1);
412
413 info!("Spawn worker {} {}", &task.cmd, task.args.join(" "));
414
415 let mut child = Command::new(task.cmd.clone())
416 .args(task.args.clone())
417 .envs(envs)
418 .spawn()?;
419
420 let pid = child.id();
421
422 if let Some(ref id) = pid {
423 info!("Worker pid {}", id);
424 }
425
426 {
427 let worker = WorkerState {
428 id: id.clone(),
429 task,
430 socket,
431 pid,
432 reap: false,
433 shutdown: shutdown_tx,
434 };
435 let mut state = supervisor_state().lock().await;
436 state.workers.push(worker);
437 }
438
439 let mut reaping = false;
440
441 loop {
442 tokio::select!(
443 res = child.wait() => {
444 match res {
445 Ok(status) => {
446 let pid = pid.unwrap_or(0);
447 if !reaping {
448 if let Some(code) = status.code() {
449 warn!("Worker process died: {} (code: {})", pid, code);
450 } else {
451 warn!("Worker process died: {} ({})", pid, status);
452 }
453 }
454 let mut state = supervisor_state().lock().await;
455 let worker = state.remove(&id);
456 drop(state);
457 if let Some(worker) = worker {
458 info!("Removed child worker (id: {}, pid {})", worker.id, pid);
459 if !worker.reap && worker.task.daemon {
460 restart(worker, retry).await;
461 }
462 } else {
463 if !reaping {
464 error!("Failed to remove stale worker for pid {}", pid);
465 }
466 }
467 break;
468 }
469 Err(e) => return Err(e),
470 }
471 }
472 mut worker = shutdown_rx.recv() => {
473 if let Some(mut worker) = worker.take() {
474 reaping = true;
475 info!("Shutdown worker {}", worker.id);
476 worker.reap = true;
477 child.kill().await?;
478 }
479 }
480 )
481 }
482
483 Ok::<(), io::Error>(())
484 });
485}
486
487async fn listen<P: AsRef<Path>>(
488 socket: P,
489 tx: Sender<()>,
490 handler: Arc<IpcHandler>,
491 control_tx: mpsc::Sender<Message>,
492) -> Result<()> {
493 let path = socket.as_ref();
494
495 if path.exists() {
497 std::fs::remove_file(path)?;
498 }
499
500 let listener = UnixListener::bind(socket).unwrap();
501 tx.send(()).unwrap();
502
503 loop {
504 match listener.accept().await {
505 Ok((stream, _addr)) => (handler)(stream, control_tx.clone()),
506 Err(e) => {
507 warn!("Supervisor failed to accept worker socket {}", e);
508 }
509 }
510 }
511}