use std::{
io, thread,
hash::Hasher,
path::{Path, PathBuf},
process::Command,
sync::Arc,
sync::Mutex,
collections::{hash_map::DefaultHasher, HashMap},
};
use tokio::net::{UnixListener, UnixStream};
use tokio::sync::oneshot::{self, Sender};
use log::{error, info, warn};
use once_cell::sync::OnceCell;
use rand::Rng;
use super::{Result, SOCKET, WORKER_ID, DETACHED};
type IpcHandler = Box<dyn Fn(UnixStream) + Send + Sync>;
fn supervisor_state() -> &'static Mutex<SupervisorState> {
static INSTANCE: OnceCell<Mutex<SupervisorState>> = OnceCell::new();
INSTANCE.get_or_init(|| Mutex::new(SupervisorState { workers: vec![] }))
}
#[derive(Debug, Clone)]
pub struct Task {
cmd: String,
args: Vec<String>,
envs: HashMap<String, String>,
daemon: bool,
detached: bool,
}
impl Task {
pub fn new(cmd: &str) -> Self {
Self {
cmd: cmd.to_string(),
args: Vec::new(),
envs: HashMap::new(),
daemon: false,
detached: false,
}
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str> {
let args = args.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
self.args = args;
self
}
pub fn envs<I, K, V>(mut self, vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let envs = vars
.into_iter()
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
.collect::<HashMap<_, _>>();
self.envs = envs;
self
}
pub fn daemon(mut self, flag: bool) -> Self {
self.daemon = flag;
self
}
pub fn detached(mut self, flag: bool) -> Self {
self.detached = flag;
self
}
}
pub struct SupervisorBuilder {
socket: PathBuf,
commands: Vec<Task>,
ipc_handler: IpcHandler,
}
impl SupervisorBuilder {
pub fn new(ipc_handler: IpcHandler) -> Self {
let socket = std::env::temp_dir().join("psup.sock");
Self {
socket,
commands: Vec::new(),
ipc_handler,
}
}
pub fn path(mut self, path: PathBuf) -> Self {
self.socket = path;
self
}
pub fn add_worker(mut self, task: Task) -> Self {
self.commands.push(task);
self
}
pub fn build(self) -> Supervisor {
Supervisor {
socket: self.socket,
commands: self.commands,
ipc_handler: Arc::new(self.ipc_handler),
}
}
}
pub struct Supervisor {
socket: PathBuf,
commands: Vec<Task>,
ipc_handler: Arc<IpcHandler>,
}
impl Supervisor {
pub async fn run(&self) -> Result<()> {
let socket = self.socket.clone();
let (tx, rx) = oneshot::channel::<()>();
let ipc = Arc::clone(&self.ipc_handler);
tokio::spawn(async move {
listen(&socket, tx, ipc)
.await
.expect("Supervisor failed to bind to socket");
});
let _ = rx.await?;
info!("Supervisor is listening {}", self.socket.display());
for task in self.commands.iter() {
spawn_worker(task.clone(), self.socket.clone());
}
Ok(())
}
pub fn spawn(&self, task: Task) {
spawn_worker(task, self.socket.clone());
}
}
struct SupervisorState {
workers: Vec<WorkerState>,
}
impl SupervisorState {
fn remove(&mut self, pid: u32) -> Option<WorkerState> {
let res = self.workers.iter().enumerate().find_map(|(i, w)| {
if w.pid == pid {
Some(i)
} else {
None
}
});
if let Some(position) = res {
Some(self.workers.swap_remove(position))
} else {
None
}
}
}
#[derive(Debug)]
struct WorkerState {
task: Task,
id: String,
socket: PathBuf,
pid: u32,
reap: bool,
}
impl PartialEq for WorkerState {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.pid == other.pid
}
}
impl Eq for WorkerState {}
fn restart(worker: WorkerState) {
info!("Restarting worker {}", worker.id);
spawn_worker(worker.task, worker.socket)
}
fn spawn_worker(task: Task, socket: PathBuf) {
let mut rng = rand::thread_rng();
let mut hasher = DefaultHasher::new();
hasher.write_usize(rng.gen());
let id = format!("{:x}", hasher.finish());
thread::spawn(move || {
let mut envs = task.envs.clone();
envs.insert(WORKER_ID.to_string(), id.clone());
envs.insert(
SOCKET.to_string(),
socket.to_string_lossy().into_owned(),
);
envs.insert(DETACHED.to_string(), task.detached.to_string());
info!("Spawn worker {}", &task.cmd);
let child = Command::new(task.cmd.clone())
.args(task.args.clone())
.envs(envs)
.spawn()?;
let pid = child.id();
{
let worker = WorkerState {
task,
id,
socket,
pid,
reap: false,
};
let mut state = supervisor_state().lock().unwrap();
state.workers.push(worker);
}
let _ = child.wait_with_output()?;
warn!("Worker process died: {}", pid);
let mut state = supervisor_state().lock().unwrap();
let worker = state.remove(pid);
drop(state);
if let Some(worker) = worker {
info!("Removed child worker (id: {}, pid {})", worker.id, pid);
if !worker.reap && worker.task.daemon {
restart(worker);
}
} else {
error!("Failed to remove stale worker for pid {}", pid);
}
Ok::<(), io::Error>(())
});
}
async fn listen<P: AsRef<Path>>(
socket: P,
tx: Sender<()>,
handler: Arc<IpcHandler>,
) -> Result<()> {
let path = socket.as_ref();
if path.exists() {
std::fs::remove_file(path)?;
}
let listener = UnixListener::bind(socket).unwrap();
tx.send(()).unwrap();
loop {
match listener.accept().await {
Ok((stream, _addr)) => (handler)(stream),
Err(e) => {
warn!("Supervisor failed to accept worker socket {}", e);
}
}
}
}