1use crate::error::{Error, Result};
4use crate::task_map::TaskMap;
5use std::collections::HashMap;
6use std::time::Duration;
7use tokio::task::JoinSet;
8use tokio_util::sync::CancellationToken;
9use tracing::{Instrument, instrument};
10
11pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
13type TaskOutcome<T> = (&'static str, std::result::Result<T, BoxedError>);
14
15pub type PartialResults<T> = HashMap<&'static str, std::result::Result<T, BoxedError>>;
17
18pub struct Executor<T> {
22 tasks: TaskMap<T>,
23 cancellation: Option<CancellationToken>,
24 timeout: Option<Duration>,
25}
26
27impl<T> Executor<T>
28where
29 T: Send + 'static,
30{
31 #[must_use]
42 pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
43 self.cancellation = Some(token);
44 self
45 }
46
47 #[must_use]
52 pub fn with_timeout(mut self, timeout: Duration) -> Self {
53 self.timeout = Some(timeout);
54 self
55 }
56
57 #[must_use]
63 pub fn with_partial_results(self) -> PartialExecutor<T> {
64 PartialExecutor {
65 tasks: self.tasks,
66 cancellation: self.cancellation,
67 timeout: self.timeout,
68 }
69 }
70}
71
72impl<T> std::future::IntoFuture for Executor<T>
73where
74 T: Send + 'static,
75{
76 type Output = Result<HashMap<&'static str, T>>;
77 type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
78
79 fn into_future(self) -> Self::IntoFuture {
80 Box::pin(async move { run_fail_fast(self).await })
81 }
82}
83
84pub struct PartialExecutor<T> {
89 tasks: TaskMap<T>,
90 cancellation: Option<CancellationToken>,
91 timeout: Option<Duration>,
92}
93
94impl<T> PartialExecutor<T>
95where
96 T: Send + 'static,
97{
98 #[must_use]
100 pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
101 self.cancellation = Some(token);
102 self
103 }
104
105 #[must_use]
107 pub fn with_timeout(mut self, timeout: Duration) -> Self {
108 self.timeout = Some(timeout);
109 self
110 }
111}
112
113impl<T> std::future::IntoFuture for PartialExecutor<T>
114where
115 T: Send + 'static,
116{
117 type Output = Result<PartialResults<T>>;
118 type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
119
120 fn into_future(self) -> Self::IntoFuture {
121 Box::pin(async move { run_partial(self).await })
122 }
123}
124
125fn spawn_tasks<T>(tasks: TaskMap<T>, token: &CancellationToken) -> JoinSet<TaskOutcome<T>>
126where
127 T: Send + 'static,
128{
129 let mut set: JoinSet<TaskOutcome<T>> = JoinSet::new();
130 for (name, task_fn) in tasks.tasks {
131 let child_token = token.clone();
132 let span = tracing::info_span!("concurrent.task", task.name = name);
133 set.spawn(
134 async move {
135 let result = task_fn(child_token).await;
136 (name, result)
137 }
138 .instrument(span),
139 );
140 }
141 set
142}
143
144#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
145async fn run_fail_fast<T>(executor: Executor<T>) -> Result<HashMap<&'static str, T>>
146where
147 T: Send + 'static,
148{
149 let token = executor.cancellation.unwrap_or_default();
150 let mut set = spawn_tasks(executor.tasks, &token);
151 let mut results: HashMap<&'static str, T> = HashMap::new();
152 let timeout = executor.timeout;
153
154 loop {
155 let outcome = next_outcome(&mut set, &token, timeout).await?;
156 match outcome {
157 None => break,
158 Some((name, Ok(v))) => {
159 results.insert(name, v);
160 }
161 Some((name, Err(e))) => {
162 token.cancel();
163 set.shutdown().await;
164 return Err(Error::TaskFailed { name, source: e });
165 }
166 }
167 if token.is_cancelled() && set.is_empty() {
168 return Err(Error::Cancelled);
169 }
170 }
171
172 Ok(results)
173}
174
175#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
176async fn run_partial<T>(executor: PartialExecutor<T>) -> Result<PartialResults<T>>
177where
178 T: Send + 'static,
179{
180 let token = executor.cancellation.unwrap_or_default();
181 let mut set = spawn_tasks(executor.tasks, &token);
182 let mut results: PartialResults<T> = HashMap::new();
183 let timeout = executor.timeout;
184
185 loop {
186 let outcome = next_outcome(&mut set, &token, timeout).await?;
187 match outcome {
188 None => break,
189 Some((name, result)) => {
190 results.insert(name, result);
191 }
192 }
193 }
194
195 Ok(results)
196}
197
198async fn next_outcome<T>(
199 set: &mut JoinSet<TaskOutcome<T>>,
200 token: &CancellationToken,
201 timeout: Option<Duration>,
202) -> Result<Option<TaskOutcome<T>>>
203where
204 T: Send + 'static,
205{
206 let next = async { set.join_next().await };
207 let raw = if let Some(d) = timeout {
208 if let Ok(v) = tokio::time::timeout(d, next).await {
209 v
210 } else {
211 token.cancel();
212 set.shutdown().await;
213 return Err(Error::Timeout);
214 }
215 } else {
216 next.await
217 };
218
219 match raw {
220 None => Ok(None),
221 Some(Ok(outcome)) => Ok(Some(outcome)),
222 Some(Err(e)) => Err(Error::Join(e)),
223 }
224}
225
226#[must_use]
231pub fn execute_concurrently<T>(tasks: TaskMap<T>) -> Executor<T> {
232 Executor {
233 tasks,
234 cancellation: None,
235 timeout: None,
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use pretty_assertions::assert_eq;
243
244 #[tokio::test]
245 async fn empty_map_resolves_to_empty_results() {
246 let m: TaskMap<u32> = TaskMap::new();
247 let r = execute_concurrently(m).await.unwrap();
248 assert!(r.is_empty());
249 }
250
251 #[tokio::test]
252 async fn two_tasks_complete() {
253 let m: TaskMap<u32> = TaskMap::new()
254 .insert("a", |_| async { Ok::<_, std::io::Error>(1) })
255 .insert("b", |_| async { Ok::<_, std::io::Error>(2) });
256 let r = execute_concurrently(m).await.unwrap();
257 assert_eq!(r["a"], 1);
258 assert_eq!(r["b"], 2);
259 }
260
261 #[tokio::test]
262 async fn failing_task_returns_task_failed_error() {
263 let m: TaskMap<u32> = TaskMap::new()
264 .insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
265 .insert("bad", |_| async {
266 Err::<u32, std::io::Error>(std::io::Error::other("boom"))
267 });
268 let err = execute_concurrently(m).await.unwrap_err();
269 match err {
270 Error::TaskFailed { name, .. } => assert_eq!(name, "bad"),
271 other => panic!("expected TaskFailed, got {other:?}"),
272 }
273 }
274
275 #[tokio::test]
276 async fn timeout_returns_timeout_error() {
277 let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
278 tokio::time::sleep(Duration::from_secs(10)).await;
279 Ok::<_, std::io::Error>(1)
280 });
281 let err = execute_concurrently(m)
282 .with_timeout(Duration::from_millis(50))
283 .await
284 .unwrap_err();
285 assert!(matches!(err, Error::Timeout));
286 }
287
288 #[tokio::test]
289 async fn external_cancellation_causes_cancelled_error() {
290 let token = CancellationToken::new();
291 let inner = token.clone();
292 let m: TaskMap<u32> = TaskMap::new().insert("waiter", move |ct| async move {
293 ct.cancelled().await;
294 Err::<u32, std::io::Error>(std::io::Error::other("cancelled"))
295 });
296 let handle =
297 tokio::spawn(async move { execute_concurrently(m).with_cancellation(token).await });
298 tokio::time::sleep(Duration::from_millis(20)).await;
299 inner.cancel();
300 let err = handle.await.unwrap().unwrap_err();
301 assert!(matches!(err, Error::TaskFailed { .. } | Error::Cancelled));
303 }
304
305 #[tokio::test]
306 async fn partial_results_returns_per_task_results() {
307 let m: TaskMap<u32> = TaskMap::new()
308 .insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
309 .insert("bad", |_| async {
310 Err::<u32, std::io::Error>(std::io::Error::other("boom"))
311 })
312 .insert("also_ok", |_| async { Ok::<_, std::io::Error>(2) });
313 let r = execute_concurrently(m)
314 .with_partial_results()
315 .await
316 .unwrap();
317 assert_eq!(r.len(), 3);
318 assert!(r["ok"].is_ok());
319 assert!(r["bad"].is_err());
320 assert!(r["also_ok"].is_ok());
321 }
322
323 #[tokio::test]
324 async fn partial_timeout_still_propagates() {
325 let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
326 tokio::time::sleep(Duration::from_secs(10)).await;
327 Ok::<_, std::io::Error>(1)
328 });
329 let err = execute_concurrently(m)
330 .with_partial_results()
331 .with_timeout(Duration::from_millis(20))
332 .await
333 .unwrap_err();
334 assert!(matches!(err, Error::Timeout));
335 }
336}