Skip to main content

miden_node_utils/
tasks.rs

1use std::collections::HashMap;
2use std::future::Future;
3
4use anyhow::Context;
5use tokio::task::{Id, JoinError, JoinSet};
6
7use crate::shutdown::CancellationToken;
8
9/// A named task set for supervising concurrently-running Tokio tasks.
10///
11/// Dropping a task set aborts all tasks that are still running.
12pub struct Tasks {
13    handles: JoinSet<anyhow::Result<()>>,
14    names: HashMap<Id, String>,
15}
16
17impl Default for Tasks {
18    fn default() -> Self {
19        Self {
20            handles: JoinSet::new(),
21            names: HashMap::new(),
22        }
23    }
24}
25
26impl Tasks {
27    /// Creates an empty task set.
28    pub fn new() -> Self {
29        Self::default()
30    }
31
32    /// Spawns a named task into the set.
33    pub fn spawn(
34        &mut self,
35        name: impl Into<String>,
36        task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
37    ) -> Id {
38        let id = self.handles.spawn(task).id();
39        self.names.insert(id, name.into());
40        id
41    }
42
43    /// Spawns a named task that does not return an error.
44    pub fn spawn_infallible(
45        &mut self,
46        name: impl Into<String>,
47        task: impl Future<Output = ()> + Send + 'static,
48    ) -> Id {
49        self.spawn(name, async move {
50            task.await;
51            Ok(())
52        })
53    }
54
55    /// Waits for the next task to complete.
56    pub async fn join_next(&mut self) -> Option<(String, Result<anyhow::Result<()>, JoinError>)> {
57        let result = self.handles.join_next_with_id().await?;
58        let id = match &result {
59            Ok((id, _)) => *id,
60            Err(err) => err.id(),
61        };
62        let name = self.names.remove(&id).unwrap_or_else(|| "unknown".to_string());
63        let result = result.map(|(_, output)| output);
64
65        Some((name, result))
66    }
67
68    /// Returns `true` if no tasks are currently in the set.
69    pub fn is_empty(&self) -> bool {
70        self.handles.is_empty()
71    }
72
73    /// Returns the number of tasks currently in the set.
74    pub fn len(&self) -> usize {
75        self.handles.len()
76    }
77
78    /// Waits for the next task to complete, treating that completion as an error.
79    ///
80    /// This is intended for supervised task sets where every task is expected to run indefinitely.
81    pub async fn join_next_as_error(&mut self) -> anyhow::Result<()> {
82        let Some((task, result)) = self.join_next().await else {
83            anyhow::bail!("task set is empty");
84        };
85
86        Self::unexpected_completion(&task, result)
87    }
88
89    /// Waits for either an unexpected task completion or a shutdown request.
90    ///
91    /// Before shutdown, any task completion is treated as fatal because this type supervises
92    /// long-running tasks. Once `token` is cancelled, clean task exits are accepted and this method
93    /// waits for all tracked tasks to finish.
94    pub async fn join_next_or_cancelled(&mut self, token: CancellationToken) -> anyhow::Result<()> {
95        while !token.is_cancelled() {
96            tokio::select! {
97                biased;
98                () = token.cancelled() => break,
99                result = self.join_next() => {
100                    let Some((task, result)) = result else {
101                        anyhow::bail!("task set is empty");
102                    };
103                    Self::unexpected_completion(&task, result)?;
104                },
105            }
106        }
107
108        while let Some((task, result)) = self.join_next().await {
109            Self::shutdown_completion(&task, result)?;
110        }
111
112        Ok(())
113    }
114
115    fn unexpected_completion(
116        task: &str,
117        result: Result<anyhow::Result<()>, JoinError>,
118    ) -> anyhow::Result<()> {
119        match result {
120            Ok(Ok(())) => anyhow::bail!("task {task} completed unexpectedly"),
121            Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed")),
122            Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
123        }
124    }
125
126    fn shutdown_completion(
127        task: &str,
128        result: Result<anyhow::Result<()>, JoinError>,
129    ) -> anyhow::Result<()> {
130        match result {
131            Ok(Ok(())) => Ok(()),
132            Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed during shutdown")),
133            Err(err) if err.is_cancelled() => Ok(()),
134            Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use std::time::Duration;
142
143    use super::*;
144
145    #[tokio::test]
146    async fn join_next_or_cancelled_accepts_clean_task_completion_after_cancellation() {
147        let token = crate::shutdown::CancellationToken::new();
148        let mut tasks = Tasks::new();
149        tasks.spawn("worker", {
150            let token = token.clone();
151            async move {
152                token.cancelled().await;
153                Ok(())
154            }
155        });
156
157        token.cancel();
158
159        tasks
160            .join_next_or_cancelled(token)
161            .await
162            .expect("clean shutdown should not be treated as an error");
163    }
164
165    #[tokio::test]
166    async fn join_next_or_cancelled_treats_task_completion_before_cancellation_as_error() {
167        let token = crate::shutdown::CancellationToken::new();
168        let mut tasks = Tasks::new();
169        tasks.spawn("worker", async { Ok(()) });
170
171        let err = tasks
172            .join_next_or_cancelled(token)
173            .await
174            .expect_err("unexpected task completion should fail before shutdown");
175
176        assert_eq!(err.to_string(), "task worker completed unexpectedly");
177    }
178
179    #[tokio::test]
180    async fn join_next_or_cancelled_waits_for_all_tasks_to_complete_after_cancellation() {
181        let token = crate::shutdown::CancellationToken::new();
182        let mut tasks = Tasks::new();
183        tasks.spawn("worker-a", {
184            let token = token.clone();
185            async move {
186                token.cancelled().await;
187                Ok(())
188            }
189        });
190        tasks.spawn("worker-b", {
191            let token = token.clone();
192            async move {
193                token.cancelled().await;
194                tokio::time::sleep(Duration::from_millis(10)).await;
195                Ok(())
196            }
197        });
198
199        token.cancel();
200
201        tasks
202            .join_next_or_cancelled(token)
203            .await
204            .expect("shutdown should wait for all clean task exits");
205        assert!(tasks.is_empty());
206    }
207}