1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
use std::{collections::HashMap, time::Duration};
use crate::utils::try_extract_panic_message;
use futures::future::Fuse;
use tokio::{runtime::Runtime, sync::watch, task::JoinHandle};
pub use self::{
context::ServiceContext,
context_traits::{FromContext, IntoContext},
error::{ServiceError, TaskError},
shutdown_hook::ShutdownHook,
stop_receiver::StopReceiver,
};
use crate::{
resource::{ResourceId, StoredResource},
service::{
named_future::NamedFuture,
runnables::{NamedBoxFuture, Runnables, TaskReprs},
},
task::TaskId,
wiring_layer::{WireFn, WiringError, WiringLayer, WiringLayerExt},
};
mod context;
mod context_traits;
mod error;
mod named_future;
mod runnables;
mod shutdown_hook;
mod stop_receiver;
#[cfg(test)]
mod tests;
// A reasonable amount of time for any task to finish the shutdown process
const TASK_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
/// A builder for [`Service`].
#[derive(Debug)]
pub struct ServiceBuilder {
/// List of wiring layers.
// Note: It has to be a `Vec` and not e.g. `HashMap` because the order in which we
// iterate through it matters.
layers: Vec<(&'static str, WireFn)>,
/// Tokio runtime used to spawn tasks.
runtime: Runtime,
}
impl ServiceBuilder {
/// Creates a new builder.
///
/// Returns an error if called within a Tokio runtime context.
pub fn new() -> Result<Self, ServiceError> {
if tokio::runtime::Handle::try_current().is_ok() {
return Err(ServiceError::RuntimeDetected);
}
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
Ok(Self::on_runtime(runtime))
}
/// Creates a new builder with the provided Tokio runtime.
/// This method can be used if asynchronous tasks must be performed before the service is built.
///
/// However, it is not recommended to use this method to spawn any tasks that will not be managed
/// by the service itself, so whenever it can be avoided, using [`ServiceBuilder::new`] is preferred.
pub fn on_runtime(runtime: Runtime) -> Self {
Self {
layers: Vec::new(),
runtime,
}
}
/// Returns a handle to the Tokio runtime used by the service.
pub fn runtime_handle(&self) -> tokio::runtime::Handle {
self.runtime.handle().clone()
}
/// Adds a wiring layer.
///
/// During the [`run`](Service::run) call the service will invoke
/// `wire` method of every layer in the order they were added.
///
/// This method may be invoked multiple times with the same layer type, but the
/// layer will only be stored once (meaning that 2nd attempt to add the same layer will be ignored).
/// This may be useful if the same layer is a prerequisite for multiple other layers: it is safe
/// to add it multiple times, and it will only be wired once.
pub fn add_layer<T: WiringLayer>(&mut self, layer: T) -> &mut Self {
let name = layer.layer_name();
if !self
.layers
.iter()
.any(|(existing_name, _)| name == *existing_name)
{
self.layers.push((name, layer.into_wire_fn()));
}
self
}
/// Builds the service.
pub fn build(self) -> Service {
let (stop_sender, _stop_receiver) = watch::channel(false);
Service {
layers: self.layers,
resources: Default::default(),
runnables: Default::default(),
stop_sender,
runtime: self.runtime,
errors: Vec::new(),
}
}
}
/// "Manager" class for a set of tasks. Collects all the resources and tasks,
/// then runs tasks until completion.
#[derive(Debug)]
pub struct Service {
/// Cache of resources that have been requested at least by one task.
resources: HashMap<ResourceId, Box<dyn StoredResource>>,
/// List of wiring layers.
layers: Vec<(&'static str, WireFn)>,
/// Different kinds of tasks for the service.
runnables: Runnables,
/// Sender used to stop the tasks.
stop_sender: watch::Sender<bool>,
/// Tokio runtime used to spawn tasks.
runtime: Runtime,
/// Collector for the task errors met during the service execution.
errors: Vec<TaskError>,
}
type TaskFuture = NamedFuture<Fuse<JoinHandle<eyre::Result<()>>>>;
impl Service {
/// Runs the system.
///
/// In case of errors during wiring phase, will return the list of all the errors that happened, in the order
/// of their occurrence.
pub fn run(self) -> Result<(), ServiceError> {
self.run_with_guard(())
}
/// Runs the system.
///
/// In case of errors during wiring phase, will return the list of all the errors that happened, in the order
/// of their occurrence.
///
/// `observability_guard` will be used to deinitialize the observability subsystem
/// as the very last step before exiting the node.
pub fn run_with_guard<G>(mut self, observability_guard: G) -> Result<(), ServiceError> {
self.wire()?;
let TaskReprs {
tasks,
shutdown_hooks,
} = self.prepare_tasks();
let remaining = self.run_tasks(tasks);
self.shutdown_tasks(remaining);
self.run_shutdown_hooks(shutdown_hooks);
tracing::info!("Exiting the service");
if std::mem::needs_drop::<G>() {
// Make sure that the shutdown happens in the `tokio` context.
let _guard = self.runtime.enter();
drop(observability_guard);
}
if self.errors.is_empty() {
Ok(())
} else {
Err(ServiceError::Task(self.errors.into()))
}
}
/// Performs wiring of the service.
/// After invoking this method, the collected tasks will be collected in `self.runnables`.
fn wire(&mut self) -> Result<(), ServiceError> {
// Initialize tasks.
let wiring_layers = std::mem::take(&mut self.layers);
let mut errors: Vec<(String, WiringError)> = Vec::new();
let runtime_handle = self.runtime.handle().clone();
for (name, WireFn(wire_fn)) in wiring_layers {
// We must process wiring layers sequentially and in the same order as they were added.
let mut context = ServiceContext::new(name, self);
let task_result = wire_fn(&runtime_handle, &mut context);
if let Err(err) = task_result {
// We don't want to bail on the first error, since it'll provide worse DevEx:
// People likely want to fix as much problems as they can in one go, rather than have
// to fix them one by one.
errors.push((name.to_string(), err));
continue;
};
}
// Report all the errors we've met during the init.
if !errors.is_empty() {
for (layer, error) in &errors {
tracing::error!("Wiring layer {layer} can't be initialized: {error:?}");
}
return Err(ServiceError::Wiring(errors));
}
if self.runnables.is_empty() {
return Err(ServiceError::NoTasks);
}
// Wiring is now complete.
for resource in self.resources.values_mut() {
resource.stored_resource_wired();
}
self.resources = HashMap::default(); // Decrement reference counters for resources.
tracing::info!("Wiring complete");
Ok(())
}
/// Prepares collected tasks for running.
fn prepare_tasks(&mut self) -> TaskReprs {
// Barrier that will only be lifted once all the preconditions are met.
// It will be awaited by the tasks before they start running and by the preconditions once they are fulfilled.
let task_barrier = self.runnables.task_barrier();
// Collect long-running tasks.
let stop_receiver = StopReceiver(self.stop_sender.subscribe());
self.runnables
.prepare_tasks(task_barrier.clone(), stop_receiver.clone())
}
/// Spawn the provided tasks and runs them until at least one task exits, and returns the list
/// of remaining tasks.
/// Adds error, if any, to the `errors` vector.
fn run_tasks(&mut self, tasks: Vec<NamedBoxFuture<eyre::Result<()>>>) -> Vec<TaskFuture> {
// Prepare tasks for running.
let rt_handle = self.runtime.handle().clone();
let join_handles: Vec<_> = tasks
.into_iter()
.map(|task| task.spawn(&rt_handle).fuse())
.collect();
// Collect names for remaining tasks for reporting purposes.
let mut tasks_names: Vec<_> = join_handles.iter().map(|task| task.id()).collect();
// Run the tasks until one of them exits.
let (resolved, resolved_idx, remaining) = self
.runtime
.block_on(futures::future::select_all(join_handles));
// Extract the result and report it to logs early, before waiting for any other task to shutdown.
// We will also collect the errors from the remaining tasks, hence a vector.
let task_name = tasks_names.swap_remove(resolved_idx);
self.handle_task_exit(resolved, task_name);
tracing::info!("One of the task has exited, shutting down the node");
remaining
}
/// Sends the stop signal and waits for the remaining tasks to finish.
fn shutdown_tasks(&mut self, remaining: Vec<TaskFuture>) {
// Send stop signal to remaining tasks and wait for them to finish.
self.stop_sender.send(true).ok();
// Collect names for remaining tasks for reporting purposes.
// We have to re-collect, becuase `select_all` does not guarantes the order of returned remaining futures.
let remaining_tasks_names: Vec<_> = remaining.iter().map(|task| task.id()).collect();
let remaining_tasks_with_timeout: Vec<_> = remaining
.into_iter()
.map(|task| async { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, task).await })
.collect();
let execution_results = self
.runtime
.block_on(futures::future::join_all(remaining_tasks_with_timeout));
// Report the results of the remaining tasks.
for (name, result) in remaining_tasks_names.into_iter().zip(execution_results) {
match result {
Ok(resolved) => {
self.handle_task_exit(resolved, name);
}
Err(_) => {
tracing::error!("Task {name} timed out");
self.errors.push(TaskError::TaskShutdownTimedOut(name));
}
}
}
}
/// Runs the provided shutdown hooks.
fn run_shutdown_hooks(&mut self, shutdown_hooks: Vec<NamedBoxFuture<eyre::Result<()>>>) {
// Run shutdown hooks sequentially.
for hook in shutdown_hooks {
let name = hook.id().clone();
// Limit each shutdown hook to the same timeout as the tasks.
let hook_with_timeout =
async move { tokio::time::timeout(TASK_SHUTDOWN_TIMEOUT, hook).await };
match self.runtime.block_on(hook_with_timeout) {
Ok(Ok(())) => {
tracing::info!("Shutdown hook {name} completed");
}
Ok(Err(err)) => {
tracing::error!("Shutdown hook {name} failed: {err:?}");
self.errors.push(TaskError::ShutdownHookFailed(name, err));
}
Err(_) => {
tracing::error!("Shutdown hook {name} timed out");
self.errors.push(TaskError::ShutdownHookTimedOut(name));
}
}
}
}
/// Checks the result of the task execution, logs the result, and stores the error if any.
fn handle_task_exit(
&mut self,
task_result: Result<eyre::Result<()>, tokio::task::JoinError>,
task_name: TaskId,
) {
match task_result {
Ok(Ok(())) => {
tracing::info!("Task {task_name} finished");
}
Ok(Err(err)) => {
tracing::error!("Task {task_name} failed: {err:?}");
self.errors.push(TaskError::TaskFailed(task_name, err));
}
Err(panic_err) => {
let panic_msg = try_extract_panic_message(panic_err);
tracing::error!("Task {task_name} panicked: {panic_msg}");
self.errors
.push(TaskError::TaskPanicked(task_name, panic_msg));
}
};
}
}