rs-zero 0.2.4

Rust-first microservice framework inspired by go-zero engineering practices
Documentation
use std::{future::Future, sync::Arc, time::Duration};

use tokio::task::JoinSet;

use crate::core::{CoreError, CoreResult, FnService, Service, ShutdownToken, shutdown_signal};

/// Service group runtime controls.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ServiceGroupConfig {
    /// Maximum time allowed for service stop hooks and task shutdown.
    pub shutdown_timeout: Duration,
    /// Whether the group cancels all services after the first service error.
    pub stop_on_first_error: bool,
}

impl Default for ServiceGroupConfig {
    fn default() -> Self {
        Self {
            shutdown_timeout: Duration::from_secs(30),
            stop_on_first_error: true,
        }
    }
}

/// Handle that can request shutdown for a running [`ServiceGroup`].
#[derive(Debug, Clone)]
pub struct ServiceGroupHandle {
    shutdown: ShutdownToken,
}

impl ServiceGroupHandle {
    /// Requests group shutdown. Calling this method multiple times is safe.
    pub fn stop(&self) {
        self.shutdown.cancel();
    }

    /// Returns whether shutdown has been requested.
    pub fn is_stopped(&self) -> bool {
        self.shutdown.is_cancelled()
    }
}

/// A group of async services that start together and stop together.
///
/// Services are started concurrently. Like go-zero's service group, startup
/// order must not be relied on. When a process shutdown signal, explicit handle
/// stop, or service error occurs, the group broadcasts a shutdown token, runs
/// stop hooks, waits for tasks and returns aggregated errors if any.
pub struct ServiceGroup {
    config: ServiceGroupConfig,
    services: Vec<Arc<dyn Service>>,
    shutdown: ShutdownToken,
}

impl Default for ServiceGroup {
    fn default() -> Self {
        Self::new()
    }
}

impl ServiceGroup {
    /// Creates an empty service group with default controls.
    pub fn new() -> Self {
        Self::with_config(ServiceGroupConfig::default())
    }

    /// Creates an empty service group with custom controls.
    pub fn with_config(config: ServiceGroupConfig) -> Self {
        Self {
            config,
            services: Vec::new(),
            shutdown: ShutdownToken::new(),
        }
    }

    /// Adds a service to the group.
    pub fn add<S>(&mut self, service: S) -> &mut Self
    where
        S: Service,
    {
        self.services.push(Arc::new(service));
        self
    }

    /// Adds an already shared service to the group.
    pub fn add_arc<S>(&mut self, service: Arc<S>) -> &mut Self
    where
        S: Service,
    {
        self.services.push(service);
        self
    }

    /// Adds a service backed by an async start function.
    pub fn add_fn<F, Fut>(&mut self, name: impl Into<String>, start: F) -> &mut Self
    where
        F: Fn(ShutdownToken) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = CoreResult<()>> + Send + 'static,
    {
        self.add(FnService::new(name, start))
    }

    /// Returns a handle that can stop the group while it is running.
    pub fn handle(&self) -> ServiceGroupHandle {
        ServiceGroupHandle {
            shutdown: self.shutdown.clone(),
        }
    }

    /// Starts all services and waits for Ctrl-C or explicit stop.
    pub async fn start(self) -> CoreResult<()> {
        self.start_with_shutdown(shutdown_signal()).await
    }

    /// Starts all services and waits for the supplied shutdown future.
    pub async fn start_with_shutdown<F>(self, shutdown: F) -> CoreResult<()>
    where
        F: Future<Output = ()> + Send,
    {
        if self.services.is_empty() {
            return Ok(());
        }

        let mut tasks = spawn_services(&self.services, &self.shutdown);
        let mut active = self.services.len();
        let mut errors = Vec::new();
        tokio::pin!(shutdown);

        while active > 0 {
            tokio::select! {
                _ = &mut shutdown => {
                    self.shutdown.cancel();
                    break;
                }
                _ = self.shutdown.cancelled() => {
                    break;
                }
                joined = tasks.join_next() => {
                    let Some(joined) = joined else { break; };
                    active -= 1;
                    if handle_service_exit(joined, &mut errors) && self.config.stop_on_first_error {
                        self.shutdown.cancel();
                        break;
                    }
                }
            }
        }

        if self.shutdown.is_cancelled() || active > 0 {
            self.shutdown.cancel();
            stop_services(&self.services, self.config.shutdown_timeout, &mut errors).await;
            wait_for_tasks(
                &mut tasks,
                active,
                self.config.shutdown_timeout,
                &mut errors,
            )
            .await;
        }

        into_group_result(errors)
    }
}

struct ServiceTaskExit {
    name: String,
    result: CoreResult<()>,
}

fn spawn_services(
    services: &[Arc<dyn Service>],
    shutdown: &ShutdownToken,
) -> JoinSet<ServiceTaskExit> {
    let mut tasks = JoinSet::new();
    for service in services {
        let service = Arc::clone(service);
        let token = shutdown.clone();
        let name = service.name().to_string();
        tasks.spawn(async move {
            let result = service.start(token).await;
            ServiceTaskExit { name, result }
        });
    }
    tasks
}

fn handle_service_exit(
    joined: Result<ServiceTaskExit, tokio::task::JoinError>,
    errors: &mut Vec<String>,
) -> bool {
    match joined {
        Ok(exit) => {
            if let Err(error) = exit.result {
                errors.push(format!("service {} failed: {error}", exit.name));
                return true;
            }
        }
        Err(error) => {
            errors.push(format!("service task failed: {error}"));
            return true;
        }
    }
    false
}

async fn stop_services(services: &[Arc<dyn Service>], timeout: Duration, errors: &mut Vec<String>) {
    match tokio::time::timeout(timeout, run_stop_hooks(services)).await {
        Ok(stop_errors) => errors.extend(stop_errors),
        Err(_) => errors.push(format!(
            "service group stop hooks timed out after {:?}",
            timeout
        )),
    }
}

async fn run_stop_hooks(services: &[Arc<dyn Service>]) -> Vec<String> {
    let mut tasks = JoinSet::new();
    for service in services.iter().rev() {
        let service = Arc::clone(service);
        let name = service.name().to_string();
        tasks.spawn(async move {
            let result = service.stop().await;
            (name, result)
        });
    }

    let mut errors = Vec::new();
    while let Some(joined) = tasks.join_next().await {
        match joined {
            Ok((_name, Ok(()))) => {}
            Ok((name, Err(error))) => errors.push(format!("service {name} stop failed: {error}")),
            Err(error) => errors.push(format!("service stop task failed: {error}")),
        }
    }
    errors
}

async fn wait_for_tasks(
    tasks: &mut JoinSet<ServiceTaskExit>,
    active: usize,
    timeout: Duration,
    errors: &mut Vec<String>,
) {
    let wait = async {
        let mut remaining = active;
        while remaining > 0 {
            let Some(joined) = tasks.join_next().await else {
                break;
            };
            remaining -= 1;
            match joined {
                Ok(exit) => {
                    if let Err(error) = exit.result {
                        errors.push(format!(
                            "service {} failed during shutdown: {error}",
                            exit.name
                        ));
                    }
                }
                Err(error) => errors.push(format!("service task failed during shutdown: {error}")),
            }
        }
    };

    if tokio::time::timeout(timeout, wait).await.is_err() {
        tasks.abort_all();
        errors.push(format!(
            "service group shutdown timed out after {:?}",
            timeout
        ));
    }
}

fn into_group_result(errors: Vec<String>) -> CoreResult<()> {
    if errors.is_empty() {
        Ok(())
    } else {
        Err(CoreError::Service(errors.join("; ")))
    }
}