use std::{future::Future, sync::Arc, time::Duration};
use tokio::task::JoinSet;
use crate::core::{CoreError, CoreResult, FnService, Service, ShutdownToken, shutdown_signal};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ServiceGroupConfig {
pub shutdown_timeout: Duration,
pub stop_on_first_error: bool,
}
impl Default for ServiceGroupConfig {
fn default() -> Self {
Self {
shutdown_timeout: Duration::from_secs(30),
stop_on_first_error: true,
}
}
}
#[derive(Debug, Clone)]
pub struct ServiceGroupHandle {
shutdown: ShutdownToken,
}
impl ServiceGroupHandle {
pub fn stop(&self) {
self.shutdown.cancel();
}
pub fn is_stopped(&self) -> bool {
self.shutdown.is_cancelled()
}
}
pub struct ServiceGroup {
config: ServiceGroupConfig,
services: Vec<Arc<dyn Service>>,
shutdown: ShutdownToken,
}
impl Default for ServiceGroup {
fn default() -> Self {
Self::new()
}
}
impl ServiceGroup {
pub fn new() -> Self {
Self::with_config(ServiceGroupConfig::default())
}
pub fn with_config(config: ServiceGroupConfig) -> Self {
Self {
config,
services: Vec::new(),
shutdown: ShutdownToken::new(),
}
}
pub fn add<S>(&mut self, service: S) -> &mut Self
where
S: Service,
{
self.services.push(Arc::new(service));
self
}
pub fn add_arc<S>(&mut self, service: Arc<S>) -> &mut Self
where
S: Service,
{
self.services.push(service);
self
}
pub fn add_fn<F, Fut>(&mut self, name: impl Into<String>, start: F) -> &mut Self
where
F: Fn(ShutdownToken) -> Fut + Send + Sync + 'static,
Fut: Future<Output = CoreResult<()>> + Send + 'static,
{
self.add(FnService::new(name, start))
}
pub fn handle(&self) -> ServiceGroupHandle {
ServiceGroupHandle {
shutdown: self.shutdown.clone(),
}
}
pub async fn start(self) -> CoreResult<()> {
self.start_with_shutdown(shutdown_signal()).await
}
pub async fn start_with_shutdown<F>(self, shutdown: F) -> CoreResult<()>
where
F: Future<Output = ()> + Send,
{
if self.services.is_empty() {
return Ok(());
}
let mut tasks = spawn_services(&self.services, &self.shutdown);
let mut active = self.services.len();
let mut errors = Vec::new();
tokio::pin!(shutdown);
while active > 0 {
tokio::select! {
_ = &mut shutdown => {
self.shutdown.cancel();
break;
}
_ = self.shutdown.cancelled() => {
break;
}
joined = tasks.join_next() => {
let Some(joined) = joined else { break; };
active -= 1;
if handle_service_exit(joined, &mut errors) && self.config.stop_on_first_error {
self.shutdown.cancel();
break;
}
}
}
}
if self.shutdown.is_cancelled() || active > 0 {
self.shutdown.cancel();
stop_services(&self.services, self.config.shutdown_timeout, &mut errors).await;
wait_for_tasks(
&mut tasks,
active,
self.config.shutdown_timeout,
&mut errors,
)
.await;
}
into_group_result(errors)
}
}
struct ServiceTaskExit {
name: String,
result: CoreResult<()>,
}
fn spawn_services(
services: &[Arc<dyn Service>],
shutdown: &ShutdownToken,
) -> JoinSet<ServiceTaskExit> {
let mut tasks = JoinSet::new();
for service in services {
let service = Arc::clone(service);
let token = shutdown.clone();
let name = service.name().to_string();
tasks.spawn(async move {
let result = service.start(token).await;
ServiceTaskExit { name, result }
});
}
tasks
}
fn handle_service_exit(
joined: Result<ServiceTaskExit, tokio::task::JoinError>,
errors: &mut Vec<String>,
) -> bool {
match joined {
Ok(exit) => {
if let Err(error) = exit.result {
errors.push(format!("service {} failed: {error}", exit.name));
return true;
}
}
Err(error) => {
errors.push(format!("service task failed: {error}"));
return true;
}
}
false
}
async fn stop_services(services: &[Arc<dyn Service>], timeout: Duration, errors: &mut Vec<String>) {
match tokio::time::timeout(timeout, run_stop_hooks(services)).await {
Ok(stop_errors) => errors.extend(stop_errors),
Err(_) => errors.push(format!(
"service group stop hooks timed out after {:?}",
timeout
)),
}
}
async fn run_stop_hooks(services: &[Arc<dyn Service>]) -> Vec<String> {
let mut tasks = JoinSet::new();
for service in services.iter().rev() {
let service = Arc::clone(service);
let name = service.name().to_string();
tasks.spawn(async move {
let result = service.stop().await;
(name, result)
});
}
let mut errors = Vec::new();
while let Some(joined) = tasks.join_next().await {
match joined {
Ok((_name, Ok(()))) => {}
Ok((name, Err(error))) => errors.push(format!("service {name} stop failed: {error}")),
Err(error) => errors.push(format!("service stop task failed: {error}")),
}
}
errors
}
async fn wait_for_tasks(
tasks: &mut JoinSet<ServiceTaskExit>,
active: usize,
timeout: Duration,
errors: &mut Vec<String>,
) {
let wait = async {
let mut remaining = active;
while remaining > 0 {
let Some(joined) = tasks.join_next().await else {
break;
};
remaining -= 1;
match joined {
Ok(exit) => {
if let Err(error) = exit.result {
errors.push(format!(
"service {} failed during shutdown: {error}",
exit.name
));
}
}
Err(error) => errors.push(format!("service task failed during shutdown: {error}")),
}
}
};
if tokio::time::timeout(timeout, wait).await.is_err() {
tasks.abort_all();
errors.push(format!(
"service group shutdown timed out after {:?}",
timeout
));
}
}
fn into_group_result(errors: Vec<String>) -> CoreResult<()> {
if errors.is_empty() {
Ok(())
} else {
Err(CoreError::Service(errors.join("; ")))
}
}