Skip to main content

miden_node_utils/
tasks.rs

1use std::collections::HashMap;
2use std::future::Future;
3
4use anyhow::Context;
5use tokio::task::{Id, JoinError, JoinSet};
6
7/// A named task set for supervising concurrently-running Tokio tasks.
8///
9/// Dropping a task set aborts all tasks that are still running.
10pub struct Tasks {
11    handles: JoinSet<anyhow::Result<()>>,
12    names: HashMap<Id, String>,
13}
14
15impl Default for Tasks {
16    fn default() -> Self {
17        Self {
18            handles: JoinSet::new(),
19            names: HashMap::new(),
20        }
21    }
22}
23
24impl Tasks {
25    /// Creates an empty task set.
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    /// Spawns a named task into the set.
31    pub fn spawn(
32        &mut self,
33        name: impl Into<String>,
34        task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
35    ) -> Id {
36        let id = self.handles.spawn(task).id();
37        self.names.insert(id, name.into());
38        id
39    }
40
41    /// Spawns a named task that does not return an error.
42    pub fn spawn_infallible(
43        &mut self,
44        name: impl Into<String>,
45        task: impl Future<Output = ()> + Send + 'static,
46    ) -> Id {
47        self.spawn(name, async move {
48            task.await;
49            Ok(())
50        })
51    }
52
53    /// Waits for the next task to complete.
54    pub async fn join_next(&mut self) -> Option<(String, Result<anyhow::Result<()>, JoinError>)> {
55        let result = self.handles.join_next_with_id().await?;
56        let id = match &result {
57            Ok((id, _)) => *id,
58            Err(err) => err.id(),
59        };
60        let name = self.names.remove(&id).unwrap_or_else(|| "unknown".to_string());
61        let result = result.map(|(_, output)| output);
62
63        Some((name, result))
64    }
65
66    /// Returns `true` if no tasks are currently in the set.
67    pub fn is_empty(&self) -> bool {
68        self.handles.is_empty()
69    }
70
71    /// Returns the number of tasks currently in the set.
72    pub fn len(&self) -> usize {
73        self.handles.len()
74    }
75
76    /// Waits for the next task to complete, treating that completion as an error.
77    ///
78    /// This is intended for supervised task sets where every task is expected to run indefinitely.
79    pub async fn join_next_as_error(&mut self) -> anyhow::Result<()> {
80        let Some((task, result)) = self.join_next().await else {
81            anyhow::bail!("task set is empty");
82        };
83
84        match result {
85            Ok(Ok(())) => anyhow::bail!("task {task} completed unexpectedly"),
86            Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed")),
87            Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
88        }
89    }
90}