Skip to main content

rs_zero/core/
service_group.rs

1use std::{future::Future, sync::Arc, time::Duration};
2
3use tokio::task::JoinSet;
4
5use crate::core::{CoreError, CoreResult, FnService, Service, ShutdownToken, shutdown_signal};
6
7/// Service group runtime controls.
8#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ServiceGroupConfig {
10    /// Maximum time allowed for service stop hooks and task shutdown.
11    pub shutdown_timeout: Duration,
12    /// Whether the group cancels all services after the first service error.
13    pub stop_on_first_error: bool,
14}
15
16impl Default for ServiceGroupConfig {
17    fn default() -> Self {
18        Self {
19            shutdown_timeout: Duration::from_secs(30),
20            stop_on_first_error: true,
21        }
22    }
23}
24
25/// Handle that can request shutdown for a running [`ServiceGroup`].
26#[derive(Debug, Clone)]
27pub struct ServiceGroupHandle {
28    shutdown: ShutdownToken,
29}
30
31impl ServiceGroupHandle {
32    /// Requests group shutdown. Calling this method multiple times is safe.
33    pub fn stop(&self) {
34        self.shutdown.cancel();
35    }
36
37    /// Returns whether shutdown has been requested.
38    pub fn is_stopped(&self) -> bool {
39        self.shutdown.is_cancelled()
40    }
41}
42
43/// A group of async services that start together and stop together.
44///
45/// Services are started concurrently. Like go-zero's service group, startup
46/// order must not be relied on. When a process shutdown signal, explicit handle
47/// stop, or service error occurs, the group broadcasts a shutdown token, runs
48/// stop hooks, waits for tasks and returns aggregated errors if any.
49pub struct ServiceGroup {
50    config: ServiceGroupConfig,
51    services: Vec<Arc<dyn Service>>,
52    shutdown: ShutdownToken,
53}
54
55impl Default for ServiceGroup {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61impl ServiceGroup {
62    /// Creates an empty service group with default controls.
63    pub fn new() -> Self {
64        Self::with_config(ServiceGroupConfig::default())
65    }
66
67    /// Creates an empty service group with custom controls.
68    pub fn with_config(config: ServiceGroupConfig) -> Self {
69        Self {
70            config,
71            services: Vec::new(),
72            shutdown: ShutdownToken::new(),
73        }
74    }
75
76    /// Adds a service to the group.
77    pub fn add<S>(&mut self, service: S) -> &mut Self
78    where
79        S: Service,
80    {
81        self.services.push(Arc::new(service));
82        self
83    }
84
85    /// Adds an already shared service to the group.
86    pub fn add_arc<S>(&mut self, service: Arc<S>) -> &mut Self
87    where
88        S: Service,
89    {
90        self.services.push(service);
91        self
92    }
93
94    /// Adds a service backed by an async start function.
95    pub fn add_fn<F, Fut>(&mut self, name: impl Into<String>, start: F) -> &mut Self
96    where
97        F: Fn(ShutdownToken) -> Fut + Send + Sync + 'static,
98        Fut: Future<Output = CoreResult<()>> + Send + 'static,
99    {
100        self.add(FnService::new(name, start))
101    }
102
103    /// Returns a handle that can stop the group while it is running.
104    pub fn handle(&self) -> ServiceGroupHandle {
105        ServiceGroupHandle {
106            shutdown: self.shutdown.clone(),
107        }
108    }
109
110    /// Starts all services and waits for Ctrl-C or explicit stop.
111    pub async fn start(self) -> CoreResult<()> {
112        self.start_with_shutdown(shutdown_signal()).await
113    }
114
115    /// Starts all services and waits for the supplied shutdown future.
116    pub async fn start_with_shutdown<F>(self, shutdown: F) -> CoreResult<()>
117    where
118        F: Future<Output = ()> + Send,
119    {
120        if self.services.is_empty() {
121            return Ok(());
122        }
123
124        let mut tasks = spawn_services(&self.services, &self.shutdown);
125        let mut active = self.services.len();
126        let mut errors = Vec::new();
127        tokio::pin!(shutdown);
128
129        while active > 0 {
130            tokio::select! {
131                _ = &mut shutdown => {
132                    self.shutdown.cancel();
133                    break;
134                }
135                _ = self.shutdown.cancelled() => {
136                    break;
137                }
138                joined = tasks.join_next() => {
139                    let Some(joined) = joined else { break; };
140                    active -= 1;
141                    if handle_service_exit(joined, &mut errors) && self.config.stop_on_first_error {
142                        self.shutdown.cancel();
143                        break;
144                    }
145                }
146            }
147        }
148
149        if self.shutdown.is_cancelled() || active > 0 {
150            self.shutdown.cancel();
151            stop_services(&self.services, self.config.shutdown_timeout, &mut errors).await;
152            wait_for_tasks(
153                &mut tasks,
154                active,
155                self.config.shutdown_timeout,
156                &mut errors,
157            )
158            .await;
159        }
160
161        into_group_result(errors)
162    }
163}
164
165struct ServiceTaskExit {
166    name: String,
167    result: CoreResult<()>,
168}
169
170fn spawn_services(
171    services: &[Arc<dyn Service>],
172    shutdown: &ShutdownToken,
173) -> JoinSet<ServiceTaskExit> {
174    let mut tasks = JoinSet::new();
175    for service in services {
176        let service = Arc::clone(service);
177        let token = shutdown.clone();
178        let name = service.name().to_string();
179        tasks.spawn(async move {
180            let result = service.start(token).await;
181            ServiceTaskExit { name, result }
182        });
183    }
184    tasks
185}
186
187fn handle_service_exit(
188    joined: Result<ServiceTaskExit, tokio::task::JoinError>,
189    errors: &mut Vec<String>,
190) -> bool {
191    match joined {
192        Ok(exit) => {
193            if let Err(error) = exit.result {
194                errors.push(format!("service {} failed: {error}", exit.name));
195                return true;
196            }
197        }
198        Err(error) => {
199            errors.push(format!("service task failed: {error}"));
200            return true;
201        }
202    }
203    false
204}
205
206async fn stop_services(services: &[Arc<dyn Service>], timeout: Duration, errors: &mut Vec<String>) {
207    match tokio::time::timeout(timeout, run_stop_hooks(services)).await {
208        Ok(stop_errors) => errors.extend(stop_errors),
209        Err(_) => errors.push(format!(
210            "service group stop hooks timed out after {:?}",
211            timeout
212        )),
213    }
214}
215
216async fn run_stop_hooks(services: &[Arc<dyn Service>]) -> Vec<String> {
217    let mut tasks = JoinSet::new();
218    for service in services.iter().rev() {
219        let service = Arc::clone(service);
220        let name = service.name().to_string();
221        tasks.spawn(async move {
222            let result = service.stop().await;
223            (name, result)
224        });
225    }
226
227    let mut errors = Vec::new();
228    while let Some(joined) = tasks.join_next().await {
229        match joined {
230            Ok((_name, Ok(()))) => {}
231            Ok((name, Err(error))) => errors.push(format!("service {name} stop failed: {error}")),
232            Err(error) => errors.push(format!("service stop task failed: {error}")),
233        }
234    }
235    errors
236}
237
238async fn wait_for_tasks(
239    tasks: &mut JoinSet<ServiceTaskExit>,
240    active: usize,
241    timeout: Duration,
242    errors: &mut Vec<String>,
243) {
244    let wait = async {
245        let mut remaining = active;
246        while remaining > 0 {
247            let Some(joined) = tasks.join_next().await else {
248                break;
249            };
250            remaining -= 1;
251            match joined {
252                Ok(exit) => {
253                    if let Err(error) = exit.result {
254                        errors.push(format!(
255                            "service {} failed during shutdown: {error}",
256                            exit.name
257                        ));
258                    }
259                }
260                Err(error) => errors.push(format!("service task failed during shutdown: {error}")),
261            }
262        }
263    };
264
265    if tokio::time::timeout(timeout, wait).await.is_err() {
266        tasks.abort_all();
267        errors.push(format!(
268            "service group shutdown timed out after {:?}",
269            timeout
270        ));
271    }
272}
273
274fn into_group_result(errors: Vec<String>) -> CoreResult<()> {
275    if errors.is_empty() {
276        Ok(())
277    } else {
278        Err(CoreError::Service(errors.join("; ")))
279    }
280}