1use crate::durable::{OrchestrationFuture, OrchestrationState};
2use futures::future::FutureExt;
3use std::{
4 cell::RefCell,
5 future::Future,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct SelectAll<F> {
14 inner: Vec<F>,
15 state: Rc<RefCell<OrchestrationState>>,
16 event_index: Option<usize>,
17 is_inner: bool,
18}
19
20impl<F> SelectAll<F>
21where
22 F: OrchestrationFuture,
23{
24 pub(crate) fn new<T>(state: Rc<RefCell<OrchestrationState>>, iter: T) -> Self
25 where
26 T: IntoIterator<Item = F>,
27 F: OrchestrationFuture,
28 {
29 let inner: Vec<_> = iter
30 .into_iter()
31 .map(|mut f| {
32 f.notify_inner();
33 f
34 })
35 .collect();
36
37 assert!(!inner.is_empty());
38
39 let event_index = inner.iter().filter_map(|f| f.event_index()).min();
41
42 SelectAll {
43 inner,
44 state,
45 event_index,
46 is_inner: false,
47 }
48 }
49}
50
51impl<F> Unpin for SelectAll<F> where F: Unpin {}
52
53impl<F> Future for SelectAll<F>
54where
55 F: OrchestrationFuture + Unpin,
56{
57 type Output = (F::Output, usize, Vec<F>);
58
59 fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
60 let event_index = self.event_index;
61 if event_index.is_none() {
62 return Poll::Pending;
63 }
64
65 let item = self.inner.iter_mut().enumerate().find_map(|(i, f)| {
66 if f.event_index() != event_index {
67 return None;
68 }
69 match f.poll_unpin(context) {
70 Poll::Pending => None,
71 Poll::Ready(e) => Some((i, e)),
72 }
73 });
74
75 match item {
76 Some((idx, res)) => {
77 self.inner.remove(idx);
78 let rest = std::mem::replace(&mut self.inner, Vec::new());
79
80 if !self.is_inner {
81 self.state.borrow_mut().update(event_index.unwrap());
82 }
83
84 Poll::Ready((res, idx, rest))
85 }
86 None => Poll::Pending,
87 }
88 }
89}
90
91impl<F> OrchestrationFuture for SelectAll<F>
92where
93 F: OrchestrationFuture + Unpin,
94{
95 fn notify_inner(&mut self) {
96 self.is_inner = true;
97 }
98
99 fn event_index(&self) -> Option<usize> {
100 self.event_index
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use crate::durable::{
108 tests::{create_event, poll},
109 ActionFuture, EventType,
110 };
111 use serde_json::{from_str, json, Value};
112 use std::task::Poll;
113
114 #[test]
115 fn it_polls_pending_without_a_result() {
116 let history = vec![create_event(
117 EventType::OrchestratorStarted,
118 -1,
119 None,
120 None,
121 None,
122 )];
123
124 let state = Rc::new(RefCell::new(OrchestrationState::new(history)));
125 let future1 = ActionFuture::<()>::new(None, state.clone(), None);
126 let future2 = ActionFuture::<()>::new(None, state.clone(), None);
127 let select = SelectAll::new(state.clone(), vec![future1, future2]);
128
129 assert_eq!(select.event_index(), None);
130 match poll(select) {
131 Poll::Ready(_) => assert!(false),
132 Poll::Pending => {}
133 };
134 }
135
136 #[test]
137 fn it_polls_pending_with_a_result() {
138 let history = vec![
139 create_event(EventType::OrchestratorStarted, -1, None, None, None),
140 create_event(
141 EventType::TaskScheduled,
142 0,
143 Some("hello".to_string()),
144 None,
145 None,
146 ),
147 create_event(
148 EventType::TaskScheduled,
149 1,
150 Some("world".to_string()),
151 None,
152 None,
153 ),
154 create_event(
155 EventType::TaskCompleted,
156 -1,
157 Some("hello".to_string()),
158 Some(json!("hello").to_string()),
159 Some(0),
160 ),
161 create_event(
162 EventType::TaskCompleted,
163 -1,
164 Some("world".to_string()),
165 Some(json!("world").to_string()),
166 Some(1),
167 ),
168 ];
169
170 let mut state = OrchestrationState::new(history);
171
172 let (idx, event) = state
173 .find_start_event("hello", EventType::TaskScheduled)
174 .unwrap();
175 event.is_processed = true;
176
177 let (idx, event) = state
178 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
179 .unwrap();
180 event.is_processed = true;
181
182 let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
183 let idx1 = Some(idx);
184
185 let (idx, event) = state
186 .find_start_event("world", EventType::TaskScheduled)
187 .unwrap();
188 event.is_processed = true;
189
190 let (idx, event) = state
191 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
192 .unwrap();
193 event.is_processed = true;
194
195 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
196 let idx2 = Some(idx);
197
198 let state = Rc::new(RefCell::new(state));
199 let future1 = ActionFuture::new(result1, state.clone(), idx1);
200 let future2 = ActionFuture::new(result2, state.clone(), idx2);
201 let select = SelectAll::new(state.clone(), vec![future2, future1]);
202
203 assert_eq!(select.event_index(), idx1);
204 match poll(select) {
205 Poll::Ready((r, i, mut remaining)) => {
206 assert_eq!(r, json!("hello"));
207 assert_eq!(i, 1);
208 assert_eq!(remaining.len(), 1);
209 assert_eq!(poll(remaining.pop().unwrap()), Poll::Ready(json!("world")));
210 }
211 Poll::Pending => assert!(false),
212 };
213 }
214
215 #[test]
216 fn it_updates_state() {
217 let history = vec![
218 create_event(EventType::OrchestratorStarted, -1, None, None, None),
219 create_event(
220 EventType::TaskScheduled,
221 0,
222 Some("hello".to_string()),
223 None,
224 None,
225 ),
226 create_event(
227 EventType::TaskScheduled,
228 1,
229 Some("world".to_string()),
230 None,
231 None,
232 ),
233 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
234 create_event(EventType::OrchestratorStarted, -1, None, None, None),
235 create_event(
236 EventType::TaskCompleted,
237 -1,
238 Some("hello".to_string()),
239 Some(json!("hello").to_string()),
240 Some(0),
241 ),
242 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
243 create_event(EventType::OrchestratorStarted, -1, None, None, None),
244 create_event(
245 EventType::TaskCompleted,
246 -1,
247 Some("world".to_string()),
248 Some(json!("world").to_string()),
249 Some(1),
250 ),
251 ];
252
253 let mut state = OrchestrationState::new(history);
254 assert!(state.is_replaying());
255
256 let (idx, event) = state
257 .find_start_event("hello", EventType::TaskScheduled)
258 .unwrap();
259 event.is_processed = true;
260
261 let (idx, event) = state
262 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
263 .unwrap();
264 event.is_processed = true;
265
266 let result1: Option<Value> = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
267 let idx1 = Some(idx);
268
269 let (idx, event) = state
270 .find_start_event("world", EventType::TaskScheduled)
271 .unwrap();
272 event.is_processed = true;
273
274 let (idx, event) = state
275 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
276 .unwrap();
277 event.is_processed = true;
278
279 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
280 let idx2 = Some(idx);
281
282 let state = Rc::new(RefCell::new(state));
283 let future1 = ActionFuture::new(result1, state.clone(), idx1);
284 let future2 = ActionFuture::new(result2, state.clone(), idx2);
285 let select = SelectAll::new(state.clone(), vec![future2, future1]);
286
287 assert_eq!(select.event_index(), idx1);
288 match poll(select) {
289 Poll::Ready((r, i, remaining)) => {
290 assert_eq!(r, json!("hello"));
291 assert_eq!(i, 1);
292 assert_eq!(remaining.len(), 1);
293
294 match poll(SelectAll::new(state.clone(), remaining)) {
295 Poll::Ready((r, i, remaining)) => {
296 assert_eq!(r, json!("world"));
297 assert_eq!(i, 0);
298 assert!(remaining.is_empty());
299 assert!(!state.borrow().is_replaying());
300 }
301 Poll::Pending => assert!(false),
302 };
303 }
304 Poll::Pending => assert!(false),
305 };
306 }
307
308 #[test]
309 fn it_does_not_update_state_when_an_inner_future() {
310 let history = vec![
311 create_event(EventType::OrchestratorStarted, -1, None, None, None),
312 create_event(
313 EventType::TaskScheduled,
314 0,
315 Some("hello".to_string()),
316 None,
317 None,
318 ),
319 create_event(
320 EventType::TaskScheduled,
321 1,
322 Some("world".to_string()),
323 None,
324 None,
325 ),
326 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
327 create_event(EventType::OrchestratorStarted, -1, None, None, None),
328 create_event(
329 EventType::TaskCompleted,
330 -1,
331 Some("hello".to_string()),
332 Some(json!("hello").to_string()),
333 Some(0),
334 ),
335 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
336 create_event(EventType::OrchestratorStarted, -1, None, None, None),
337 create_event(
338 EventType::TaskCompleted,
339 -1,
340 Some("world".to_string()),
341 Some(json!("world").to_string()),
342 Some(1),
343 ),
344 ];
345
346 let mut state = OrchestrationState::new(history);
347 assert!(state.is_replaying());
348
349 let (idx, event) = state
350 .find_start_event("hello", EventType::TaskScheduled)
351 .unwrap();
352 event.is_processed = true;
353
354 let (idx, event) = state
355 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
356 .unwrap();
357 event.is_processed = true;
358
359 let result1: Option<Value> = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
360 let idx1 = Some(idx);
361
362 let (idx, event) = state
363 .find_start_event("world", EventType::TaskScheduled)
364 .unwrap();
365 event.is_processed = true;
366
367 let (idx, event) = state
368 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
369 .unwrap();
370 event.is_processed = true;
371
372 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
373 let idx2 = Some(idx);
374
375 let state = Rc::new(RefCell::new(state));
376 let future1 = ActionFuture::new(result1, state.clone(), idx1);
377 let future2 = ActionFuture::new(result2, state.clone(), idx2);
378 let mut select = SelectAll::new(state.clone(), vec![future2, future1]);
379 select.notify_inner();
380
381 assert_eq!(select.event_index(), idx1);
382 match poll(select) {
383 Poll::Ready((r, i, remaining)) => {
384 assert_eq!(r, json!("hello"));
385 assert_eq!(i, 1);
386 assert_eq!(remaining.len(), 1);
387
388 let mut select = SelectAll::new(state.clone(), remaining);
389 select.notify_inner();
390
391 match poll(select) {
392 Poll::Ready((r, i, remaining)) => {
393 assert_eq!(r, json!("world"));
394 assert_eq!(i, 0);
395 assert!(remaining.is_empty());
396 assert!(state.borrow().is_replaying());
397 }
398 Poll::Pending => assert!(false),
399 };
400 }
401 Poll::Pending => assert!(false),
402 };
403 }
404}