1use crate::durable::{OrchestrationFuture, OrchestrationState};
2use futures::future::{join_all, FutureExt};
3use std::{
4 cell::RefCell,
5 future::Future,
6 pin::Pin,
7 rc::Rc,
8 task::{Context, Poll},
9};
10
11#[must_use = "futures do nothing unless you `.await` or poll them"]
13pub struct JoinAll<F>
14where
15 F: OrchestrationFuture,
16{
17 inner: futures::future::JoinAll<F>,
18 state: Rc<RefCell<OrchestrationState>>,
19 event_index: Option<usize>,
20 is_inner: bool,
21}
22
23impl<F> JoinAll<F>
24where
25 F: OrchestrationFuture,
26{
27 pub(crate) fn new<T>(state: Rc<RefCell<OrchestrationState>>, iter: T) -> Self
28 where
29 T: IntoIterator<Item = F>,
30 F: OrchestrationFuture,
31 {
32 let inner: Vec<_> = iter
33 .into_iter()
34 .map(|mut f| {
35 f.notify_inner();
36 f
37 })
38 .collect();
39
40 let event_index = inner
42 .iter()
43 .try_fold(None, |i, f| {
44 let next = f.event_index();
45 if next.is_none() {
46 Err(())
47 } else if i < next {
48 Ok(next)
49 } else {
50 Ok(i)
51 }
52 })
53 .unwrap_or(None);
54
55 JoinAll {
56 inner: join_all(inner),
57 state,
58 event_index,
59 is_inner: false,
60 }
61 }
62}
63
64impl<F> Future for JoinAll<F>
65where
66 F: OrchestrationFuture,
67{
68 type Output = Vec<F::Output>;
69
70 fn poll(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
71 let result = self.inner.poll_unpin(context);
72
73 if !self.is_inner {
74 if let Poll::Ready(_) = &result {
75 self.state.borrow_mut().update(self.event_index.unwrap());
76 }
77 }
78
79 result
80 }
81}
82
83impl<F> OrchestrationFuture for JoinAll<F>
84where
85 F: OrchestrationFuture,
86{
87 fn notify_inner(&mut self) {
88 self.is_inner = true;
89 }
90
91 fn event_index(&self) -> Option<usize> {
92 self.event_index
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::durable::{
100 tests::{create_event, poll},
101 ActionFuture, EventType,
102 };
103 use serde_json::{from_str, json};
104 use std::task::Poll;
105
106 #[test]
107 fn it_polls_pending_without_a_result() {
108 let history = vec![create_event(
109 EventType::OrchestratorStarted,
110 -1,
111 None,
112 None,
113 None,
114 )];
115
116 let state = Rc::new(RefCell::new(OrchestrationState::new(history)));
117 let future1 = ActionFuture::<()>::new(None, state.clone(), None);
118 let future2 = ActionFuture::<()>::new(None, state.clone(), None);
119 let join = JoinAll::new(state.clone(), vec![future1, future2]);
120
121 assert_eq!(join.event_index(), None);
122 assert_eq!(poll(join), Poll::Pending);
123 }
124
125 #[test]
126 fn it_polls_pending_with_a_result() {
127 let history = vec![
128 create_event(EventType::OrchestratorStarted, -1, None, None, None),
129 create_event(
130 EventType::TaskScheduled,
131 0,
132 Some("hello".to_string()),
133 None,
134 None,
135 ),
136 create_event(
137 EventType::TaskScheduled,
138 1,
139 Some("world".to_string()),
140 None,
141 None,
142 ),
143 create_event(
144 EventType::TaskCompleted,
145 -1,
146 Some("hello".to_string()),
147 Some(json!("hello").to_string()),
148 Some(0),
149 ),
150 create_event(
151 EventType::TaskCompleted,
152 -1,
153 Some("world".to_string()),
154 Some(json!("world").to_string()),
155 Some(1),
156 ),
157 ];
158
159 let mut state = OrchestrationState::new(history);
160
161 let (idx, event) = state
162 .find_start_event("hello", EventType::TaskScheduled)
163 .unwrap();
164 event.is_processed = true;
165
166 let (idx, event) = state
167 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
168 .unwrap();
169 event.is_processed = true;
170
171 let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
172 let idx1 = Some(idx);
173
174 let (idx, event) = state
175 .find_start_event("world", EventType::TaskScheduled)
176 .unwrap();
177 event.is_processed = true;
178
179 let (idx, event) = state
180 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
181 .unwrap();
182 event.is_processed = true;
183
184 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
185 let idx2 = Some(idx);
186
187 let state = Rc::new(RefCell::new(state));
188 let future1 = ActionFuture::new(result1, state.clone(), idx1);
189 let future2 = ActionFuture::new(result2, state.clone(), idx2);
190 let join = JoinAll::new(state.clone(), vec![future2, future1]);
191
192 assert_eq!(join.event_index(), idx2);
193 assert_eq!(
194 poll(join),
195 Poll::Ready(vec![json!("world"), json!("hello")])
196 );
197 }
198
199 #[test]
200 fn it_updates_state() {
201 let history = vec![
202 create_event(EventType::OrchestratorStarted, -1, None, None, None),
203 create_event(
204 EventType::TaskScheduled,
205 0,
206 Some("hello".to_string()),
207 None,
208 None,
209 ),
210 create_event(
211 EventType::TaskScheduled,
212 1,
213 Some("world".to_string()),
214 None,
215 None,
216 ),
217 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
218 create_event(EventType::OrchestratorStarted, -1, None, None, None),
219 create_event(
220 EventType::TaskCompleted,
221 -1,
222 Some("hello".to_string()),
223 Some(json!("hello").to_string()),
224 Some(0),
225 ),
226 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
227 create_event(EventType::OrchestratorStarted, -1, None, None, None),
228 create_event(
229 EventType::TaskCompleted,
230 -1,
231 Some("world".to_string()),
232 Some(json!("world").to_string()),
233 Some(1),
234 ),
235 ];
236
237 let mut state = OrchestrationState::new(history);
238 assert!(state.is_replaying());
239
240 let (idx, event) = state
241 .find_start_event("hello", EventType::TaskScheduled)
242 .unwrap();
243 event.is_processed = true;
244
245 let (idx, event) = state
246 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
247 .unwrap();
248 event.is_processed = true;
249
250 let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
251 let idx1 = Some(idx);
252
253 let (idx, event) = state
254 .find_start_event("world", EventType::TaskScheduled)
255 .unwrap();
256 event.is_processed = true;
257
258 let (idx, event) = state
259 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
260 .unwrap();
261 event.is_processed = true;
262
263 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
264 let idx2 = Some(idx);
265
266 let state = Rc::new(RefCell::new(state));
267 let future1 = ActionFuture::new(result1, state.clone(), idx1);
268 let future2 = ActionFuture::new(result2, state.clone(), idx2);
269 let join = JoinAll::new(state.clone(), vec![future2, future1]);
270
271 assert_eq!(join.event_index(), idx2);
272 assert_eq!(
273 poll(join),
274 Poll::Ready(vec![json!("world"), json!("hello")])
275 );
276 assert!(!state.borrow().is_replaying());
277 }
278
279 #[test]
280 fn it_does_not_update_state_when_an_inner_future() {
281 let history = vec![
282 create_event(EventType::OrchestratorStarted, -1, None, None, None),
283 create_event(
284 EventType::TaskScheduled,
285 0,
286 Some("hello".to_string()),
287 None,
288 None,
289 ),
290 create_event(
291 EventType::TaskScheduled,
292 1,
293 Some("world".to_string()),
294 None,
295 None,
296 ),
297 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
298 create_event(EventType::OrchestratorStarted, -1, None, None, None),
299 create_event(
300 EventType::TaskCompleted,
301 -1,
302 Some("hello".to_string()),
303 Some(json!("hello").to_string()),
304 Some(0),
305 ),
306 create_event(EventType::OrchestratorCompleted, -1, None, None, None),
307 create_event(EventType::OrchestratorStarted, -1, None, None, None),
308 create_event(
309 EventType::TaskCompleted,
310 -1,
311 Some("world".to_string()),
312 Some(json!("world").to_string()),
313 Some(1),
314 ),
315 ];
316
317 let mut state = OrchestrationState::new(history);
318 assert!(state.is_replaying());
319
320 let (idx, event) = state
321 .find_start_event("hello", EventType::TaskScheduled)
322 .unwrap();
323 event.is_processed = true;
324
325 let (idx, event) = state
326 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
327 .unwrap();
328 event.is_processed = true;
329
330 let result1 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
331 let idx1 = Some(idx);
332
333 let (idx, event) = state
334 .find_start_event("world", EventType::TaskScheduled)
335 .unwrap();
336 event.is_processed = true;
337
338 let (idx, event) = state
339 .find_end_event(idx, EventType::TaskCompleted, Some(EventType::TaskFailed))
340 .unwrap();
341 event.is_processed = true;
342
343 let result2 = Some(from_str(&event.result.as_ref().unwrap()).unwrap());
344 let idx2 = Some(idx);
345
346 let state = Rc::new(RefCell::new(state));
347 let future1 = ActionFuture::new(result1, state.clone(), idx1);
348 let future2 = ActionFuture::new(result2, state.clone(), idx2);
349 let mut join = JoinAll::new(state.clone(), vec![future2, future1]);
350 join.notify_inner();
351
352 assert_eq!(join.event_index(), idx2);
353 assert_eq!(
354 poll(join),
355 Poll::Ready(vec![json!("world"), json!("hello")])
356 );
357 assert!(state.borrow().is_replaying());
358 }
359}