use std::sync::atomic::{AtomicBool, Ordering};
use std::{sync::Arc, time::Duration};
use tokio::{sync::Notify, sync::broadcast, sync::mpsc, time::timeout};
use tokio_util::sync::CancellationToken;
#[cfg(feature = "controller")]
use tokio::sync::OnceCell;
use crate::core::{
alive::AliveTracker,
builder::SupervisorBuilder,
registry::{Registry, RegistryCommand},
};
use crate::{
core::SupervisorConfig,
error::RuntimeError,
events::{Bus, Event, EventKind},
subscribers::{Subscribe, SubscriberSet},
tasks::TaskSpec,
};
pub struct Supervisor {
cfg: SupervisorConfig,
bus: Bus,
subs: Arc<SubscriberSet>,
alive: Arc<AliveTracker>,
registry: Arc<Registry>,
ready: Arc<Notify>,
runtime_token: CancellationToken,
started: AtomicBool,
cmd_tx: mpsc::UnboundedSender<RegistryCommand>,
#[cfg(feature = "controller")]
pub(super) controller: OnceCell<Arc<crate::controller::Controller>>,
}
impl std::fmt::Debug for Supervisor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Supervisor")
.field("cfg", &self.cfg)
.field("started", &self.started.load(Ordering::Relaxed))
.finish_non_exhaustive()
}
}
impl Supervisor {
pub(crate) fn new_internal(
cfg: SupervisorConfig,
bus: Bus,
subs: Arc<SubscriberSet>,
alive: Arc<AliveTracker>,
registry: Arc<Registry>,
runtime_token: CancellationToken,
cmd_tx: mpsc::UnboundedSender<RegistryCommand>,
) -> Self {
Self {
cfg,
bus,
subs,
alive,
registry,
runtime_token,
ready: Arc::new(Notify::new()),
started: AtomicBool::new(false),
cmd_tx,
#[cfg(feature = "controller")]
controller: OnceCell::new(),
}
}
pub fn new(cfg: SupervisorConfig, subscribers: Vec<Arc<dyn Subscribe>>) -> Arc<Self> {
Self::builder(cfg).with_subscribers(subscribers).build()
}
pub fn builder(cfg: SupervisorConfig) -> SupervisorBuilder {
SupervisorBuilder::new(cfg)
}
pub fn serve(self: &Arc<Self>) -> super::handle::SupervisorHandle {
self.start();
super::handle::SupervisorHandle::new(Arc::clone(self))
}
pub(crate) fn add_task(&self, spec: TaskSpec) -> Result<(), RuntimeError> {
self.bus
.publish(Event::new(EventKind::TaskAddRequested).with_task(spec.task().name()));
self.cmd_tx
.send(RegistryCommand::Add(spec))
.map_err(|_| RuntimeError::ShuttingDown)
}
pub(crate) fn remove_task(&self, name: &str) -> Result<(), RuntimeError> {
self.bus
.publish(Event::new(EventKind::TaskRemoveRequested).with_task(name));
self.cmd_tx
.send(RegistryCommand::Remove(Arc::from(name)))
.map_err(|_| RuntimeError::ShuttingDown)
}
pub(crate) async fn list_tasks(&self) -> Vec<Arc<str>> {
self.registry.list().await
}
pub(crate) async fn registry_contains(&self, name: &str) -> bool {
self.registry.contains(name).await
}
pub(crate) fn subscribe_bus(&self) -> broadcast::Receiver<Arc<Event>> {
self.bus.subscribe()
}
pub(crate) fn start(&self) {
if self.started.swap(true, Ordering::AcqRel) {
return;
}
self.subscriber_listener();
self.registry.clone().spawn_listener();
self.ready.notify_waiters();
}
pub async fn run(&self, tasks: Vec<TaskSpec>) -> Result<(), RuntimeError> {
self.start();
let expected = tasks.len();
if expected > 0 {
let mut rx = self.bus.subscribe();
for spec in tasks {
self.add_task(spec)?;
}
self.wait_tasks_registered(&mut rx, expected).await;
}
self.drive_shutdown().await
}
async fn wait_tasks_registered(
&self,
rx: &mut broadcast::Receiver<Arc<Event>>,
expected: usize,
) {
let mut registered = 0usize;
while registered < expected {
match rx.recv().await {
Ok(ev) if ev.kind == EventKind::TaskAdded => {
registered += 1;
}
Ok(_) => continue,
Err(broadcast::error::RecvError::Lagged(_)) => break,
Err(broadcast::error::RecvError::Closed) => break,
}
}
}
pub(crate) async fn shutdown(&self) -> Result<(), RuntimeError> {
self.bus.publish(Event::new(EventKind::ShutdownRequested));
self.registry.cancel_all().await;
let res = self.wait_all_with_grace().await;
self.runtime_token.cancel();
self.subs.close().await;
res
}
pub(crate) async fn snapshot(&self) -> Vec<Arc<str>> {
self.alive.snapshot().await
}
pub(crate) async fn is_alive(&self, name: &str) -> bool {
self.alive.is_alive(name).await
}
pub(crate) async fn cancel(&self, name: &str) -> Result<bool, RuntimeError> {
self.cancel_with_timeout(name, self.cfg.grace).await
}
pub(crate) async fn cancel_with_timeout(
&self,
name: &str,
wait_for: Duration,
) -> Result<bool, RuntimeError> {
let exists_before = self.registry.contains(name).await;
if !exists_before {
return Ok(false);
}
let mut rx = self.bus.subscribe();
self.bus.publish(
Event::new(EventKind::TaskRemoveRequested)
.with_task(name)
.with_reason("manual_cancel"),
);
self.cmd_tx
.send(RegistryCommand::Remove(Arc::from(name)))
.map_err(|_| RuntimeError::ShuttingDown)?;
self.wait_task_removed(&mut rx, name, wait_for).await
}
#[cfg(feature = "controller")]
pub(crate) async fn submit(
&self,
spec: crate::controller::ControllerSpec,
) -> Result<(), crate::controller::ControllerError> {
match self.controller.get() {
Some(ctrl) => ctrl.handle().submit(spec).await,
None => Err(crate::controller::ControllerError::NotConfigured),
}
}
#[cfg(feature = "controller")]
pub(crate) fn try_submit(
&self,
spec: crate::controller::ControllerSpec,
) -> Result<(), crate::controller::ControllerError> {
match self.controller.get() {
Some(ctrl) => ctrl.handle().try_submit(spec),
None => Err(crate::controller::ControllerError::NotConfigured),
}
}
fn subscriber_listener(&self) {
let mut rx = self.bus.subscribe();
let set = Arc::clone(&self.subs);
let alive = Arc::clone(&self.alive);
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(arc_ev) => {
alive.update(&arc_ev).await;
set.emit_arc(arc_ev);
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
let e = Event::new(EventKind::SubscriberOverflow)
.with_task("subscriber_listener")
.with_reason(format!("lagged({skipped})"));
let arc_e = Arc::new(e);
alive.update(&arc_e).await;
set.emit_arc(arc_e);
}
Err(broadcast::error::RecvError::Closed) => {
break;
}
}
}
});
}
async fn drive_shutdown(&self) -> Result<(), RuntimeError> {
let res = tokio::select! {
_ = crate::core::shutdown::wait_for_shutdown_signal() => {
self.bus.publish(Event::new(EventKind::ShutdownRequested));
self.registry.cancel_all().await;
self.wait_all_with_grace().await
}
_ = self.registry.wait_until_empty() => {
Ok(())
}
};
self.runtime_token.cancel();
self.subs.close().await;
res
}
async fn wait_all_with_grace(&self) -> Result<(), RuntimeError> {
let grace = self.cfg.grace;
match timeout(grace, self.registry.wait_until_empty()).await {
Ok(_) => {
self.bus
.publish(Event::new(EventKind::AllStoppedWithinGrace));
Ok(())
}
Err(_) => {
self.bus.publish(Event::new(EventKind::GraceExceeded));
let stuck = self.snapshot().await;
Err(RuntimeError::GraceExceeded { grace, stuck })
}
}
}
async fn wait_task_removed(
&self,
rx: &mut broadcast::Receiver<Arc<Event>>,
name: &str,
wait_for: Duration,
) -> Result<bool, RuntimeError> {
let target: Arc<str> = Arc::from(name);
let wait_for_event = async {
loop {
match rx.recv().await {
Ok(ev)
if matches!(ev.kind, EventKind::TaskRemoved)
&& ev.task.as_deref() == Some(&*target) =>
{
return Ok(true);
}
Ok(_) => {}
Err(broadcast::error::RecvError::Lagged(_)) => {
if !self.registry.contains(&target).await {
return Ok(true);
}
}
Err(broadcast::error::RecvError::Closed) => {
return Ok(!self.registry.contains(&target).await);
}
}
}
};
match timeout(wait_for, wait_for_event).await {
Ok(result) => result,
Err(_) => Err(RuntimeError::TaskRemoveTimeout {
name: target,
timeout: wait_for,
}),
}
}
}