Skip to main content

dapr_durabletask/task/
when_all.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use crate::api::DurableTaskError;
6
7use super::completable_task::{CompletableTask, TaskResult};
8
9/// A future that completes when all tasks complete, or fails if any task fails.
10/// Returns a `Vec` of JSON-serialised results on success.
11pub struct WhenAllTask {
12    tasks: Vec<CompletableTask>,
13}
14
15impl WhenAllTask {
16    pub fn new(tasks: Vec<CompletableTask>) -> Self {
17        Self { tasks }
18    }
19}
20
21impl Future for WhenAllTask {
22    type Output = crate::api::Result<Vec<Option<String>>>;
23
24    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
25        let this = self.get_mut();
26
27        // Single pass: poll each task, short-circuit on failure, track readiness.
28        let mut all_complete = true;
29        for task in &mut this.tasks {
30            match Pin::new(task).poll(cx) {
31                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
32                Poll::Ready(Ok(_)) => {}
33                Poll::Pending => {
34                    all_complete = false;
35                }
36            }
37        }
38
39        if all_complete {
40            let results: crate::api::Result<Vec<Option<String>>> = this
41                .tasks
42                .iter()
43                .map(|t| match t.get_result() {
44                    Some(TaskResult::Completed(v)) => Ok(v),
45                    Some(TaskResult::Failed(d)) => Err(DurableTaskError::TaskFailed {
46                        message: d.message.clone(),
47                        failure_details: Some(d),
48                    }),
49                    None => Err(DurableTaskError::Other(
50                        "internal error: task state inconsistency in when_all".to_string(),
51                    )),
52                })
53                .collect();
54            Poll::Ready(results)
55        } else {
56            Poll::Pending
57        }
58    }
59}
60
61/// Wait for all tasks to complete. Fails if any task fails.
62pub fn when_all(tasks: Vec<CompletableTask>) -> WhenAllTask {
63    WhenAllTask::new(tasks)
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69    use crate::api::FailureDetails;
70    use std::task::Waker;
71
72    fn noop_waker() -> Waker {
73        Waker::noop().clone()
74    }
75
76    #[test]
77    fn test_when_all_empty() {
78        let waker = noop_waker();
79        let mut cx = Context::from_waker(&waker);
80        let mut fut = when_all(vec![]);
81        match Pin::new(&mut fut).poll(&mut cx) {
82            Poll::Ready(Ok(results)) => assert!(results.is_empty()),
83            other => panic!("expected Ready(Ok([])), got {:?}", other),
84        }
85    }
86
87    #[test]
88    fn test_when_all_all_complete() {
89        let t1 = CompletableTask::new();
90        let t2 = CompletableTask::new();
91        t1.complete(Some("1".to_string()));
92        t2.complete(Some("2".to_string()));
93
94        let waker = noop_waker();
95        let mut cx = Context::from_waker(&waker);
96        let mut fut = when_all(vec![t1, t2]);
97        match Pin::new(&mut fut).poll(&mut cx) {
98            Poll::Ready(Ok(results)) => {
99                assert_eq!(results.len(), 2);
100                assert_eq!(results[0], Some("1".to_string()));
101                assert_eq!(results[1], Some("2".to_string()));
102            }
103            other => panic!("expected Ready(Ok), got {:?}", other),
104        }
105    }
106
107    #[test]
108    fn test_when_all_pending_then_complete() {
109        let t1 = CompletableTask::new();
110        let t2 = CompletableTask::new();
111        t1.complete(Some("1".to_string()));
112
113        let waker = noop_waker();
114        let mut cx = Context::from_waker(&waker);
115        let mut fut = when_all(vec![t1, t2.clone()]);
116        assert!(Pin::new(&mut fut).poll(&mut cx).is_pending());
117
118        t2.complete(Some("2".to_string()));
119        match Pin::new(&mut fut).poll(&mut cx) {
120            Poll::Ready(Ok(results)) => assert_eq!(results.len(), 2),
121            other => panic!("expected Ready(Ok), got {:?}", other),
122        }
123    }
124
125    #[test]
126    fn test_when_all_fails_on_any_failure() {
127        let t1 = CompletableTask::new();
128        let t2 = CompletableTask::new();
129        t1.complete(Some("1".to_string()));
130        t2.fail(FailureDetails {
131            message: "boom".to_string(),
132            error_type: "Error".to_string(),
133            stack_trace: None,
134        });
135
136        let waker = noop_waker();
137        let mut cx = Context::from_waker(&waker);
138        let mut fut = when_all(vec![t1, t2]);
139        match Pin::new(&mut fut).poll(&mut cx) {
140            Poll::Ready(Err(DurableTaskError::TaskFailed { message, .. })) => {
141                assert_eq!(message, "boom");
142            }
143            other => panic!("expected Ready(Err(TaskFailed)), got {:?}", other),
144        }
145    }
146}