miden_node_utils/
tasks.rs1use std::collections::HashMap;
2use std::future::Future;
3
4use anyhow::Context;
5use tokio::task::{Id, JoinError, JoinSet};
6
7use crate::shutdown::CancellationToken;
8
9pub struct Tasks {
13 handles: JoinSet<anyhow::Result<()>>,
14 names: HashMap<Id, String>,
15}
16
17impl Default for Tasks {
18 fn default() -> Self {
19 Self {
20 handles: JoinSet::new(),
21 names: HashMap::new(),
22 }
23 }
24}
25
26impl Tasks {
27 pub fn new() -> Self {
29 Self::default()
30 }
31
32 pub fn spawn(
34 &mut self,
35 name: impl Into<String>,
36 task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
37 ) -> Id {
38 let id = self.handles.spawn(task).id();
39 self.names.insert(id, name.into());
40 id
41 }
42
43 pub fn spawn_infallible(
45 &mut self,
46 name: impl Into<String>,
47 task: impl Future<Output = ()> + Send + 'static,
48 ) -> Id {
49 self.spawn(name, async move {
50 task.await;
51 Ok(())
52 })
53 }
54
55 pub async fn join_next(&mut self) -> Option<(String, Result<anyhow::Result<()>, JoinError>)> {
57 let result = self.handles.join_next_with_id().await?;
58 let id = match &result {
59 Ok((id, _)) => *id,
60 Err(err) => err.id(),
61 };
62 let name = self.names.remove(&id).unwrap_or_else(|| "unknown".to_string());
63 let result = result.map(|(_, output)| output);
64
65 Some((name, result))
66 }
67
68 pub fn is_empty(&self) -> bool {
70 self.handles.is_empty()
71 }
72
73 pub fn len(&self) -> usize {
75 self.handles.len()
76 }
77
78 pub async fn join_next_as_error(&mut self) -> anyhow::Result<()> {
82 let Some((task, result)) = self.join_next().await else {
83 anyhow::bail!("task set is empty");
84 };
85
86 Self::unexpected_completion(&task, result)
87 }
88
89 pub async fn join_next_or_cancelled(&mut self, token: CancellationToken) -> anyhow::Result<()> {
95 while !token.is_cancelled() {
96 tokio::select! {
97 biased;
98 () = token.cancelled() => break,
99 result = self.join_next() => {
100 let Some((task, result)) = result else {
101 anyhow::bail!("task set is empty");
102 };
103 Self::unexpected_completion(&task, result)?;
104 },
105 }
106 }
107
108 while let Some((task, result)) = self.join_next().await {
109 Self::shutdown_completion(&task, result)?;
110 }
111
112 Ok(())
113 }
114
115 fn unexpected_completion(
116 task: &str,
117 result: Result<anyhow::Result<()>, JoinError>,
118 ) -> anyhow::Result<()> {
119 match result {
120 Ok(Ok(())) => anyhow::bail!("task {task} completed unexpectedly"),
121 Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed")),
122 Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
123 }
124 }
125
126 fn shutdown_completion(
127 task: &str,
128 result: Result<anyhow::Result<()>, JoinError>,
129 ) -> anyhow::Result<()> {
130 match result {
131 Ok(Ok(())) => Ok(()),
132 Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed during shutdown")),
133 Err(err) if err.is_cancelled() => Ok(()),
134 Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use std::time::Duration;
142
143 use super::*;
144
145 #[tokio::test]
146 async fn join_next_or_cancelled_accepts_clean_task_completion_after_cancellation() {
147 let token = crate::shutdown::CancellationToken::new();
148 let mut tasks = Tasks::new();
149 tasks.spawn("worker", {
150 let token = token.clone();
151 async move {
152 token.cancelled().await;
153 Ok(())
154 }
155 });
156
157 token.cancel();
158
159 tasks
160 .join_next_or_cancelled(token)
161 .await
162 .expect("clean shutdown should not be treated as an error");
163 }
164
165 #[tokio::test]
166 async fn join_next_or_cancelled_treats_task_completion_before_cancellation_as_error() {
167 let token = crate::shutdown::CancellationToken::new();
168 let mut tasks = Tasks::new();
169 tasks.spawn("worker", async { Ok(()) });
170
171 let err = tasks
172 .join_next_or_cancelled(token)
173 .await
174 .expect_err("unexpected task completion should fail before shutdown");
175
176 assert_eq!(err.to_string(), "task worker completed unexpectedly");
177 }
178
179 #[tokio::test]
180 async fn join_next_or_cancelled_waits_for_all_tasks_to_complete_after_cancellation() {
181 let token = crate::shutdown::CancellationToken::new();
182 let mut tasks = Tasks::new();
183 tasks.spawn("worker-a", {
184 let token = token.clone();
185 async move {
186 token.cancelled().await;
187 Ok(())
188 }
189 });
190 tasks.spawn("worker-b", {
191 let token = token.clone();
192 async move {
193 token.cancelled().await;
194 tokio::time::sleep(Duration::from_millis(10)).await;
195 Ok(())
196 }
197 });
198
199 token.cancel();
200
201 tasks
202 .join_next_or_cancelled(token)
203 .await
204 .expect("shutdown should wait for all clean task exits");
205 assert!(tasks.is_empty());
206 }
207}