azure_functions/durable/
join_all.rs

1use crate::durable::{OrchestrationFuture, OrchestrationState};
2use futures::future::{join_all, FutureExt};
3use std::{
4    cell::RefCell,
5    future::Future,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11/// Future for the `DurableOrchestrationContext::join_all` function.
12#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct JoinAll<F>
14where
15    F: OrchestrationFuture,
16{
17    inner: futures::future::JoinAll<F>,
18    state: Rc<RefCell<OrchestrationState>>,
19    event_index: Option<usize>,
20    is_inner: bool,
21}
22
23impl<F> JoinAll<F>
24where
25    F: OrchestrationFuture,
26{
27    pub(crate) fn new<T>(state: Rc<RefCell<OrchestrationState>>, iter: T) -> Self
28    where
29        T: IntoIterator<Item = F>,
30        F: OrchestrationFuture,
31    {
32        let inner: Vec<_> = iter
33            .into_iter()
34            .map(|mut f| {
35                f.notify_inner();
36                f
37            })
38            .collect();
39
40        // The event index of a join is the maximum of the sequence, provided all inner futures have event indexes
41        let event_index = inner
42            .iter()
43            .try_fold(None, |i, f| {
44                let next = f.event_index();
45                if next.is_none() {
46                    Err(())
47                } else if i < next {
48                    Ok(next)
49                } else {
50                    Ok(i)
51                }
52            })
53            .unwrap_or(None);
54
55        JoinAll {
56            inner: join_all(inner),
57            state,
58            event_index,
59            is_inner: false,
60        }
61    }
62}
63
64impl<F> Future for JoinAll<F>
65where
66    F: OrchestrationFuture,
67{
68    type Output = Vec<F::Output>;
69
70    fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
71        let result = self.inner.poll_unpin(context);
72
73        if !self.is_inner {
74            if let Poll::Ready(_) = &result {
75                self.state.borrow_mut().update(self.event_index.unwrap());
76            }
77        }
78
79        result
80    }
81}
82
83impl<F> OrchestrationFuture for JoinAll<F>
84where
85    F: OrchestrationFuture,
86{
87    fn notify_inner(&mut self) {
88        self.is_inner = true;
89    }
90
91    fn event_index(&self) -> Option<usize> {
92        self.event_index
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::durable::{
100        tests::{create_event, poll},
101        ActionFuture, EventType,
102    };
103    use serde_json::{from_str, json};
104    use std::task::Poll;
105
106    #[test]
107    fn it_polls_pending_without_a_result() {
108        let history = vec![create_event(
109            EventType::OrchestratorStarted,
110            -1,
111            None,
112            None,
113            None,
114        )];
115
116        let state = Rc::new(RefCell::new(OrchestrationState::new(history)));
117        let future1 = ActionFuture::<()>::new(None, state.clone(), None);
118        let future2 = ActionFuture::<()>::new(None, state.clone(), None);
119        let join = JoinAll::new(state.clone(), vec![future1, future2]);
120
121        assert_eq!(join.event_index(), None);
122        assert_eq!(poll(join), Poll::Pending);
123    }
124
125    #[test]
126    fn it_polls_pending_with_a_result() {
127        let history = vec![
128            create_event(EventType::OrchestratorStarted, -1, None, None, None),
129            create_event(
130                EventType::TaskScheduled,
131                0,
132                Some("hello".to_string()),
133                None,
134                None,
135            ),
136            create_event(
137                EventType::TaskScheduled,
138                1,
139                Some("world".to_string()),
140                None,
141                None,
142            ),
143            create_event(
144                EventType::TaskCompleted,
145                -1,
146                Some("hello".to_string()),
147                Some(json!("hello").to_string()),
148                Some(0),
149            ),
150            create_event(
151                EventType::TaskCompleted,
152                -1,
153                Some("world".to_string()),
154                Some(json!("world").to_string()),
155                Some(1),
156            ),
157        ];
158
159        let mut state = OrchestrationState::new(history);
160
161        let (idx, event) = state
162            .find_start_event("hello", EventType::TaskScheduled)
163            .unwrap();
164        event.is_processed = true;
165
166        let (idx, event) = state
167            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
168            .unwrap();
169        event.is_processed = true;
170
171        let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
172        let idx1 = Some(idx);
173
174        let (idx, event) = state
175            .find_start_event("world", EventType::TaskScheduled)
176            .unwrap();
177        event.is_processed = true;
178
179        let (idx, event) = state
180            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
181            .unwrap();
182        event.is_processed = true;
183
184        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
185        let idx2 = Some(idx);
186
187        let state = Rc::new(RefCell::new(state));
188        let future1 = ActionFuture::new(result1, state.clone(), idx1);
189        let future2 = ActionFuture::new(result2, state.clone(), idx2);
190        let join = JoinAll::new(state.clone(), vec![future2, future1]);
191
192        assert_eq!(join.event_index(), idx2);
193        assert_eq!(
194            poll(join),
195            Poll::Ready(vec![json!("world"), json!("hello")])
196        );
197    }
198
199    #[test]
200    fn it_updates_state() {
201        let history = vec![
202            create_event(EventType::OrchestratorStarted, -1, None, None, None),
203            create_event(
204                EventType::TaskScheduled,
205                0,
206                Some("hello".to_string()),
207                None,
208                None,
209            ),
210            create_event(
211                EventType::TaskScheduled,
212                1,
213                Some("world".to_string()),
214                None,
215                None,
216            ),
217            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
218            create_event(EventType::OrchestratorStarted, -1, None, None, None),
219            create_event(
220                EventType::TaskCompleted,
221                -1,
222                Some("hello".to_string()),
223                Some(json!("hello").to_string()),
224                Some(0),
225            ),
226            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
227            create_event(EventType::OrchestratorStarted, -1, None, None, None),
228            create_event(
229                EventType::TaskCompleted,
230                -1,
231                Some("world".to_string()),
232                Some(json!("world").to_string()),
233                Some(1),
234            ),
235        ];
236
237        let mut state = OrchestrationState::new(history);
238        assert!(state.is_replaying());
239
240        let (idx, event) = state
241            .find_start_event("hello", EventType::TaskScheduled)
242            .unwrap();
243        event.is_processed = true;
244
245        let (idx, event) = state
246            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
247            .unwrap();
248        event.is_processed = true;
249
250        let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
251        let idx1 = Some(idx);
252
253        let (idx, event) = state
254            .find_start_event("world", EventType::TaskScheduled)
255            .unwrap();
256        event.is_processed = true;
257
258        let (idx, event) = state
259            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
260            .unwrap();
261        event.is_processed = true;
262
263        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
264        let idx2 = Some(idx);
265
266        let state = Rc::new(RefCell::new(state));
267        let future1 = ActionFuture::new(result1, state.clone(), idx1);
268        let future2 = ActionFuture::new(result2, state.clone(), idx2);
269        let join = JoinAll::new(state.clone(), vec![future2, future1]);
270
271        assert_eq!(join.event_index(), idx2);
272        assert_eq!(
273            poll(join),
274            Poll::Ready(vec![json!("world"), json!("hello")])
275        );
276        assert!(!state.borrow().is_replaying());
277    }
278
279    #[test]
280    fn it_does_not_update_state_when_an_inner_future() {
281        let history = vec![
282            create_event(EventType::OrchestratorStarted, -1, None, None, None),
283            create_event(
284                EventType::TaskScheduled,
285                0,
286                Some("hello".to_string()),
287                None,
288                None,
289            ),
290            create_event(
291                EventType::TaskScheduled,
292                1,
293                Some("world".to_string()),
294                None,
295                None,
296            ),
297            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
298            create_event(EventType::OrchestratorStarted, -1, None, None, None),
299            create_event(
300                EventType::TaskCompleted,
301                -1,
302                Some("hello".to_string()),
303                Some(json!("hello").to_string()),
304                Some(0),
305            ),
306            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
307            create_event(EventType::OrchestratorStarted, -1, None, None, None),
308            create_event(
309                EventType::TaskCompleted,
310                -1,
311                Some("world".to_string()),
312                Some(json!("world").to_string()),
313                Some(1),
314            ),
315        ];
316
317        let mut state = OrchestrationState::new(history);
318        assert!(state.is_replaying());
319
320        let (idx, event) = state
321            .find_start_event("hello", EventType::TaskScheduled)
322            .unwrap();
323        event.is_processed = true;
324
325        let (idx, event) = state
326            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
327            .unwrap();
328        event.is_processed = true;
329
330        let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
331        let idx1 = Some(idx);
332
333        let (idx, event) = state
334            .find_start_event("world", EventType::TaskScheduled)
335            .unwrap();
336        event.is_processed = true;
337
338        let (idx, event) = state
339            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
340            .unwrap();
341        event.is_processed = true;
342
343        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
344        let idx2 = Some(idx);
345
346        let state = Rc::new(RefCell::new(state));
347        let future1 = ActionFuture::new(result1, state.clone(), idx1);
348        let future2 = ActionFuture::new(result2, state.clone(), idx2);
349        let mut join = JoinAll::new(state.clone(), vec![future2, future1]);
350        join.notify_inner();
351
352        assert_eq!(join.event_index(), idx2);
353        assert_eq!(
354            poll(join),
355            Poll::Ready(vec![json!("world"), json!("hello")])
356        );
357        assert!(state.borrow().is_replaying());
358    }
359}