use std::{
borrow::Cow,
sync::{
Arc, Mutex,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
time::Duration,
};
use futures::FutureExt as _;
use once_cell::sync::OnceCell;
use std::panic::AssertUnwindSafe;
use tokio::{
sync::mpsc,
task::{JoinHandle, JoinSet},
time::Instant,
};
#[cfg(feature = "tracing")]
use tracing::Instrument;
use crate::{CtrlFuture, ProcFuture, ProcessControlHandler, Runnable, RuntimeError};
static PID: AtomicUsize = AtomicUsize::new(0);
struct Child {
id: usize,
#[allow(dead_code)]
proc: Arc<dyn Runnable>,
handle: Arc<dyn ProcessControlHandler>,
join_handle: Arc<JoinHandle<()>>,
}
type ProcessCompletionChannel =
tokio::sync::Mutex<mpsc::UnboundedReceiver<(usize, Result<(), RuntimeError>)>>;
struct Inner {
processes: Mutex<Vec<Child>>,
handles: Mutex<Vec<Arc<dyn ProcessControlHandler>>>,
running: AtomicBool,
next_id: AtomicUsize,
active: AtomicUsize,
completion_tx: mpsc::UnboundedSender<(usize, Result<(), RuntimeError>)>,
completion_rx: OnceCell<ProcessCompletionChannel>,
shutdown_grace_period: Duration,
}
pub struct ProcessManager {
id: usize,
pre_start: Vec<Arc<dyn Runnable>>,
inner: Arc<Inner>,
pub(crate) custom_name: Option<Cow<'static, str>>,
pub(crate) auto_cleanup: bool,
}
impl ProcessManager {
pub fn new() -> Self {
let id = PID.fetch_add(1, Ordering::SeqCst);
let (tx, rx) = mpsc::unbounded_channel();
Self {
id,
pre_start: Vec::new(),
inner: Arc::new(Inner {
processes: Mutex::new(Vec::new()),
handles: Mutex::new(Vec::new()),
running: AtomicBool::new(false),
next_id: AtomicUsize::new(0),
active: AtomicUsize::new(0),
completion_tx: tx,
completion_rx: {
let cell = OnceCell::new();
let _ = cell.set(tokio::sync::Mutex::new(rx));
cell
},
shutdown_grace_period: Duration::from_secs(30),
}),
custom_name: None,
auto_cleanup: true,
}
}
pub fn shutdown_grace_period(&self) -> Duration {
self.inner.shutdown_grace_period
}
pub(crate) fn set_shutdown_grace_period(&mut self, duration: Duration) {
let inner = Arc::get_mut(&mut self.inner)
.expect("inner must be uniquely owned during manager setup");
inner.shutdown_grace_period = duration;
}
pub fn insert(&mut self, process: impl Runnable) {
assert!(
!self.inner.running.load(Ordering::SeqCst),
"cannot call insert() after manager has started – use add() instead"
);
self.pre_start
.push(Arc::from(Box::new(process) as Box<dyn Runnable>));
}
pub fn add(&self, process: impl Runnable) {
let proc: Arc<dyn Runnable> = Arc::from(Box::new(process) as Box<dyn Runnable>);
assert!(
self.inner.running.load(Ordering::SeqCst),
"cannot call add() before manager has started – use insert() instead"
);
let id = self.inner.next_id.fetch_add(1, Ordering::SeqCst);
let handle = proc.process_handle();
self.inner.handles.lock().unwrap().push(Arc::clone(&handle));
{
let mut guard = self.inner.processes.lock().unwrap();
guard.push(Child {
id,
proc: Arc::clone(&proc),
handle,
join_handle: Arc::new(spawn_child(id, proc, Arc::clone(&self.inner))),
});
}
}
}
impl Runnable for ProcessManager {
fn process_start(&self) -> ProcFuture<'_> {
let inner = Arc::clone(&self.inner);
let auto_cleanup = self.auto_cleanup;
let initial = self.pre_start.clone();
let manager_handle = self.process_handle();
Box::pin(async move {
inner.running.store(true, Ordering::SeqCst);
let name = self.process_name();
#[cfg(feature = "tracing")]
::tracing::info!("Start process manager {name}");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Start process manager {name}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Start process manager {name}");
for proc in initial {
let id = inner.next_id.fetch_add(1, Ordering::SeqCst);
let handle = proc.process_handle();
{
let mut g = inner.processes.lock().unwrap();
g.push(Child {
id,
proc: Arc::clone(&proc),
handle: Arc::clone(&handle),
join_handle: Arc::new(spawn_child(id, proc, Arc::clone(&inner))),
});
inner.handles.lock().unwrap().push(handle);
}
}
let completion_rx = inner
.completion_rx
.get()
.expect("process_start called twice");
let mut completion_rx = completion_rx.lock().await;
let mut first_error: Option<RuntimeError> = None;
loop {
#[cfg(feature = "tracing")]
{
for child in self.inner.processes.lock().unwrap().iter() {
::tracing::debug!(
"Process {}: running={:?}",
child.proc.process_name(),
!child.join_handle.is_finished()
);
}
}
if inner.active.load(Ordering::SeqCst) == 0 {
inner.running.store(false, Ordering::SeqCst);
break;
}
match completion_rx.recv().await {
Some((cid, res)) => {
match res {
Ok(()) => {
if auto_cleanup {
let mut g = inner.processes.lock().unwrap();
g.retain(|c| c.id != cid);
inner
.handles
.lock()
.unwrap()
.retain(|h| g.iter().any(|c| Arc::ptr_eq(&c.handle, h)));
}
}
Err(err) => {
if first_error.is_none() {
first_error = Some(err);
manager_handle.shutdown().await;
}
}
}
inner.active.fetch_sub(1, Ordering::SeqCst);
}
None => {
return Err(RuntimeError::Internal {
message: "completion channel closed unexpectedly".into(),
});
}
}
}
match first_error {
Some(error) => {
#[cfg(feature = "tracing")]
::tracing::warn!("Shutdown process manager {name} with error: {error:?}");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::warn!("Shutdown process manager {name} with error: {error:?}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Shutdown process manager {name} with error: {error:?}");
Err(error)
}
None => {
#[cfg(feature = "tracing")]
::tracing::info!("Shutdown process manager {name}");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Shutdown process manager {name}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Shutdown process manager {name}");
Ok(())
}
}
})
}
fn process_name(&self) -> Cow<'static, str> {
if let Some(ref name) = self.custom_name {
name.clone()
} else {
format!("process-manager-{}", self.id).into()
}
}
fn process_handle(&self) -> Arc<dyn ProcessControlHandler> {
Arc::new(Handle {
inner: Arc::clone(&self.inner),
})
}
}
impl Default for ProcessManager {
fn default() -> Self {
Self::new()
}
}
struct Handle {
inner: Arc<Inner>,
}
impl ProcessControlHandler for Handle {
fn shutdown(&self) -> CtrlFuture<'_> {
let inner = Arc::clone(&self.inner);
Box::pin(async move {
let mut set = JoinSet::new();
let handles = {
let guard = inner.processes.lock().unwrap();
guard
.iter()
.map(|child| {
(
child.proc.process_name(),
child.handle.clone(),
child.join_handle.clone(),
)
})
.collect::<Vec<_>>()
};
let shutdown_grace_period = inner.shutdown_grace_period;
for (name, h, jh) in handles {
set.spawn(async move {
#[cfg(feature = "tracing")]
::tracing::info!(name = %name, "Initiate shutdown");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Initiate shutdown {name}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Initiate shutdown {name}");
let dur = shutdown_grace_period;
let now = Instant::now();
let watched = Arc::clone(&jh);
let timeout = tokio::time::timeout(dur, async move {
h.shutdown().await;
while !watched.is_finished() {
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await;
let _elapsed = now.elapsed();
match timeout {
Ok(_) => {
#[cfg(feature = "tracing")]
::tracing::info!(name = %name, elapsed = ?_elapsed, "Shutdown completed");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Process {name}: shutdown completed");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Process {name}: shutdown completed");
}
Err(_) => {
jh.abort();
#[cfg(feature = "tracing")]
::tracing::info!(name = %name, elapsed = ?_elapsed, "Shutdown timed out");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Process {name}: Shutdown timed out after {dur:?}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Process {name}: Shutdown timed out after {dur:?}");
}
}
});
}
let _ = set.join_all().await;
})
}
fn reload(&self) -> CtrlFuture<'_> {
let inner = Arc::clone(&self.inner);
Box::pin(async move {
let mut set = JoinSet::new();
let handles = {
let guard = inner.handles.lock().unwrap();
guard.clone()
};
for h in handles {
set.spawn(async move {
h.reload().await;
});
}
let _ = set.join_all().await;
})
}
}
fn spawn_child(id: usize, proc: Arc<dyn Runnable>, inner: Arc<Inner>) -> JoinHandle<()> {
inner.active.fetch_add(1, Ordering::SeqCst);
let tx = inner.completion_tx.clone();
tokio::spawn(async move {
let name = proc.process_name();
#[cfg(feature = "tracing")]
::tracing::info!(name = %name, "Start process");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Start process {name}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Start process {name}");
let catch_fut = AssertUnwindSafe(proc.process_start()).catch_unwind();
#[cfg(feature = "tracing")]
let catch_result = {
let span = ::tracing::info_span!("process", name = %name);
catch_fut.instrument(span).await
};
#[cfg(not(feature = "tracing"))]
let catch_result = { catch_fut.await };
let res = catch_result.unwrap_or_else(|panic| {
let msg = if let Some(s) = panic.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
};
Err(RuntimeError::Internal {
message: format!("process panicked: {msg}"),
})
});
match &res {
Ok(_) => {
#[cfg(feature = "tracing")]
::tracing::info!(name = %name, "Process stopped");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::info!("Process {name}: stopped");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Process {name}: stopped");
}
Err(err) => {
#[cfg(feature = "tracing")]
::tracing::error!(name = %name, "Process failed: {err:?}");
#[cfg(all(not(feature = "tracing"), feature = "log"))]
::log::error!("Process {name}: failed {err:?}");
#[cfg(all(not(feature = "tracing"), not(feature = "log")))]
eprintln!("Process {name}: failed {err:?}");
}
}
let _ = tx.send((id, res)); })
}