Skip to main content

altair_concurrent/
executor.rs

1//! Concurrent execution entry point.
2
3use crate::error::{Error, Result};
4use crate::task_map::TaskMap;
5use std::collections::HashMap;
6use std::time::Duration;
7use tokio::task::JoinSet;
8use tokio_util::sync::CancellationToken;
9use tracing::{Instrument, instrument};
10
11/// Boxed task error returned in partial-results mode.
12pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
13type TaskOutcome<T> = (&'static str, std::result::Result<T, BoxedError>);
14
15/// Per-task result map returned by [`PartialExecutor`].
16pub type PartialResults<T> = HashMap<&'static str, std::result::Result<T, BoxedError>>;
17
18/// Fail-fast executor. Constructed by [`execute_concurrently`].
19///
20/// On the first task error: cancels the rest and returns [`Error::TaskFailed`].
21pub struct Executor<T> {
22    tasks: TaskMap<T>,
23    cancellation: Option<CancellationToken>,
24    timeout: Option<Duration>,
25}
26
27impl<T> Executor<T>
28where
29    T: Send + 'static,
30{
31    /// Attach a cancellation token. Cancelling it requests all running
32    /// tasks to stop.
33    ///
34    /// # Cooperative cancellation
35    ///
36    /// Each task receives the token via the closure argument and is
37    /// responsible for `.cancelled().await`-ing it. Tasks that ignore
38    /// the token will not be interrupted — `JoinSet::shutdown` aborts
39    /// their `JoinHandle`s, but a CPU-bound task that never yields cannot
40    /// be preempted. Design tasks to check the token at await points.
41    #[must_use]
42    pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
43        self.cancellation = Some(token);
44        self
45    }
46
47    /// Apply an overall timeout. If the timeout elapses, [`Error::Timeout`]
48    /// is returned and the internal cancellation token is signalled so
49    /// remaining tasks observe cancellation. Tasks that ignore the token
50    /// continue running until they yield (see [`Self::with_cancellation`]).
51    #[must_use]
52    pub fn with_timeout(mut self, timeout: Duration) -> Self {
53        self.timeout = Some(timeout);
54        self
55    }
56
57    /// Switch to partial-results mode: every task is awaited; each task's
58    /// `Result` appears in the returned [`PartialResults`] map.
59    ///
60    /// The outer [`Result`] still reports infrastructure errors ([`Error::Timeout`],
61    /// [`Error::Cancelled`], [`Error::Join`]).
62    #[must_use]
63    pub fn with_partial_results(self) -> PartialExecutor<T> {
64        PartialExecutor {
65            tasks: self.tasks,
66            cancellation: self.cancellation,
67            timeout: self.timeout,
68        }
69    }
70}
71
72impl<T> std::future::IntoFuture for Executor<T>
73where
74    T: Send + 'static,
75{
76    type Output = Result<HashMap<&'static str, T>>;
77    type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
78
79    fn into_future(self) -> Self::IntoFuture {
80        Box::pin(async move { run_fail_fast(self).await })
81    }
82}
83
84/// Partial-results executor — every task runs to completion; per-task
85/// success/failure exposed in the returned map.
86///
87/// Constructed via [`Executor::with_partial_results`].
88pub struct PartialExecutor<T> {
89    tasks: TaskMap<T>,
90    cancellation: Option<CancellationToken>,
91    timeout: Option<Duration>,
92}
93
94impl<T> PartialExecutor<T>
95where
96    T: Send + 'static,
97{
98    /// Attach a cancellation token. Cancelling it causes all tasks to abort.
99    #[must_use]
100    pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
101        self.cancellation = Some(token);
102        self
103    }
104
105    /// Apply an overall timeout.
106    #[must_use]
107    pub fn with_timeout(mut self, timeout: Duration) -> Self {
108        self.timeout = Some(timeout);
109        self
110    }
111}
112
113impl<T> std::future::IntoFuture for PartialExecutor<T>
114where
115    T: Send + 'static,
116{
117    type Output = Result<PartialResults<T>>;
118    type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
119
120    fn into_future(self) -> Self::IntoFuture {
121        Box::pin(async move { run_partial(self).await })
122    }
123}
124
125fn spawn_tasks<T>(tasks: TaskMap<T>, token: &CancellationToken) -> JoinSet<TaskOutcome<T>>
126where
127    T: Send + 'static,
128{
129    let mut set: JoinSet<TaskOutcome<T>> = JoinSet::new();
130    for (name, task_fn) in tasks.tasks {
131        let child_token = token.clone();
132        let span = tracing::info_span!("concurrent.task", task.name = name);
133        set.spawn(
134            async move {
135                let result = task_fn(child_token).await;
136                (name, result)
137            }
138            .instrument(span),
139        );
140    }
141    set
142}
143
144#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
145async fn run_fail_fast<T>(executor: Executor<T>) -> Result<HashMap<&'static str, T>>
146where
147    T: Send + 'static,
148{
149    let token = executor.cancellation.unwrap_or_default();
150    let mut set = spawn_tasks(executor.tasks, &token);
151    let mut results: HashMap<&'static str, T> = HashMap::new();
152    let timeout = executor.timeout;
153
154    loop {
155        let outcome = next_outcome(&mut set, &token, timeout).await?;
156        match outcome {
157            None => break,
158            Some((name, Ok(v))) => {
159                results.insert(name, v);
160            }
161            Some((name, Err(e))) => {
162                token.cancel();
163                set.shutdown().await;
164                return Err(Error::TaskFailed { name, source: e });
165            }
166        }
167        if token.is_cancelled() && set.is_empty() {
168            return Err(Error::Cancelled);
169        }
170    }
171
172    Ok(results)
173}
174
175#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
176async fn run_partial<T>(executor: PartialExecutor<T>) -> Result<PartialResults<T>>
177where
178    T: Send + 'static,
179{
180    let token = executor.cancellation.unwrap_or_default();
181    let mut set = spawn_tasks(executor.tasks, &token);
182    let mut results: PartialResults<T> = HashMap::new();
183    let timeout = executor.timeout;
184
185    loop {
186        let outcome = next_outcome(&mut set, &token, timeout).await?;
187        match outcome {
188            None => break,
189            Some((name, result)) => {
190                results.insert(name, result);
191            }
192        }
193    }
194
195    Ok(results)
196}
197
198async fn next_outcome<T>(
199    set: &mut JoinSet<TaskOutcome<T>>,
200    token: &CancellationToken,
201    timeout: Option<Duration>,
202) -> Result<Option<TaskOutcome<T>>>
203where
204    T: Send + 'static,
205{
206    let next = async { set.join_next().await };
207    let raw = if let Some(d) = timeout {
208        if let Ok(v) = tokio::time::timeout(d, next).await {
209            v
210        } else {
211            token.cancel();
212            set.shutdown().await;
213            return Err(Error::Timeout);
214        }
215    } else {
216        next.await
217    };
218
219    match raw {
220        None => Ok(None),
221        Some(Ok(outcome)) => Ok(Some(outcome)),
222        Some(Err(e)) => Err(Error::Join(e)),
223    }
224}
225
226/// Run a [`TaskMap`] concurrently in fail-fast mode.
227///
228/// Returns an [`Executor`] that resolves to a `HashMap<&'static str, T>` when awaited.
229/// Call [`Executor::with_partial_results`] to switch to "run all, return per-task results" mode.
230#[must_use]
231pub fn execute_concurrently<T>(tasks: TaskMap<T>) -> Executor<T> {
232    Executor {
233        tasks,
234        cancellation: None,
235        timeout: None,
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use pretty_assertions::assert_eq;
243
244    #[tokio::test]
245    async fn empty_map_resolves_to_empty_results() {
246        let m: TaskMap<u32> = TaskMap::new();
247        let r = execute_concurrently(m).await.unwrap();
248        assert!(r.is_empty());
249    }
250
251    #[tokio::test]
252    async fn two_tasks_complete() {
253        let m: TaskMap<u32> = TaskMap::new()
254            .insert("a", |_| async { Ok::<_, std::io::Error>(1) })
255            .insert("b", |_| async { Ok::<_, std::io::Error>(2) });
256        let r = execute_concurrently(m).await.unwrap();
257        assert_eq!(r["a"], 1);
258        assert_eq!(r["b"], 2);
259    }
260
261    #[tokio::test]
262    async fn failing_task_returns_task_failed_error() {
263        let m: TaskMap<u32> = TaskMap::new()
264            .insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
265            .insert("bad", |_| async {
266                Err::<u32, std::io::Error>(std::io::Error::other("boom"))
267            });
268        let err = execute_concurrently(m).await.unwrap_err();
269        match err {
270            Error::TaskFailed { name, .. } => assert_eq!(name, "bad"),
271            other => panic!("expected TaskFailed, got {other:?}"),
272        }
273    }
274
275    #[tokio::test]
276    async fn timeout_returns_timeout_error() {
277        let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
278            tokio::time::sleep(Duration::from_secs(10)).await;
279            Ok::<_, std::io::Error>(1)
280        });
281        let err = execute_concurrently(m)
282            .with_timeout(Duration::from_millis(50))
283            .await
284            .unwrap_err();
285        assert!(matches!(err, Error::Timeout));
286    }
287
288    #[tokio::test]
289    async fn external_cancellation_causes_cancelled_error() {
290        let token = CancellationToken::new();
291        let inner = token.clone();
292        let m: TaskMap<u32> = TaskMap::new().insert("waiter", move |ct| async move {
293            ct.cancelled().await;
294            Err::<u32, std::io::Error>(std::io::Error::other("cancelled"))
295        });
296        let handle =
297            tokio::spawn(async move { execute_concurrently(m).with_cancellation(token).await });
298        tokio::time::sleep(Duration::from_millis(20)).await;
299        inner.cancel();
300        let err = handle.await.unwrap().unwrap_err();
301        // Either TaskFailed or Cancelled is acceptable depending on order.
302        assert!(matches!(err, Error::TaskFailed { .. } | Error::Cancelled));
303    }
304
305    #[tokio::test]
306    async fn partial_results_returns_per_task_results() {
307        let m: TaskMap<u32> = TaskMap::new()
308            .insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
309            .insert("bad", |_| async {
310                Err::<u32, std::io::Error>(std::io::Error::other("boom"))
311            })
312            .insert("also_ok", |_| async { Ok::<_, std::io::Error>(2) });
313        let r = execute_concurrently(m)
314            .with_partial_results()
315            .await
316            .unwrap();
317        assert_eq!(r.len(), 3);
318        assert!(r["ok"].is_ok());
319        assert!(r["bad"].is_err());
320        assert!(r["also_ok"].is_ok());
321    }
322
323    #[tokio::test]
324    async fn partial_timeout_still_propagates() {
325        let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
326            tokio::time::sleep(Duration::from_secs(10)).await;
327            Ok::<_, std::io::Error>(1)
328        });
329        let err = execute_concurrently(m)
330            .with_partial_results()
331            .with_timeout(Duration::from_millis(20))
332            .await
333            .unwrap_err();
334        assert!(matches!(err, Error::Timeout));
335    }
336}