azure_functions/durable/
select_all.rs

1use crate::durable::{OrchestrationFuture, OrchestrationState};
2use futures::future::FutureExt;
3use std::{
4    cell::RefCell,
5    future::Future,
6    pin::Pin,
7    rc::Rc,
8    task::{Context, Poll},
9};
10
11/// Future for the `DurableOrchestrationContext::select_all` function.
12#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct SelectAll<F> {
14    inner: Vec<F>,
15    state: Rc<RefCell<OrchestrationState>>,
16    event_index: Option<usize>,
17    is_inner: bool,
18}
19
20impl<F> SelectAll<F>
21where
22    F: OrchestrationFuture,
23{
24    pub(crate) fn new<T>(state: Rc<RefCell<OrchestrationState>>, iter: T) -> Self
25    where
26        T: IntoIterator<Item = F>,
27        F: OrchestrationFuture,
28    {
29        let inner: Vec<_> = iter
30            .into_iter()
31            .map(|mut f| {
32                f.notify_inner();
33                f
34            })
35            .collect();
36
37        assert!(!inner.is_empty());
38
39        // The event index of a select is the minimum present index of the sequence
40        let event_index = inner.iter().filter_map(|f| f.event_index()).min();
41
42        SelectAll {
43            inner,
44            state,
45            event_index,
46            is_inner: false,
47        }
48    }
49}
50
51impl<F> Unpin for SelectAll<F> where F: Unpin {}
52
53impl<F> Future for SelectAll<F>
54where
55    F: OrchestrationFuture + Unpin,
56{
57    type Output = (F::Output, usize, Vec<F>);
58
59    fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
60        let event_index = self.event_index;
61        if event_index.is_none() {
62            return Poll::Pending;
63        }
64
65        let item = self.inner.iter_mut().enumerate().find_map(|(i, f)| {
66            if f.event_index() != event_index {
67                return None;
68            }
69            match f.poll_unpin(context) {
70                Poll::Pending => None,
71                Poll::Ready(e) => Some((i, e)),
72            }
73        });
74
75        match item {
76            Some((idx, res)) => {
77                self.inner.remove(idx);
78                let rest = std::mem::replace(&mut self.inner, Vec::new());
79
80                if !self.is_inner {
81                    self.state.borrow_mut().update(event_index.unwrap());
82                }
83
84                Poll::Ready((res, idx, rest))
85            }
86            None => Poll::Pending,
87        }
88    }
89}
90
91impl<F> OrchestrationFuture for SelectAll<F>
92where
93    F: OrchestrationFuture + Unpin,
94{
95    fn notify_inner(&mut self) {
96        self.is_inner = true;
97    }
98
99    fn event_index(&self) -> Option<usize> {
100        self.event_index
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use crate::durable::{
108        tests::{create_event, poll},
109        ActionFuture, EventType,
110    };
111    use serde_json::{from_str, json, Value};
112    use std::task::Poll;
113
114    #[test]
115    fn it_polls_pending_without_a_result() {
116        let history = vec![create_event(
117            EventType::OrchestratorStarted,
118            -1,
119            None,
120            None,
121            None,
122        )];
123
124        let state = Rc::new(RefCell::new(OrchestrationState::new(history)));
125        let future1 = ActionFuture::<()>::new(None, state.clone(), None);
126        let future2 = ActionFuture::<()>::new(None, state.clone(), None);
127        let select = SelectAll::new(state.clone(), vec![future1, future2]);
128
129        assert_eq!(select.event_index(), None);
130        match poll(select) {
131            Poll::Ready(_) => assert!(false),
132            Poll::Pending => {}
133        };
134    }
135
136    #[test]
137    fn it_polls_pending_with_a_result() {
138        let history = vec![
139            create_event(EventType::OrchestratorStarted, -1, None, None, None),
140            create_event(
141                EventType::TaskScheduled,
142                0,
143                Some("hello".to_string()),
144                None,
145                None,
146            ),
147            create_event(
148                EventType::TaskScheduled,
149                1,
150                Some("world".to_string()),
151                None,
152                None,
153            ),
154            create_event(
155                EventType::TaskCompleted,
156                -1,
157                Some("hello".to_string()),
158                Some(json!("hello").to_string()),
159                Some(0),
160            ),
161            create_event(
162                EventType::TaskCompleted,
163                -1,
164                Some("world".to_string()),
165                Some(json!("world").to_string()),
166                Some(1),
167            ),
168        ];
169
170        let mut state = OrchestrationState::new(history);
171
172        let (idx, event) = state
173            .find_start_event("hello", EventType::TaskScheduled)
174            .unwrap();
175        event.is_processed = true;
176
177        let (idx, event) = state
178            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
179            .unwrap();
180        event.is_processed = true;
181
182        let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
183        let idx1 = Some(idx);
184
185        let (idx, event) = state
186            .find_start_event("world", EventType::TaskScheduled)
187            .unwrap();
188        event.is_processed = true;
189
190        let (idx, event) = state
191            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
192            .unwrap();
193        event.is_processed = true;
194
195        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
196        let idx2 = Some(idx);
197
198        let state = Rc::new(RefCell::new(state));
199        let future1 = ActionFuture::new(result1, state.clone(), idx1);
200        let future2 = ActionFuture::new(result2, state.clone(), idx2);
201        let select = SelectAll::new(state.clone(), vec![future2, future1]);
202
203        assert_eq!(select.event_index(), idx1);
204        match poll(select) {
205            Poll::Ready((r, i, mut remaining)) => {
206                assert_eq!(r, json!("hello"));
207                assert_eq!(i, 1);
208                assert_eq!(remaining.len(), 1);
209                assert_eq!(poll(remaining.pop().unwrap()), Poll::Ready(json!("world")));
210            }
211            Poll::Pending => assert!(false),
212        };
213    }
214
215    #[test]
216    fn it_updates_state() {
217        let history = vec![
218            create_event(EventType::OrchestratorStarted, -1, None, None, None),
219            create_event(
220                EventType::TaskScheduled,
221                0,
222                Some("hello".to_string()),
223                None,
224                None,
225            ),
226            create_event(
227                EventType::TaskScheduled,
228                1,
229                Some("world".to_string()),
230                None,
231                None,
232            ),
233            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
234            create_event(EventType::OrchestratorStarted, -1, None, None, None),
235            create_event(
236                EventType::TaskCompleted,
237                -1,
238                Some("hello".to_string()),
239                Some(json!("hello").to_string()),
240                Some(0),
241            ),
242            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
243            create_event(EventType::OrchestratorStarted, -1, None, None, None),
244            create_event(
245                EventType::TaskCompleted,
246                -1,
247                Some("world".to_string()),
248                Some(json!("world").to_string()),
249                Some(1),
250            ),
251        ];
252
253        let mut state = OrchestrationState::new(history);
254        assert!(state.is_replaying());
255
256        let (idx, event) = state
257            .find_start_event("hello", EventType::TaskScheduled)
258            .unwrap();
259        event.is_processed = true;
260
261        let (idx, event) = state
262            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
263            .unwrap();
264        event.is_processed = true;
265
266        let result1: Option<Value> = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
267        let idx1 = Some(idx);
268
269        let (idx, event) = state
270            .find_start_event("world", EventType::TaskScheduled)
271            .unwrap();
272        event.is_processed = true;
273
274        let (idx, event) = state
275            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
276            .unwrap();
277        event.is_processed = true;
278
279        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
280        let idx2 = Some(idx);
281
282        let state = Rc::new(RefCell::new(state));
283        let future1 = ActionFuture::new(result1, state.clone(), idx1);
284        let future2 = ActionFuture::new(result2, state.clone(), idx2);
285        let select = SelectAll::new(state.clone(), vec![future2, future1]);
286
287        assert_eq!(select.event_index(), idx1);
288        match poll(select) {
289            Poll::Ready((r, i, remaining)) => {
290                assert_eq!(r, json!("hello"));
291                assert_eq!(i, 1);
292                assert_eq!(remaining.len(), 1);
293
294                match poll(SelectAll::new(state.clone(), remaining)) {
295                    Poll::Ready((r, i, remaining)) => {
296                        assert_eq!(r, json!("world"));
297                        assert_eq!(i, 0);
298                        assert!(remaining.is_empty());
299                        assert!(!state.borrow().is_replaying());
300                    }
301                    Poll::Pending => assert!(false),
302                };
303            }
304            Poll::Pending => assert!(false),
305        };
306    }
307
308    #[test]
309    fn it_does_not_update_state_when_an_inner_future() {
310        let history = vec![
311            create_event(EventType::OrchestratorStarted, -1, None, None, None),
312            create_event(
313                EventType::TaskScheduled,
314                0,
315                Some("hello".to_string()),
316                None,
317                None,
318            ),
319            create_event(
320                EventType::TaskScheduled,
321                1,
322                Some("world".to_string()),
323                None,
324                None,
325            ),
326            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
327            create_event(EventType::OrchestratorStarted, -1, None, None, None),
328            create_event(
329                EventType::TaskCompleted,
330                -1,
331                Some("hello".to_string()),
332                Some(json!("hello").to_string()),
333                Some(0),
334            ),
335            create_event(EventType::OrchestratorCompleted, -1, None, None, None),
336            create_event(EventType::OrchestratorStarted, -1, None, None, None),
337            create_event(
338                EventType::TaskCompleted,
339                -1,
340                Some("world".to_string()),
341                Some(json!("world").to_string()),
342                Some(1),
343            ),
344        ];
345
346        let mut state = OrchestrationState::new(history);
347        assert!(state.is_replaying());
348
349        let (idx, event) = state
350            .find_start_event("hello", EventType::TaskScheduled)
351            .unwrap();
352        event.is_processed = true;
353
354        let (idx, event) = state
355            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
356            .unwrap();
357        event.is_processed = true;
358
359        let result1: Option<Value> = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
360        let idx1 = Some(idx);
361
362        let (idx, event) = state
363            .find_start_event("world", EventType::TaskScheduled)
364            .unwrap();
365        event.is_processed = true;
366
367        let (idx, event) = state
368            .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
369            .unwrap();
370        event.is_processed = true;
371
372        let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
373        let idx2 = Some(idx);
374
375        let state = Rc::new(RefCell::new(state));
376        let future1 = ActionFuture::new(result1, state.clone(), idx1);
377        let future2 = ActionFuture::new(result2, state.clone(), idx2);
378        let mut select = SelectAll::new(state.clone(), vec![future2, future1]);
379        select.notify_inner();
380
381        assert_eq!(select.event_index(), idx1);
382        match poll(select) {
383            Poll::Ready((r, i, remaining)) => {
384                assert_eq!(r, json!("hello"));
385                assert_eq!(i, 1);
386                assert_eq!(remaining.len(), 1);
387
388                let mut select = SelectAll::new(state.clone(), remaining);
389                select.notify_inner();
390
391                match poll(select) {
392                    Poll::Ready((r, i, remaining)) => {
393                        assert_eq!(r, json!("world"));
394                        assert_eq!(i, 0);
395                        assert!(remaining.is_empty());
396                        assert!(state.borrow().is_replaying());
397                    }
398                    Poll::Pending => assert!(false),
399                };
400            }
401            Poll::Pending => assert!(false),
402        };
403    }
404}