use std::{
collections::{hash_map::DefaultHasher, HashMap},
hash::Hasher,
io,
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use tokio::{
net::{UnixListener, UnixStream},
process::Command,
sync::oneshot::{self, Sender},
sync::{mpsc, Mutex},
time,
};
use log::{error, info, warn};
use once_cell::sync::OnceCell;
use rand::Rng;
use super::{Result, SOCKET, WORKER_ID};
type IpcHandler = Box<dyn Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync>;
fn supervisor_state() -> &'static Mutex<SupervisorState> {
static INSTANCE: OnceCell<Mutex<SupervisorState>> = OnceCell::new();
INSTANCE.get_or_init(|| Mutex::new(SupervisorState { workers: vec![] }))
}
pub enum Message {
Shutdown {
id: String,
},
Spawn {
task: Task,
},
}
#[derive(Debug, Clone)]
pub struct Task {
cmd: String,
args: Vec<String>,
envs: HashMap<String, String>,
daemon: bool,
detached: bool,
limit: usize,
factor: usize,
}
impl Task {
pub fn new(cmd: &str) -> Self {
Self {
cmd: cmd.to_string(),
args: Vec::new(),
envs: HashMap::new(),
daemon: false,
detached: false,
limit: 5,
factor: 0,
}
}
pub fn args<I, S>(mut self, args: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let args = args
.into_iter()
.map(|s| s.as_ref().to_string())
.collect::<Vec<_>>();
self.args = args;
self
}
pub fn envs<I, K, V>(mut self, vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let envs = vars
.into_iter()
.map(|(k, v)| (k.as_ref().to_string(), v.as_ref().to_string()))
.collect::<HashMap<_, _>>();
self.envs = envs;
self
}
pub fn daemon(mut self, flag: bool) -> Self {
self.daemon = flag;
self
}
pub fn detached(mut self, flag: bool) -> Self {
self.detached = flag;
self
}
pub fn retry_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn retry_factor(mut self, factor: usize) -> Self {
self.factor = factor;
self
}
fn retry(&self) -> Retry {
Retry {
limit: self.limit,
factor: self.factor,
attempts: 0,
}
}
}
#[derive(Clone, Copy)]
struct Retry {
limit: usize,
factor: usize,
attempts: usize,
}
pub struct SupervisorBuilder {
socket: PathBuf,
commands: Vec<Task>,
ipc_handler: Option<IpcHandler>,
shutdown: Option<oneshot::Receiver<()>>,
}
impl SupervisorBuilder {
pub fn new() -> Self {
let socket = std::env::temp_dir().join("psup.sock");
Self {
socket,
commands: Vec::new(),
ipc_handler: None,
shutdown: None,
}
}
pub fn server<F: 'static>(mut self, handler: F) -> Self
where
F: Fn(UnixStream, mpsc::Sender<Message>) + Send + Sync,
{
self.ipc_handler = Some(Box::new(handler));
self
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.socket = path.as_ref().to_path_buf();
self
}
pub fn add_worker(mut self, task: Task) -> Self {
self.commands.push(task);
self
}
pub fn shutdown(mut self, rx: oneshot::Receiver<()>) -> Self {
self.shutdown = Some(rx);
self
}
pub fn build(self) -> Supervisor {
Supervisor {
socket: self.socket,
commands: self.commands,
ipc_handler: self.ipc_handler.map(Arc::new),
shutdown: self.shutdown,
}
}
}
pub struct Supervisor {
socket: PathBuf,
commands: Vec<Task>,
ipc_handler: Option<Arc<IpcHandler>>,
shutdown: Option<oneshot::Receiver<()>>,
}
impl Supervisor {
pub async fn run(&mut self) -> Result<()> {
if let Some(ref ipc_handler) = self.ipc_handler {
let socket = self.socket.clone();
let control_socket = self.socket.clone();
let (control_tx, mut control_rx) = mpsc::channel::<Message>(1024);
let (tx, rx) = oneshot::channel::<()>();
let handler = Arc::clone(ipc_handler);
if let Some(shutdown) = self.shutdown.take() {
tokio::spawn(async move {
let _ = shutdown.await;
let mut state = supervisor_state().lock().await;
let workers = state.workers.drain(..);
for worker in workers {
let tx = worker.shutdown.clone();
let _ = tx.send(worker).await;
}
});
}
tokio::spawn(async move {
while let Some(msg) = control_rx.recv().await {
match msg {
Message::Shutdown { id } => {
let mut state = supervisor_state().lock().await;
let mut worker = state.remove(&id);
drop(state);
if let Some(worker) = worker.take() {
let tx = worker.shutdown.clone();
let _ = tx.send(worker).await;
} else {
warn!("Could not find worker to shutdown with id: {}", id);
}
}
Message::Spawn { task } => {
let id = id();
let retry = task.retry();
spawn_worker(
id,
task,
control_socket.clone(),
retry,
);
}
}
}
});
tokio::spawn(async move {
listen(&socket, tx, handler, control_tx)
.await
.expect("Supervisor failed to bind to socket");
});
let _ = rx.await?;
info!("Supervisor is listening {}", self.socket.display());
}
for task in self.commands.iter() {
self.spawn(task.clone());
}
Ok(())
}
pub fn spawn(&self, task: Task) -> String {
let id = id();
let retry = task.retry();
spawn_worker(id.clone(), task, self.socket.clone(), retry);
id
}
}
struct SupervisorState {
workers: Vec<WorkerState>,
}
impl SupervisorState {
fn remove(&mut self, id: &str) -> Option<WorkerState> {
let res = self.workers.iter().enumerate().find_map(|(i, w)| {
if &w.id == id {
Some(i)
} else {
None
}
});
if let Some(position) = res {
Some(self.workers.swap_remove(position))
} else {
None
}
}
}
#[derive(Debug)]
struct WorkerState {
task: Task,
id: String,
socket: PathBuf,
pid: Option<u32>,
reap: bool,
shutdown: mpsc::Sender<WorkerState>,
}
impl PartialEq for WorkerState {
fn eq(&self, other: &Self) -> bool {
self.id == other.id && self.pid == other.pid
}
}
impl Eq for WorkerState {}
async fn restart(worker: WorkerState, mut retry: Retry) {
info!("Restarting worker {}", worker.id);
retry.attempts = retry.attempts + 1;
if retry.attempts >= retry.limit {
error!(
"Failed to restart worker {}, exceeded retry limit {}",
worker.id, retry.limit
);
} else {
if retry.factor > 0 {
let ms = retry.attempts * retry.factor;
info!("Delay restart {}ms", ms);
time::sleep(Duration::from_millis(ms as u64)).await;
}
spawn_worker(worker.id, worker.task, worker.socket, retry)
}
}
pub fn id() -> String {
let mut rng = rand::thread_rng();
let mut hasher = DefaultHasher::new();
hasher.write_usize(rng.gen());
format!("{:x}", hasher.finish())
}
fn spawn_worker(id: String, task: Task, socket: PathBuf, retry: Retry) {
tokio::task::spawn(async move {
let mut envs = task.envs.clone();
envs.insert(WORKER_ID.to_string(), id.clone());
if !task.detached {
envs.insert(
SOCKET.to_string(),
socket.to_string_lossy().into_owned(),
);
}
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<WorkerState>(1);
info!("Spawn worker {} {}", &task.cmd, task.args.join(" "));
let mut child = Command::new(task.cmd.clone())
.args(task.args.clone())
.envs(envs)
.spawn()?;
let pid = child.id();
if let Some(ref id) = pid {
info!("Worker pid {}", id);
}
{
let worker = WorkerState {
id: id.clone(),
task,
socket,
pid,
reap: false,
shutdown: shutdown_tx,
};
let mut state = supervisor_state().lock().await;
state.workers.push(worker);
}
let mut reaping = false;
loop {
tokio::select!(
res = child.wait() => {
match res {
Ok(status) => {
let pid = pid.unwrap_or(0);
if !reaping {
if let Some(code) = status.code() {
warn!("Worker process died: {} (code: {})", pid, code);
} else {
warn!("Worker process died: {} ({})", pid, status);
}
}
let mut state = supervisor_state().lock().await;
let worker = state.remove(&id);
drop(state);
if let Some(worker) = worker {
info!("Removed child worker (id: {}, pid {})", worker.id, pid);
if !worker.reap && worker.task.daemon {
restart(worker, retry).await;
}
} else {
if !reaping {
error!("Failed to remove stale worker for pid {}", pid);
}
}
break;
}
Err(e) => return Err(e),
}
}
mut worker = shutdown_rx.recv() => {
if let Some(mut worker) = worker.take() {
reaping = true;
info!("Shutdown worker {}", worker.id);
worker.reap = true;
child.kill().await?;
}
}
)
}
Ok::<(), io::Error>(())
});
}
async fn listen<P: AsRef<Path>>(
socket: P,
tx: Sender<()>,
handler: Arc<IpcHandler>,
control_tx: mpsc::Sender<Message>,
) -> Result<()> {
let path = socket.as_ref();
if path.exists() {
std::fs::remove_file(path)?;
}
let listener = UnixListener::bind(socket).unwrap();
tx.send(()).unwrap();
loop {
match listener.accept().await {
Ok((stream, _addr)) => (handler)(stream, control_tx.clone()),
Err(e) => {
warn!("Supervisor failed to accept worker socket {}", e);
}
}
}
}