use std::{
collections::HashMap,
sync::Arc,
time::{Duration, Instant},
};
use thiserror::Error as ThisError;
use tokio::{
sync::watch,
task::{Id as TaskId, JoinError, JoinSet},
time::{sleep, timeout},
};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use crate::{
policy::{ErrorAction, ExitAction, FailureAction, RestartPolicy, ServicePolicy},
readiness::{ReadinessMode, ReadinessTracker, ReadySignal, SupervisorReadiness},
service::{BoxFuture, ServiceError, ServiceOutcome, SupervisedService},
};
pub trait FromSupervisorState<S>: Sized {
fn from_state(state: &S) -> Self;
}
impl<S> FromSupervisorState<S> for S
where
S: Clone,
{
fn from_state(state: &S) -> Self {
state.clone()
}
}
#[derive(Clone, Debug)]
pub struct Context<C> {
token: CancellationToken,
readiness: ReadySignal,
ctx: C,
}
impl<C> Context<C> {
pub fn new(token: CancellationToken, ctx: C) -> Self {
Self {
token,
readiness: ReadySignal::immediate(),
ctx,
}
}
pub(crate) fn with_readiness(token: CancellationToken, readiness: ReadySignal, ctx: C) -> Self {
Self {
token,
readiness,
ctx,
}
}
pub fn token(&self) -> &CancellationToken {
&self.token
}
#[cfg(test)]
pub(crate) fn begin_teardown(&self) {
self.token.cancel();
}
pub fn readiness(&self) -> &ReadySignal {
&self.readiness
}
pub fn ctx(&self) -> &C {
&self.ctx
}
pub fn into_inner(self) -> C {
self.ctx
}
pub fn map<D>(self, map: impl FnOnce(C) -> D) -> Context<D> {
Context {
token: self.token,
readiness: self.readiness,
ctx: map(self.ctx),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Options {
policy: Option<ServicePolicy>,
readiness: Option<ReadinessMode>,
}
impl Options {
pub fn new() -> Self {
Self {
policy: None,
readiness: None,
}
}
pub fn policy(mut self, policy: ServicePolicy) -> Self {
self.policy = Some(policy);
self
}
pub fn readiness(mut self, readiness: ReadinessMode) -> Self {
self.readiness = Some(readiness);
self
}
}
impl Default for Options {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ShutdownCause {
Completed,
Requested,
Signal,
ServiceRequested { service: &'static str },
FatalService { service: &'static str },
ReadinessFailed { service: &'static str },
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ServiceSummary {
name: &'static str,
outcome: ServiceOutcome,
restarts: usize,
}
impl ServiceSummary {
pub fn name(&self) -> &'static str {
self.name
}
pub fn outcome(&self) -> &ServiceOutcome {
&self.outcome
}
pub fn restarts(&self) -> usize {
self.restarts
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RunSummary {
shutdown_cause: ShutdownCause,
readiness: SupervisorReadiness,
services: Vec<ServiceSummary>,
}
impl RunSummary {
pub fn new(
shutdown_cause: ShutdownCause,
readiness: SupervisorReadiness,
services: Vec<ServiceSummary>,
) -> Self {
Self {
shutdown_cause,
readiness,
services,
}
}
pub fn shutdown_cause(&self) -> &ShutdownCause {
&self.shutdown_cause
}
pub fn readiness(&self) -> SupervisorReadiness {
self.readiness
}
pub fn services(&self) -> &[ServiceSummary] {
&self.services
}
pub fn service(&self, name: &str) -> Option<&ServiceSummary> {
self.services.iter().find(|service| service.name == name)
}
}
pub struct SupervisorBuilder<S> {
state: S,
registrations: Vec<Registration>,
shutdown_timeout: Duration,
default_restart_policy: RestartPolicy,
}
impl<S> SupervisorBuilder<S> {
pub fn new(state: S) -> Self {
Self {
state,
registrations: Vec::new(),
shutdown_timeout: Duration::from_secs(5),
default_restart_policy: RestartPolicy::never(ErrorAction::Shutdown),
}
}
pub fn shutdown_timeout(mut self, timeout: Duration) -> Self {
self.shutdown_timeout = timeout;
self
}
pub fn default_restart_policy(mut self, restart: RestartPolicy) -> Self {
self.default_restart_policy = restart;
self
}
pub fn shutdown_on_ctrl_c(mut self) -> Self {
self.registrations.push(Registration::ctrl_c());
self
}
pub fn add<T>(self, service: T) -> Self
where
T: SupervisedService,
T::Context: FromSupervisorState<S>,
{
self.add_with_options(service, Options::new())
}
pub fn add_with_options<T>(mut self, service: T, options: Options) -> Self
where
T: SupervisedService,
T::Context: FromSupervisorState<S>,
{
let policy = options.policy.unwrap_or_else(|| {
ServicePolicy::new(ExitAction::Ignore, self.default_restart_policy.clone())
});
let readiness = options.readiness.unwrap_or_else(|| service.readiness());
let ctx = T::Context::from_state(&self.state);
self.registrations
.push(Registration::new(service, ctx, policy, readiness));
self
}
pub fn build(self) -> Supervisor {
let readiness = ReadinessTracker::new(
self.registrations
.iter()
.map(|registration| registration.readiness),
);
Supervisor {
registrations: self.registrations,
shutdown_timeout: self.shutdown_timeout,
readiness,
}
}
}
pub struct Supervisor {
registrations: Vec<Registration>,
shutdown_timeout: Duration,
readiness: ReadinessTracker,
}
impl Supervisor {
pub fn readiness(&self) -> watch::Receiver<SupervisorReadiness> {
self.readiness.subscribe()
}
#[instrument(skip_all)]
pub async fn run(self) -> Result<RunSummary, Error> {
let Self {
registrations,
shutdown_timeout,
readiness,
} = self;
if registrations.is_empty() {
return Err(Error::NoServices);
}
let root = CancellationToken::new();
let mut tasks = JoinSet::new();
let mut running = HashMap::new();
let mut states = registrations
.into_iter()
.map(ServiceState::new)
.collect::<Vec<_>>();
for (index, state) in states.iter().enumerate() {
spawn_service(
&mut tasks,
&mut running,
&state.registration,
index,
root.clone(),
readiness.signal(index, state.registration.readiness),
None,
);
}
let mut shutdown = None;
loop {
let Some(joined) = tasks.join_next_with_id().await else {
let cause = shutdown.unwrap_or(ShutdownCause::Completed);
return Ok(RunSummary::new(
cause,
readiness.state(),
summarize(&states),
));
};
let result = join_result(joined, &mut running, &mut states)?;
let state = state_mut(&mut states, result.index)?;
state.last_outcome = result.outcome.clone();
if shutdown.is_some() {
continue;
}
if matches!(result.outcome, ServiceOutcome::Completed) {
readiness.mark_ready(result.index);
}
if !readiness.is_ready(result.index) {
shutdown = Some(ShutdownCause::ReadinessFailed {
service: state.registration.name,
});
root.cancel();
}
if shutdown.is_none() {
match result.outcome {
ServiceOutcome::Completed => match state.registration.policy.on_completed() {
ExitAction::Ignore => {},
ExitAction::Restart => {
state.restarts += 1;
spawn_service(
&mut tasks,
&mut running,
&state.registration,
result.index,
root.clone(),
readiness.signal(result.index, state.registration.readiness),
None,
);
continue;
},
ExitAction::Shutdown => {
shutdown = Some(ShutdownCause::Requested);
root.cancel();
},
},
ServiceOutcome::Cancelled => {},
ServiceOutcome::RequestedShutdown => {
shutdown = Some(state.registration.shutdown_request.cause());
root.cancel();
},
ServiceOutcome::Error(_) => {
match state.registration.policy.restart().action(state.restarts) {
FailureAction::Restart { backoff } => {
state.restarts += 1;
spawn_service(
&mut tasks,
&mut running,
&state.registration,
result.index,
root.clone(),
readiness.signal(result.index, state.registration.readiness),
Some(backoff),
);
continue;
},
FailureAction::Terminal(ErrorAction::Ignore) => {},
FailureAction::Terminal(ErrorAction::Shutdown) => {
shutdown = Some(ShutdownCause::FatalService {
service: state.registration.name,
});
root.cancel();
},
}
},
}
}
if let Some(cause) = shutdown.take() {
return finish_shutdown(shutdown_timeout, tasks, running, states, readiness, cause)
.await;
}
}
}
}
#[derive(Clone)]
struct Registration {
name: &'static str,
policy: ServicePolicy,
readiness: ReadinessMode,
shutdown_request: ShutdownRequest,
runner: Arc<dyn Runner>,
}
impl Registration {
fn new<S>(service: S, ctx: S::Context, policy: ServicePolicy, readiness: ReadinessMode) -> Self
where
S: SupervisedService,
{
let name = service.name();
Self {
name,
policy,
readiness,
shutdown_request: ShutdownRequest::Service { service: name },
runner: Arc::new(ServiceRunner {
service: Arc::new(service),
ctx,
}),
}
}
fn ctrl_c() -> Self {
Self::new(
CtrlC,
(),
ServicePolicy::new(
ExitAction::Shutdown,
RestartPolicy::never(ErrorAction::Shutdown),
),
ReadinessMode::Immediate,
)
.with_shutdown_request(ShutdownRequest::Signal)
}
fn with_shutdown_request(mut self, shutdown_request: ShutdownRequest) -> Self {
self.shutdown_request = shutdown_request;
self
}
fn run(
&self,
token: CancellationToken,
readiness: ReadySignal,
restart_delay: Option<Duration>,
) -> BoxFuture<ServiceOutcome> {
self.runner.run(token, readiness, restart_delay)
}
}
#[derive(Clone, Copy)]
enum ShutdownRequest {
Service { service: &'static str },
Signal,
}
impl ShutdownRequest {
fn cause(self) -> ShutdownCause {
match self {
Self::Service { service } => ShutdownCause::ServiceRequested { service },
Self::Signal => ShutdownCause::Signal,
}
}
}
struct CtrlC;
impl SupervisedService for CtrlC {
type Context = ();
fn name(&self) -> &'static str {
"ctrl-c"
}
fn run(&self, ctx: Context<Self::Context>) -> BoxFuture<ServiceOutcome> {
let token = ctx.token().clone();
Box::pin(async move {
match token
.run_until_cancelled_owned(tokio::signal::ctrl_c())
.await
{
Some(Ok(())) => ServiceOutcome::requested_shutdown(),
Some(Err(error)) => ServiceOutcome::failed(ServiceError::from_error(error)),
None => ServiceOutcome::cancelled(),
}
})
}
}
trait Runner: Send + Sync {
fn run(
&self,
token: CancellationToken,
readiness: ReadySignal,
restart_delay: Option<Duration>,
) -> BoxFuture<ServiceOutcome>;
}
struct ServiceRunner<S>
where
S: SupervisedService,
{
service: Arc<S>,
ctx: S::Context,
}
impl<S> Runner for ServiceRunner<S>
where
S: SupervisedService,
{
fn run(
&self,
token: CancellationToken,
readiness: ReadySignal,
restart_delay: Option<Duration>,
) -> BoxFuture<ServiceOutcome> {
let service = Arc::clone(&self.service);
let ctx = self.ctx.clone();
Box::pin(async move {
if let Some(delay) = restart_delay {
if token
.clone()
.run_until_cancelled_owned(sleep(delay))
.await
.is_none()
{
return ServiceOutcome::Cancelled;
}
}
service
.run(Context::with_readiness(token, readiness, ctx))
.await
})
}
}
struct ServiceState {
registration: Registration,
restarts: usize,
last_outcome: ServiceOutcome,
}
impl ServiceState {
fn new(registration: Registration) -> Self {
Self {
registration,
restarts: 0,
last_outcome: ServiceOutcome::Cancelled,
}
}
}
struct TaskResult {
index: usize,
outcome: ServiceOutcome,
}
fn spawn_service(
tasks: &mut JoinSet<TaskResult>,
running: &mut HashMap<TaskId, usize>,
registration: &Registration,
index: usize,
root: CancellationToken,
readiness: ReadySignal,
restart_delay: Option<Duration>,
) {
let future = registration.run(root, readiness, restart_delay);
let handle = tasks.spawn(async move {
TaskResult {
index,
outcome: future.await,
}
});
running.insert(handle.id(), index);
}
fn join_result(
joined: Result<(TaskId, TaskResult), JoinError>,
running: &mut HashMap<TaskId, usize>,
states: &mut [ServiceState],
) -> Result<TaskResult, Error> {
match joined {
Ok((task_id, result)) => {
running.remove(&task_id).ok_or(Error::UnknownTask {
task_id: task_id.to_string(),
})?;
Ok(result)
},
Err(source) if source.is_cancelled() => {
let task_id = source.id();
let index = running.remove(&task_id).ok_or(Error::UnknownTask {
task_id: task_id.to_string(),
})?;
state_mut(states, index)?.last_outcome = ServiceOutcome::Cancelled;
Ok(TaskResult::new(index, ServiceOutcome::Cancelled))
},
Err(source) => {
let task_id = source.id();
let index = running.remove(&task_id).ok_or(Error::UnknownTask {
task_id: task_id.to_string(),
})?;
Ok(TaskResult::new(
index,
ServiceOutcome::failed(source.to_string()),
))
},
}
}
impl TaskResult {
fn new(index: usize, outcome: ServiceOutcome) -> Self {
Self { index, outcome }
}
}
fn mark_cancelled(running: &HashMap<TaskId, usize>, states: &mut [ServiceState]) {
for index in running.values().copied() {
if let Some(state) = states.get_mut(index) {
state.last_outcome = ServiceOutcome::Cancelled;
}
}
}
#[instrument(skip_all, fields(shutdown_cause = ?cause))]
async fn finish_shutdown(
shutdown_timeout: Duration,
mut tasks: JoinSet<TaskResult>,
mut running: HashMap<TaskId, usize>,
mut states: Vec<ServiceState>,
readiness: ReadinessTracker,
cause: ShutdownCause,
) -> Result<RunSummary, Error> {
let started = Instant::now();
while !tasks.is_empty() {
let elapsed = started.elapsed();
if elapsed >= shutdown_timeout {
mark_cancelled(&running, &mut states);
tasks.abort_all();
break;
}
let remaining = shutdown_timeout - elapsed;
match timeout(remaining, tasks.join_next_with_id()).await {
Ok(Some(joined)) => {
let result = join_result(joined, &mut running, &mut states)?;
state_mut(&mut states, result.index)?.last_outcome = result.outcome;
},
Ok(None) => break,
Err(_) => {
mark_cancelled(&running, &mut states);
tasks.abort_all();
break;
},
}
}
while let Some(joined) = tasks.join_next_with_id().await {
let result = join_result(joined, &mut running, &mut states)?;
state_mut(&mut states, result.index)?.last_outcome = result.outcome;
}
Ok(RunSummary::new(
cause,
readiness.state(),
summarize(&states),
))
}
fn summarize(states: &[ServiceState]) -> Vec<ServiceSummary> {
states
.iter()
.map(|state| ServiceSummary {
name: state.registration.name,
outcome: state.last_outcome.clone(),
restarts: state.restarts,
})
.collect()
}
fn state_mut(states: &mut [ServiceState], index: usize) -> Result<&mut ServiceState, Error> {
states
.get_mut(index)
.ok_or(Error::UnknownServiceIndex { index })
}
#[derive(Debug, ThisError)]
pub enum Error {
#[error("cannot run a supervisor without any registered services")]
NoServices,
#[error("received completion for unknown supervised task {task_id}")]
UnknownTask { task_id: String },
#[error("received an out-of-range service index {index} from supervisor bookkeeping")]
UnknownServiceIndex { index: usize },
}
#[cfg(test)]
mod tests {
use std::{future, time::Duration};
use tokio::time::timeout;
use super::{Context, ServiceOutcome, ShutdownCause, ShutdownRequest, SupervisorBuilder};
use crate::{service_fn, ServiceExt};
#[test]
fn signal_shutdown_request_maps_to_signal_cause() {
assert_eq!(ShutdownRequest::Signal.cause(), ShutdownCause::Signal);
}
#[tokio::test(flavor = "current_thread")]
async fn context_test_teardown_cancels_running_services() {
let summary = timeout(
Duration::from_secs(2),
SupervisorBuilder::new(())
.shutdown_on_ctrl_c()
.add(
service_fn("loop", |_ctx: Context<()>| {
future::pending::<ServiceOutcome>()
})
.until_cancelled(),
)
.add(service_fn("teardown", |ctx: Context<()>| async move {
ctx.begin_teardown();
ServiceOutcome::completed()
}))
.build()
.run(),
)
.await
.expect("teardown should stop the supervisor")
.expect("supervisor should run");
assert_eq!(summary.shutdown_cause(), &ShutdownCause::Completed);
assert_eq!(
summary.service("loop").expect("loop summary").outcome(),
&ServiceOutcome::Cancelled
);
assert_eq!(
summary.service("ctrl-c").expect("ctrl-c summary").outcome(),
&ServiceOutcome::Cancelled
);
}
}