Skip to main content

dapr_durabletask/task/
when_all.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use crate::api::DurableTaskError;
6
7use super::completable_task::{CompletableTask, TaskResult};
8
9/// A future that completes when all tasks complete, or fails if any task fails.
10/// Returns a `Vec` of JSON-serialised results on success.
11pub struct WhenAllTask {
12    pub(crate) tasks: Vec<CompletableTask>,
13}
14
15impl Future for WhenAllTask {
16    type Output = crate::api::Result<Vec<Option<String>>>;
17
18    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
19        let this = self.get_mut();
20
21        // Single pass: poll each task, short-circuit on failure, track readiness.
22        let mut all_complete = true;
23        for task in &mut this.tasks {
24            match Pin::new(task).poll(cx) {
25                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
26                Poll::Ready(Ok(_)) => {}
27                Poll::Pending => {
28                    all_complete = false;
29                }
30            }
31        }
32
33        if all_complete {
34            let results: crate::api::Result<Vec<Option<String>>> = this
35                .tasks
36                .iter()
37                .map(|t| match t.get_result() {
38                    Some(TaskResult::Completed(v)) => Ok(v),
39                    Some(TaskResult::Failed(d)) => Err(DurableTaskError::TaskFailed {
40                        message: d.message.clone(),
41                        failure_details: Some(d),
42                    }),
43                    None => Err(DurableTaskError::Other(
44                        "internal error: task state inconsistency in when_all".to_string(),
45                    )),
46                })
47                .collect();
48            Poll::Ready(results)
49        } else {
50            Poll::Pending
51        }
52    }
53}
54
55/// Wait for all tasks to complete. Fails if any task fails.
56pub fn when_all(tasks: Vec<CompletableTask>) -> WhenAllTask {
57    WhenAllTask { tasks }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63    use crate::api::FailureDetails;
64    use std::task::Waker;
65
66    fn noop_waker() -> Waker {
67        Waker::noop().clone()
68    }
69
70    #[test]
71    fn test_when_all_empty() {
72        let waker = noop_waker();
73        let mut cx = Context::from_waker(&waker);
74        let mut fut = when_all(vec![]);
75        match Pin::new(&mut fut).poll(&mut cx) {
76            Poll::Ready(Ok(results)) => assert!(results.is_empty()),
77            other => panic!("expected Ready(Ok([])), got {other:?}"),
78        }
79    }
80
81    #[test]
82    fn test_when_all_all_complete() {
83        let t1 = CompletableTask::new();
84        let t2 = CompletableTask::new();
85        t1.complete(Some("1".to_string()));
86        t2.complete(Some("2".to_string()));
87
88        let waker = noop_waker();
89        let mut cx = Context::from_waker(&waker);
90        let mut fut = when_all(vec![t1, t2]);
91        match Pin::new(&mut fut).poll(&mut cx) {
92            Poll::Ready(Ok(results)) => {
93                assert_eq!(results.len(), 2);
94                assert_eq!(results[0], Some("1".to_string()));
95                assert_eq!(results[1], Some("2".to_string()));
96            }
97            other => panic!("expected Ready(Ok), got {other:?}"),
98        }
99    }
100
101    #[test]
102    fn test_when_all_pending_then_complete() {
103        let t1 = CompletableTask::new();
104        let t2 = CompletableTask::new();
105        t1.complete(Some("1".to_string()));
106
107        let waker = noop_waker();
108        let mut cx = Context::from_waker(&waker);
109        let mut fut = when_all(vec![t1, t2.clone()]);
110        assert!(Pin::new(&mut fut).poll(&mut cx).is_pending());
111
112        t2.complete(Some("2".to_string()));
113        match Pin::new(&mut fut).poll(&mut cx) {
114            Poll::Ready(Ok(results)) => assert_eq!(results.len(), 2),
115            other => panic!("expected Ready(Ok), got {other:?}"),
116        }
117    }
118
119    #[test]
120    fn test_when_all_fails_on_any_failure() {
121        let t1 = CompletableTask::new();
122        let t2 = CompletableTask::new();
123        t1.complete(Some("1".to_string()));
124        t2.fail(FailureDetails {
125            message: "boom".to_string(),
126            error_type: "Error".to_string(),
127            stack_trace: None,
128        });
129
130        let waker = noop_waker();
131        let mut cx = Context::from_waker(&waker);
132        let mut fut = when_all(vec![t1, t2]);
133        match Pin::new(&mut fut).poll(&mut cx) {
134            Poll::Ready(Err(DurableTaskError::TaskFailed { message, .. })) => {
135                assert_eq!(message, "boom");
136            }
137            other => panic!("expected Ready(Err(TaskFailed)), got {other:?}"),
138        }
139    }
140}