1use std::any::TypeId;
2use std::collections::VecDeque;
3use std::fmt::Display;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6
7use crate::block::{
8 BlockStructure, Connection, NextStrategy, OperatorReceiver, OperatorStructure, Replication,
9};
10use crate::channel::RecvError::Disconnected;
11use crate::channel::SelectResult;
12
13use crate::network::{Coord, NetworkMessage, NetworkReceiver, NetworkSender, ReceiverEndpoint};
14use crate::operator::end::End;
15use crate::operator::iteration::iteration_end::IterationEnd;
16use crate::operator::iteration::leader::IterationLeader;
17use crate::operator::iteration::state_handler::IterationStateHandler;
18use crate::operator::iteration::{
19 IterationResult, IterationStateHandle, IterationStateLock, StateFeedback,
20};
21use crate::operator::source::Source;
22use crate::operator::start::Start;
23use crate::operator::{ExchangeData, Operator, StreamElement};
24use crate::scheduler::{BlockId, ExecutionMetadata};
25use crate::stream::Stream;
26
27fn clone_with_default<T: Default>(_: &T) -> T {
28 T::default()
29}
30
31#[derive(Derivative)]
35#[derivative(Debug, Clone)]
36pub struct Iterate<Out: ExchangeData, State: ExchangeData> {
37 coord: Coord,
39
40 state: IterationStateHandler<State>,
42
43 #[derivative(Clone(clone_with = "clone_with_default"))]
45 input_receiver: Option<NetworkReceiver<Out>>,
46
47 #[derivative(Clone(clone_with = "clone_with_default"))]
48 feedback_receiver: Option<NetworkReceiver<Out>>,
49
50 feedback_end_block_id: Arc<AtomicUsize>,
52 input_block_id: BlockId,
53 output_sender: Option<NetworkSender<Out>>,
55 output_block_id: Arc<AtomicUsize>,
57
58 content: VecDeque<StreamElement<Out>>,
60
61 input_stash: VecDeque<StreamElement<Out>>,
63 feedback_content: VecDeque<StreamElement<Out>>,
65
66 input_finished: bool,
68}
69
70impl<Out: ExchangeData, State: ExchangeData> Iterate<Out, State> {
71 fn new(
72 state_ref: IterationStateHandle<State>,
73 input_block_id: BlockId,
74 leader_block_id: BlockId,
75 feedback_end_block_id: Arc<AtomicUsize>,
76 output_block_id: Arc<AtomicUsize>,
77 state_lock: Arc<IterationStateLock>,
78 ) -> Self {
79 Self {
80 coord: Coord::new(0, 0, 0),
82 input_receiver: None,
83 feedback_receiver: None,
84 feedback_end_block_id,
85 input_block_id,
86 output_sender: None,
87 output_block_id,
88
89 content: Default::default(),
90 input_stash: Default::default(),
91 feedback_content: Default::default(),
92 input_finished: false,
93 state: IterationStateHandler::new(leader_block_id, state_ref, state_lock),
94 }
95 }
96
97 fn next_input(&mut self) -> Option<StreamElement<Out>> {
98 let item = self.input_stash.pop_front()?;
99
100 let el = match &item {
101 StreamElement::FlushAndRestart => {
102 log::debug!("input finished for iterate {}", self.coord);
103 self.input_finished = true;
104 self.state.lock();
106 StreamElement::FlushAndRestart
107 }
108 StreamElement::Item(_)
109 | StreamElement::Timestamped(_, _)
110 | StreamElement::Watermark(_)
111 | StreamElement::FlushBatch => item,
112 StreamElement::Terminate => {
113 log::debug!("Iterate at {} is terminating", self.coord);
114 let message = NetworkMessage::new_single(StreamElement::Terminate, self.coord);
115 self.output_sender.as_ref().unwrap().send(message).unwrap();
116 item
117 }
118 };
119 Some(el)
120 }
121
122 fn next_stored(&mut self) -> Option<StreamElement<Out>> {
123 let item = self.content.pop_front()?;
124 if matches!(item, StreamElement::FlushAndRestart) {
125 self.state.lock();
127 }
128 Some(item)
129 }
130
131 fn feedback_finished(&self) -> bool {
132 matches!(
133 self.feedback_content.back(),
134 Some(StreamElement::FlushAndRestart)
135 )
136 }
137
138 pub(crate) fn input_or_feedback(&mut self) {
139 let rx_feedback = self.feedback_receiver.as_ref().unwrap();
140
141 if let Some(rx_input) = self.input_receiver.as_ref() {
142 match rx_input.select(rx_feedback) {
143 SelectResult::A(Ok(msg)) => {
144 self.input_stash.extend(msg);
145 }
146 SelectResult::B(Ok(msg)) => {
147 self.feedback_content.extend(msg);
148 }
149 SelectResult::A(Err(Disconnected)) => {
150 self.input_receiver = None;
151 self.input_or_feedback();
152 }
153 SelectResult::B(Err(Disconnected)) => {
154 log::error!("feedback_receiver disconnected!");
155 panic!("feedback_receiver disconnected!");
156 }
157 }
158 } else {
159 self.feedback_content.extend(rx_feedback.recv().unwrap());
160 }
161 }
162
163 pub(crate) fn wait_update(&mut self) -> StateFeedback<State> {
164 let rx_state = self.state.state_receiver().unwrap();
167 loop {
168 let state_msg = if let Some(rx_input) = self.input_receiver.as_ref() {
169 match rx_state.select(rx_input) {
170 SelectResult::A(Ok(state_msg)) => state_msg,
171 SelectResult::A(Err(Disconnected)) => {
172 log::error!("state_receiver disconnected!");
173 panic!("state_receiver disconnected!");
174 }
175 SelectResult::B(Ok(msg)) => {
176 self.input_stash.extend(msg);
177 continue;
178 }
179 SelectResult::B(Err(Disconnected)) => {
180 self.input_receiver = None;
181 continue;
182 }
183 }
184 } else {
185 rx_state.recv().unwrap()
186 };
187
188 assert!(state_msg.num_items() == 1);
189
190 match state_msg.into_iter().next().unwrap() {
191 StreamElement::Item((should_continue, new_state)) => {
192 return (should_continue, new_state);
193 }
194 StreamElement::FlushBatch => {}
195 StreamElement::FlushAndRestart => {}
196 m => unreachable!(
197 "Iterate received invalid message from IterationLeader: {}",
198 m.variant()
199 ),
200 }
201 }
202 }
203}
204
205impl<Out: ExchangeData, State: ExchangeData + Sync> Operator for Iterate<Out, State> {
206 type Out = Out;
207
208 fn setup(&mut self, metadata: &mut ExecutionMetadata) {
209 self.coord = metadata.coord;
210
211 let endpoint = ReceiverEndpoint::new(metadata.coord, self.input_block_id);
212 self.input_receiver = Some(metadata.network.get_receiver(endpoint));
213
214 let feedback_end_block_id = self.feedback_end_block_id.load(Ordering::Acquire) as BlockId;
215 let feedback_endpoint = ReceiverEndpoint::new(metadata.coord, feedback_end_block_id);
216 self.feedback_receiver = Some(metadata.network.get_receiver(feedback_endpoint));
217
218 let output_block_id = self.output_block_id.load(Ordering::Acquire) as BlockId;
219 let output_endpoint = ReceiverEndpoint::new(
220 Coord::new(
221 output_block_id,
222 metadata.coord.host_id,
223 metadata.coord.replica_id,
224 ),
225 metadata.coord.block_id,
226 );
227 self.output_sender = Some(metadata.network.get_sender(output_endpoint));
228
229 self.state.setup(metadata);
230 }
231
232 fn next(&mut self) -> StreamElement<Out> {
233 loop {
234 while let Ok(message) = self.feedback_receiver.as_ref().unwrap().try_recv() {
236 self.feedback_content.extend(&mut message.into_iter());
237 }
238
239 if !self.input_finished {
240 while self.input_stash.is_empty() {
241 self.input_or_feedback();
242 }
243
244 return self.next_input().unwrap();
245 }
246
247 if !self.content.is_empty() {
248 return self.next_stored().unwrap();
249 }
250
251 while !self.feedback_finished() {
252 self.input_or_feedback();
253 }
254
255 log::debug!("Iterate at {} has finished the iteration", self.coord);
258 assert!(self.content.is_empty());
259 std::mem::swap(&mut self.content, &mut self.feedback_content);
260
261 let state_update = self.wait_update();
262
263 if let IterationResult::Finished = self.state.wait_sync_state(state_update) {
264 log::debug!("Iterate block at {} finished", self.coord,);
265 self.input_finished = false;
267
268 let message =
269 NetworkMessage::new_batch(self.content.drain(..).collect(), self.coord);
270 self.output_sender.as_ref().unwrap().send(message).unwrap();
271 }
272
273 }
276 }
277
278 fn structure(&self) -> BlockStructure {
279 let mut operator = OperatorStructure::new::<Out, _>("Iterate");
280 operator
281 .receivers
282 .push(OperatorReceiver::new::<StateFeedback<State>>(
283 self.state.leader_block_id,
284 ));
285 operator.receivers.push(OperatorReceiver::new::<Out>(
286 self.feedback_end_block_id.load(Ordering::Acquire) as BlockId,
287 ));
288 operator
289 .receivers
290 .push(OperatorReceiver::new::<Out>(self.input_block_id));
291 let output_block_id = self.output_block_id.load(Ordering::Acquire);
292 operator.connections.push(Connection::new::<Out, _>(
293 output_block_id as BlockId,
294 &NextStrategy::only_one(),
295 ));
296 BlockStructure::default().add_operator(operator)
297 }
298}
299
300impl<Out: ExchangeData, State: ExchangeData + Sync> Display for Iterate<Out, State> {
301 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302 write!(f, "Iterate<{}>", std::any::type_name::<Out>())
303 }
304}
305
306impl<Out: ExchangeData, OperatorChain> Stream<OperatorChain>
307where
308 OperatorChain: Operator<Out = Out> + 'static,
309{
310 pub fn iterate<
362 Body,
363 StateUpdate: ExchangeData + Default,
364 State: ExchangeData + Sync,
365 OperatorChain2,
366 >(
367 self,
368 num_iterations: usize,
369 initial_state: State,
370 body: Body,
371 local_fold: impl Fn(&mut StateUpdate, Out) + Send + Clone + 'static,
372 global_fold: impl Fn(&mut State, StateUpdate) + Send + Clone + 'static,
373 loop_condition: impl Fn(&mut State) -> bool + Send + Clone + 'static,
374 ) -> (
375 Stream<impl Operator<Out = State>>,
376 Stream<impl Operator<Out = Out>>,
377 )
378 where
379 Body: FnOnce(
380 Stream<Iterate<Out, State>>,
381 IterationStateHandle<State>,
382 ) -> Stream<OperatorChain2>,
383 OperatorChain2: Operator<Out = Out> + 'static,
384 {
385 assert!(
388 self.block.scheduling.replication.is_unlimited(),
389 "Cannot have an iteration block with limited parallelism"
390 );
391
392 let state = IterationStateHandle::new(initial_state.clone());
393 let state_clone = state.clone();
394 let batch_mode = self.block.batch_mode;
395 let ctx = self.ctx;
396
397 let shared_state_update_id = Arc::new(AtomicUsize::new(0));
400 let shared_feedback_id = Arc::new(AtomicUsize::new(0));
401 let shared_output_id = Arc::new(AtomicUsize::new(0));
402
403 let leader_block = ctx.lock().new_block(
405 IterationLeader::new(
406 initial_state,
407 num_iterations,
408 global_fold,
409 loop_condition,
410 shared_state_update_id.clone(),
411 ),
412 batch_mode,
413 self.block.iteration_ctx.clone(),
414 );
415 let state_lock = Arc::new(IterationStateLock::default());
419
420 let mut input_block = self
421 .block
422 .add_operator(|prev| End::new(prev, NextStrategy::only_one(), batch_mode));
423 input_block.is_only_one_strategy = true;
424
425 let iter_source = Iterate::new(
426 state,
427 input_block.id,
428 leader_block.id,
429 shared_feedback_id.clone(),
430 shared_output_id.clone(),
431 state_lock.clone(),
432 );
433 let mut iter_block =
434 ctx.lock()
435 .new_block(iter_source, batch_mode, input_block.iteration_ctx.clone());
436 let iter_id = iter_block.id;
437
438 iter_block.iteration_ctx.push(state_lock.clone());
439 let pre_iter_stack = iter_block.iteration_ctx();
441
442 let output_block = ctx.lock().new_block(
444 Start::single(iter_block.id, iter_block.iteration_ctx.last().cloned()),
445 batch_mode,
446 Default::default(),
447 );
448 let output_id = output_block.id;
449
450 let iter_stream = Stream::new(ctx.clone(), iter_block);
451 let body_stream = body(iter_stream, state_clone);
453
454 let mut body_stream = body_stream.split_block(
457 move |prev, next_strategy, batch_mode| {
458 let mut end = End::new(prev, next_strategy, batch_mode);
459 end.ignore_destination(output_id);
460 end
461 },
462 NextStrategy::only_one(),
463 );
464 let body_id = body_stream.block.id;
465
466 let post_iter_stack = body_stream.block.iteration_ctx();
467 assert_eq!(
468 pre_iter_stack, post_iter_stack,
469 "The body of the iteration should return the stream given as parameter"
470 );
471
472 body_stream.block.iteration_ctx.pop().unwrap();
473
474 let state_block = ctx.lock().new_block(
476 Start::single(body_stream.block.id, Some(state_lock)),
477 batch_mode,
478 Default::default(),
479 );
480 let state_stream = Stream::new(ctx.clone(), state_block);
481 let state_stream = state_stream
482 .key_by(|_| ())
483 .fold(StateUpdate::default(), local_fold)
484 .drop_key()
485 .add_operator(|prev| IterationEnd::new(prev, leader_block.id));
486
487 let batch_mode = body_stream.block.batch_mode;
489 let mut feedback_stream = body_stream.add_operator(|prev| {
490 let mut end = End::new(prev, NextStrategy::only_one(), batch_mode);
491 end.mark_feedback(iter_id);
492 end
493 });
494 feedback_stream.block.is_only_one_strategy = true;
495
496 let mut ctx_lock = ctx.lock();
497 let scheduler = ctx_lock.scheduler_mut();
498 scheduler.connect_blocks(input_block.id, iter_id, TypeId::of::<Out>());
499 scheduler.connect_blocks(body_id, state_stream.block.id, TypeId::of::<Out>());
501 scheduler.connect_blocks(
503 state_stream.block.id,
504 leader_block.id,
505 TypeId::of::<StateUpdate>(),
506 );
507 scheduler.connect_blocks(
509 leader_block.id,
510 iter_id,
511 TypeId::of::<StateFeedback<State>>(),
512 );
513 scheduler.connect_blocks(feedback_stream.block.id, iter_id, TypeId::of::<Out>());
515 scheduler.connect_blocks_fragile(iter_id, output_block.id, TypeId::of::<Out>());
517
518 shared_state_update_id.store(state_stream.block.id as usize, Ordering::Release);
520 shared_feedback_id.store(feedback_stream.block.id as usize, Ordering::Release);
521 shared_output_id.store(output_block.id as usize, Ordering::Release);
522
523 scheduler.schedule_block(state_stream.block);
524 scheduler.schedule_block(feedback_stream.block);
525 scheduler.schedule_block(input_block);
526
527 drop(ctx_lock);
528 (
535 Stream::new(ctx.clone(), leader_block).split_block(End::new, NextStrategy::random()),
536 Stream::new(ctx, output_block),
537 )
538 }
539}
540
541impl<Out: ExchangeData, State: ExchangeData + Sync> Source for Iterate<Out, State> {
542 fn replication(&self) -> Replication {
543 Replication::Unlimited
544 }
545}