use std::sync::atomic::{AtomicBool, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::{sync::Notify, sync::broadcast, sync::mpsc, time::timeout};
use tokio_util::sync::CancellationToken;
#[cfg(feature = "controller")]
use tokio::sync::OnceCell;
use crate::core::{
alive::AliveTracker,
builder::SupervisorBuilder,
registry::{Registry, RegistryCommand},
};
use crate::{
core::SupervisorConfig,
error::RuntimeError,
events::{Bus, Event, EventKind},
identity::TaskId,
subscribers::{Subscribe, SubscriberSet},
tasks::TaskSpec,
};
pub struct Supervisor {
cfg: SupervisorConfig,
bus: Bus,
subs: Arc<SubscriberSet>,
alive: Arc<AliveTracker>,
registry: Arc<Registry>,
ready: Arc<Notify>,
runtime_token: CancellationToken,
started: AtomicBool,
cmd_tx: mpsc::UnboundedSender<RegistryCommand>,
subscriber_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
#[cfg(feature = "controller")]
pub(super) controller: OnceCell<Arc<crate::controller::Controller>>,
}
impl std::fmt::Debug for Supervisor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Supervisor")
.field("cfg", &self.cfg)
.field("started", &self.started.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
impl Supervisor {
pub(crate) fn new_internal(
cfg: SupervisorConfig,
bus: Bus,
subs: Arc<SubscriberSet>,
alive: Arc<AliveTracker>,
registry: Arc<Registry>,
runtime_token: CancellationToken,
cmd_tx: mpsc::UnboundedSender<RegistryCommand>,
) -> Self {
Self {
cfg,
bus,
subs,
alive,
registry,
runtime_token,
ready: Arc::new(Notify::new()),
started: AtomicBool::new(false),
cmd_tx,
subscriber_handle: std::sync::Mutex::new(None),
#[cfg(feature = "controller")]
controller: OnceCell::new(),
}
}
pub fn new(cfg: SupervisorConfig, subscribers: Vec<Arc<dyn Subscribe>>) -> Arc<Self> {
Self::builder(cfg).with_subscribers(subscribers).build()
}
pub fn builder(cfg: SupervisorConfig) -> SupervisorBuilder {
SupervisorBuilder::new(cfg)
}
pub fn serve(self: &Arc<Self>) -> super::handle::SupervisorHandle {
self.start();
super::handle::SupervisorHandle::new(Arc::clone(self))
}
pub(crate) fn add_task(&self, spec: TaskSpec) -> Result<TaskId, RuntimeError> {
let id = TaskId::next();
self.bus.publish(
Event::new(EventKind::TaskAddRequested)
.with_task(spec.task().name())
.with_id(id),
);
self.cmd_tx
.send(RegistryCommand::Add(id, spec))
.map_err(|_| RuntimeError::ShuttingDown)?;
Ok(id)
}
pub(crate) fn remove(&self, id: TaskId) -> Result<(), RuntimeError> {
self.bus
.publish(Event::new(EventKind::TaskRemoveRequested).with_id(id));
self.cmd_tx
.send(RegistryCommand::Remove(id))
.map_err(|_| RuntimeError::ShuttingDown)
}
pub(crate) async fn list_tasks(&self) -> Vec<(TaskId, Arc<str>)> {
self.registry.list().await
}
pub(crate) async fn contains_id(&self, id: TaskId) -> bool {
self.registry.contains(id).await
}
pub(crate) async fn id_for_label(&self, name: &str) -> Option<TaskId> {
self.registry.id_for_label(name).await
}
pub(crate) fn subscribe_bus(&self) -> broadcast::Receiver<Arc<Event>> {
self.bus.subscribe()
}
pub(crate) fn start(&self) {
if self.started.swap(true, Ordering::AcqRel) {
return;
}
self.subscriber_listener();
self.registry.clone().spawn_listener();
self.ready.notify_waiters();
}
pub async fn run(&self, tasks: Vec<TaskSpec>) -> Result<(), RuntimeError> {
self.start();
if !tasks.is_empty() {
let mut rx = self.bus.subscribe();
let mut pending_ids = Vec::with_capacity(tasks.len());
for spec in tasks {
pending_ids.push(self.add_task(spec)?);
}
self.wait_tasks_registered(&mut rx, &pending_ids).await;
}
self.drive_shutdown().await
}
async fn wait_tasks_registered(
&self,
rx: &mut broadcast::Receiver<Arc<Event>>,
ids: &[TaskId],
) {
let mut pending: Vec<TaskId> = ids.to_vec();
let confirm = async {
while !pending.is_empty() {
match rx.recv().await {
Ok(ev)
if matches!(ev.kind, EventKind::TaskAdded | EventKind::TaskAddFailed) =>
{
if let Some(id) = ev.id {
pending.retain(|p| *p != id);
}
}
Ok(_) => {}
Err(broadcast::error::RecvError::Lagged(_)) => {
let mut still = Vec::new();
for id in std::mem::take(&mut pending) {
if !self.registry.contains(id).await {
still.push(id);
}
}
pending = still;
}
Err(broadcast::error::RecvError::Closed) => break,
}
}
};
const CONFIRM_BACKSTOP: Duration = Duration::from_secs(5);
let _ = timeout(CONFIRM_BACKSTOP, confirm).await;
}
pub(crate) async fn shutdown(&self) -> Result<(), RuntimeError> {
self.bus.publish(Event::new(EventKind::ShutdownRequested));
let res = self.drain_with_grace().await;
self.runtime_token.cancel();
self.join_subscriber_listener().await;
self.subs.close().await;
res
}
pub(crate) async fn snapshot(&self) -> Vec<Arc<str>> {
self.alive.snapshot().await
}
pub(crate) async fn is_alive(&self, name: &str) -> bool {
self.alive.is_alive(name).await
}
pub(crate) async fn cancel(&self, id: TaskId) -> Result<bool, RuntimeError> {
self.cancel_with_timeout(id, self.cfg.grace).await
}
pub(crate) async fn cancel_with_timeout(
&self,
id: TaskId,
wait_for: Duration,
) -> Result<bool, RuntimeError> {
let mut rx = self.bus.subscribe();
if !self.registry.contains(id).await {
return Ok(false);
}
self.bus.publish(
Event::new(EventKind::TaskRemoveRequested)
.with_id(id)
.with_reason("manual_cancel"),
);
self.cmd_tx
.send(RegistryCommand::Remove(id))
.map_err(|_| RuntimeError::ShuttingDown)?;
self.wait_task_removed(&mut rx, id, wait_for).await
}
#[cfg(feature = "controller")]
pub(crate) async fn submit(
&self,
spec: crate::controller::ControllerSpec,
) -> Result<(), crate::controller::ControllerError> {
match self.controller.get() {
Some(ctrl) => ctrl.handle().submit(spec).await,
None => Err(crate::controller::ControllerError::NotConfigured),
}
}
#[cfg(feature = "controller")]
pub(crate) fn try_submit(
&self,
spec: crate::controller::ControllerSpec,
) -> Result<(), crate::controller::ControllerError> {
match self.controller.get() {
Some(ctrl) => ctrl.handle().try_submit(spec),
None => Err(crate::controller::ControllerError::NotConfigured),
}
}
fn subscriber_listener(&self) {
let mut rx = self.bus.subscribe();
let set = Arc::clone(&self.subs);
let alive = Arc::clone(&self.alive);
let rt = self.runtime_token.clone();
let handle = tokio::spawn(async move {
loop {
tokio::select! {
biased;
msg = rx.recv() => match msg {
Ok(arc_ev) => {
alive.update(&arc_ev).await;
set.emit_arc(arc_ev);
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
let e = Event::new(EventKind::SubscriberOverflow)
.with_task("subscriber_listener")
.with_reason(format!("lagged({skipped})"));
let arc_e = Arc::new(e);
alive.update(&arc_e).await;
set.emit_arc(arc_e);
}
Err(broadcast::error::RecvError::Closed) => break,
},
_ = rt.cancelled() => {
while let Ok(arc_ev) = rx.try_recv() {
alive.update(&arc_ev).await;
set.emit_arc(arc_ev);
}
break;
}
}
}
});
*self.subscriber_handle.lock().unwrap() = Some(handle);
}
async fn join_subscriber_listener(&self) {
let handle = self.subscriber_handle.lock().unwrap().take();
if let Some(handle) = handle {
let _ = handle.await;
}
}
async fn drive_shutdown(&self) -> Result<(), RuntimeError> {
let res = tokio::select! {
_ = crate::core::shutdown::wait_for_shutdown_signal() => {
self.bus.publish(Event::new(EventKind::ShutdownRequested));
self.drain_with_grace().await
}
_ = self.registry.wait_until_empty() => {
Ok(())
}
};
self.runtime_token.cancel();
self.join_subscriber_listener().await;
self.subs.close().await;
res
}
async fn drain_with_grace(&self) -> Result<(), RuntimeError> {
let grace = self.cfg.grace;
let stuck = self.registry.cancel_all_within(grace).await;
if stuck.is_empty() {
self.bus
.publish(Event::new(EventKind::AllStoppedWithinGrace));
Ok(())
} else {
self.bus.publish(Event::new(EventKind::GraceExceeded));
Err(RuntimeError::GraceExceeded { grace, stuck })
}
}
async fn wait_task_removed(
&self,
rx: &mut broadcast::Receiver<Arc<Event>>,
id: TaskId,
wait_for: Duration,
) -> Result<bool, RuntimeError> {
let wait_for_event = async {
loop {
match rx.recv().await {
Ok(ev) if matches!(ev.kind, EventKind::TaskRemoved) && ev.id == Some(id) => {
return true;
}
Ok(_) => {}
Err(broadcast::error::RecvError::Lagged(_)) => {
if self.registry.is_terminated(id).await {
return true;
}
}
Err(broadcast::error::RecvError::Closed) => {
return self.registry.is_terminated(id).await;
}
}
}
};
match timeout(wait_for, wait_for_event).await {
Ok(true) => Ok(true),
Ok(false) | Err(_) => {
if self.registry.is_terminated(id).await {
Ok(true)
} else {
Err(RuntimeError::TaskRemoveTimeout {
id,
timeout: wait_for,
})
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{TaskFn, TaskRef};
#[tokio::test]
async fn shutdown_force_terminates_noncooperative_task_within_grace() {
let cfg = SupervisorConfig {
grace: Duration::from_millis(200),
..Default::default()
};
let sup = Supervisor::new(cfg, vec![]);
let handle = sup.serve();
let stubborn: TaskRef = TaskFn::arc("stubborn", |_ctx: CancellationToken| async move {
tokio::time::sleep(Duration::from_secs(30)).await;
Ok(())
});
handle
.add_and_wait(TaskSpec::once(stubborn), Duration::from_secs(1))
.await
.expect("task should register");
let res = tokio::time::timeout(Duration::from_secs(5), handle.shutdown())
.await
.expect("shutdown must return within grace, not block on the stuck task");
match res {
Err(RuntimeError::GraceExceeded { stuck, .. }) => {
assert!(
stuck.iter().any(|n| &**n == "stubborn"),
"stuck set must list the non-cooperative task, got {stuck:?}"
);
}
other => panic!("expected GraceExceeded listing the stuck task, got {other:?}"),
}
}
#[tokio::test]
async fn shutdown_cooperative_task_returns_ok() {
let cfg = SupervisorConfig {
grace: Duration::from_secs(5),
..Default::default()
};
let sup = Supervisor::new(cfg, vec![]);
let handle = sup.serve();
let good: TaskRef = TaskFn::arc("good", |ctx: CancellationToken| async move {
ctx.cancelled().await;
Ok(())
});
handle
.add_and_wait(TaskSpec::restartable(good), Duration::from_secs(1))
.await
.expect("task should register");
let res = timeout(Duration::from_secs(5), handle.shutdown())
.await
.expect("cooperative shutdown must not hang");
assert!(
res.is_ok(),
"cooperative shutdown should be Ok, got {res:?}"
);
}
#[tokio::test]
async fn add_and_wait_duplicate_name_returns_already_exists() {
let sup = Supervisor::new(SupervisorConfig::default(), vec![]);
let handle = sup.serve();
let make = || -> TaskRef {
TaskFn::arc("dup", |ctx: CancellationToken| async move {
ctx.cancelled().await;
Ok(())
})
};
handle
.add_and_wait(TaskSpec::restartable(make()), Duration::from_secs(1))
.await
.expect("first add should succeed");
let res = handle
.add_and_wait(TaskSpec::restartable(make()), Duration::from_secs(1))
.await;
assert!(
matches!(res, Err(RuntimeError::TaskAlreadyExists { .. })),
"duplicate add must return TaskAlreadyExists, got {res:?}"
);
let _ = handle.shutdown().await;
}
#[tokio::test]
async fn wait_task_removed_survives_bus_lag() {
let cfg = SupervisorConfig {
bus_capacity: 8,
..Default::default()
};
let sup = Supervisor::new(cfg, vec![]);
let handle = sup.serve();
let t: TaskRef = TaskFn::arc("laggy", |ctx: CancellationToken| async move {
ctx.cancelled().await;
Ok(())
});
let id = handle
.add_and_wait(TaskSpec::restartable(t), Duration::from_secs(1))
.await
.expect("add should succeed");
let mut lagged_rx = sup.subscribe_bus();
let mut observer = sup.subscribe_bus();
sup.remove(id).expect("remove should be accepted");
let observed = timeout(Duration::from_secs(2), async {
loop {
if let Ok(ev) = observer.recv().await
&& ev.kind == EventKind::TaskRemoved
&& ev.id == Some(id)
{
return;
}
}
})
.await;
observed.expect("TaskRemoved must be observed by a healthy receiver");
for _ in 0..32 {
sup.bus
.publish(Event::new(EventKind::TaskStarting).with_task("noise"));
}
let res = timeout(
Duration::from_secs(1),
sup.wait_task_removed(&mut lagged_rx, id, Duration::from_millis(300)),
)
.await
.expect("wait_task_removed must not hang");
assert!(
matches!(res, Ok(true)),
"a lagged receiver must fall back to runtime state instead of reporting \
a spurious TaskRemoveTimeout, got {res:?}"
);
let _ = handle.shutdown().await;
}
#[tokio::test]
async fn cancel_reports_removed_then_absent() {
let sup = Supervisor::new(SupervisorConfig::default(), vec![]);
let handle = sup.serve();
let t: TaskRef = TaskFn::arc("c", |ctx: CancellationToken| async move {
ctx.cancelled().await;
Ok(())
});
let id = handle
.add_and_wait(TaskSpec::restartable(t), Duration::from_secs(1))
.await
.expect("add should succeed");
let removed = handle.cancel(id).await.expect("cancel should not error");
assert!(
removed,
"cancelling an existing task should report removed=true"
);
let absent = handle.cancel(id).await.expect("cancel should not error");
assert!(!absent, "cancelling a missing task should report false");
let _ = handle.shutdown().await;
}
}