rs_zero/core/
service_group.rs1use std::{future::Future, sync::Arc, time::Duration};
2
3use tokio::task::JoinSet;
4
5use crate::core::{CoreError, CoreResult, FnService, Service, ShutdownToken, shutdown_signal};
6
7#[derive(Debug, Clone, PartialEq, Eq)]
9pub struct ServiceGroupConfig {
10 pub shutdown_timeout: Duration,
12 pub stop_on_first_error: bool,
14}
15
16impl Default for ServiceGroupConfig {
17 fn default() -> Self {
18 Self {
19 shutdown_timeout: Duration::from_secs(30),
20 stop_on_first_error: true,
21 }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct ServiceGroupHandle {
28 shutdown: ShutdownToken,
29}
30
31impl ServiceGroupHandle {
32 pub fn stop(&self) {
34 self.shutdown.cancel();
35 }
36
37 pub fn is_stopped(&self) -> bool {
39 self.shutdown.is_cancelled()
40 }
41}
42
43pub struct ServiceGroup {
50 config: ServiceGroupConfig,
51 services: Vec<Arc<dyn Service>>,
52 shutdown: ShutdownToken,
53}
54
55impl Default for ServiceGroup {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl ServiceGroup {
62 pub fn new() -> Self {
64 Self::with_config(ServiceGroupConfig::default())
65 }
66
67 pub fn with_config(config: ServiceGroupConfig) -> Self {
69 Self {
70 config,
71 services: Vec::new(),
72 shutdown: ShutdownToken::new(),
73 }
74 }
75
76 pub fn add<S>(&mut self, service: S) -> &mut Self
78 where
79 S: Service,
80 {
81 self.services.push(Arc::new(service));
82 self
83 }
84
85 pub fn add_arc<S>(&mut self, service: Arc<S>) -> &mut Self
87 where
88 S: Service,
89 {
90 self.services.push(service);
91 self
92 }
93
94 pub fn add_fn<F, Fut>(&mut self, name: impl Into<String>, start: F) -> &mut Self
96 where
97 F: Fn(ShutdownToken) -> Fut + Send + Sync + 'static,
98 Fut: Future<Output = CoreResult<()>> + Send + 'static,
99 {
100 self.add(FnService::new(name, start))
101 }
102
103 pub fn handle(&self) -> ServiceGroupHandle {
105 ServiceGroupHandle {
106 shutdown: self.shutdown.clone(),
107 }
108 }
109
110 pub async fn start(self) -> CoreResult<()> {
112 self.start_with_shutdown(shutdown_signal()).await
113 }
114
115 pub async fn start_with_shutdown<F>(self, shutdown: F) -> CoreResult<()>
117 where
118 F: Future<Output = ()> + Send,
119 {
120 if self.services.is_empty() {
121 return Ok(());
122 }
123
124 let mut tasks = spawn_services(&self.services, &self.shutdown);
125 let mut active = self.services.len();
126 let mut errors = Vec::new();
127 tokio::pin!(shutdown);
128
129 while active > 0 {
130 tokio::select! {
131 _ = &mut shutdown => {
132 self.shutdown.cancel();
133 break;
134 }
135 _ = self.shutdown.cancelled() => {
136 break;
137 }
138 joined = tasks.join_next() => {
139 let Some(joined) = joined else { break; };
140 active -= 1;
141 if handle_service_exit(joined, &mut errors) && self.config.stop_on_first_error {
142 self.shutdown.cancel();
143 break;
144 }
145 }
146 }
147 }
148
149 if self.shutdown.is_cancelled() || active > 0 {
150 self.shutdown.cancel();
151 stop_services(&self.services, self.config.shutdown_timeout, &mut errors).await;
152 wait_for_tasks(
153 &mut tasks,
154 active,
155 self.config.shutdown_timeout,
156 &mut errors,
157 )
158 .await;
159 }
160
161 into_group_result(errors)
162 }
163}
164
165struct ServiceTaskExit {
166 name: String,
167 result: CoreResult<()>,
168}
169
170fn spawn_services(
171 services: &[Arc<dyn Service>],
172 shutdown: &ShutdownToken,
173) -> JoinSet<ServiceTaskExit> {
174 let mut tasks = JoinSet::new();
175 for service in services {
176 let service = Arc::clone(service);
177 let token = shutdown.clone();
178 let name = service.name().to_string();
179 tasks.spawn(async move {
180 let result = service.start(token).await;
181 ServiceTaskExit { name, result }
182 });
183 }
184 tasks
185}
186
187fn handle_service_exit(
188 joined: Result<ServiceTaskExit, tokio::task::JoinError>,
189 errors: &mut Vec<String>,
190) -> bool {
191 match joined {
192 Ok(exit) => {
193 if let Err(error) = exit.result {
194 errors.push(format!("service {} failed: {error}", exit.name));
195 return true;
196 }
197 }
198 Err(error) => {
199 errors.push(format!("service task failed: {error}"));
200 return true;
201 }
202 }
203 false
204}
205
206async fn stop_services(services: &[Arc<dyn Service>], timeout: Duration, errors: &mut Vec<String>) {
207 match tokio::time::timeout(timeout, run_stop_hooks(services)).await {
208 Ok(stop_errors) => errors.extend(stop_errors),
209 Err(_) => errors.push(format!(
210 "service group stop hooks timed out after {:?}",
211 timeout
212 )),
213 }
214}
215
216async fn run_stop_hooks(services: &[Arc<dyn Service>]) -> Vec<String> {
217 let mut tasks = JoinSet::new();
218 for service in services.iter().rev() {
219 let service = Arc::clone(service);
220 let name = service.name().to_string();
221 tasks.spawn(async move {
222 let result = service.stop().await;
223 (name, result)
224 });
225 }
226
227 let mut errors = Vec::new();
228 while let Some(joined) = tasks.join_next().await {
229 match joined {
230 Ok((_name, Ok(()))) => {}
231 Ok((name, Err(error))) => errors.push(format!("service {name} stop failed: {error}")),
232 Err(error) => errors.push(format!("service stop task failed: {error}")),
233 }
234 }
235 errors
236}
237
238async fn wait_for_tasks(
239 tasks: &mut JoinSet<ServiceTaskExit>,
240 active: usize,
241 timeout: Duration,
242 errors: &mut Vec<String>,
243) {
244 let wait = async {
245 let mut remaining = active;
246 while remaining > 0 {
247 let Some(joined) = tasks.join_next().await else {
248 break;
249 };
250 remaining -= 1;
251 match joined {
252 Ok(exit) => {
253 if let Err(error) = exit.result {
254 errors.push(format!(
255 "service {} failed during shutdown: {error}",
256 exit.name
257 ));
258 }
259 }
260 Err(error) => errors.push(format!("service task failed during shutdown: {error}")),
261 }
262 }
263 };
264
265 if tokio::time::timeout(timeout, wait).await.is_err() {
266 tasks.abort_all();
267 errors.push(format!(
268 "service group shutdown timed out after {:?}",
269 timeout
270 ));
271 }
272}
273
274fn into_group_result(errors: Vec<String>) -> CoreResult<()> {
275 if errors.is_empty() {
276 Ok(())
277 } else {
278 Err(CoreError::Service(errors.join("; ")))
279 }
280}