noir_compute/operator/iteration/
iterate.rs

1use std::any::TypeId;
2use std::collections::VecDeque;
3use std::fmt::Display;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6
7use crate::block::{
8    BlockStructure, Connection, NextStrategy, OperatorReceiver, OperatorStructure, Replication,
9};
10use crate::channel::RecvError::Disconnected;
11use crate::channel::SelectResult;
12
13use crate::network::{Coord, NetworkMessage, NetworkReceiver, NetworkSender, ReceiverEndpoint};
14use crate::operator::end::End;
15use crate::operator::iteration::iteration_end::IterationEnd;
16use crate::operator::iteration::leader::IterationLeader;
17use crate::operator::iteration::state_handler::IterationStateHandler;
18use crate::operator::iteration::{
19    IterationResult, IterationStateHandle, IterationStateLock, StateFeedback,
20};
21use crate::operator::source::Source;
22use crate::operator::start::Start;
23use crate::operator::{ExchangeData, Operator, StreamElement};
24use crate::scheduler::{BlockId, ExecutionMetadata};
25use crate::stream::Stream;
26
27fn clone_with_default<T: Default>(_: &T) -> T {
28    T::default()
29}
30
31/// This is the first operator of the chain of blocks inside an iteration.
32///
33/// After an iteration what comes out of the loop will come back inside for the next iteration.
34#[derive(Derivative)]
35#[derivative(Debug, Clone)]
36pub struct Iterate<Out: ExchangeData, State: ExchangeData> {
37    /// The coordinate of this replica.
38    coord: Coord,
39
40    /// Helper structure that manages the iteration's state.
41    state: IterationStateHandler<State>,
42
43    /// The receiver of the data coming from the previous iteration of the loop.
44    #[derivative(Clone(clone_with = "clone_with_default"))]
45    input_receiver: Option<NetworkReceiver<Out>>,
46
47    #[derivative(Clone(clone_with = "clone_with_default"))]
48    feedback_receiver: Option<NetworkReceiver<Out>>,
49
50    /// The id of the block that handles the feedback connection.
51    feedback_end_block_id: Arc<AtomicUsize>,
52    input_block_id: BlockId,
53    /// The sender that will feed the data to the output of the iteration.
54    output_sender: Option<NetworkSender<Out>>,
55    /// The id of the block where the output of the iteration comes out.
56    output_block_id: Arc<AtomicUsize>,
57
58    /// The content of the stream to put back in the loop.
59    content: VecDeque<StreamElement<Out>>,
60
61    /// Used to store outside input arriving early
62    input_stash: VecDeque<StreamElement<Out>>,
63    /// The content to feed in the loop in the next iteration.
64    feedback_content: VecDeque<StreamElement<Out>>,
65
66    /// Whether the input stream has ended or not.
67    input_finished: bool,
68}
69
70impl<Out: ExchangeData, State: ExchangeData> Iterate<Out, State> {
71    fn new(
72        state_ref: IterationStateHandle<State>,
73        input_block_id: BlockId,
74        leader_block_id: BlockId,
75        feedback_end_block_id: Arc<AtomicUsize>,
76        output_block_id: Arc<AtomicUsize>,
77        state_lock: Arc<IterationStateLock>,
78    ) -> Self {
79        Self {
80            // these fields will be set inside the `setup` method
81            coord: Coord::new(0, 0, 0),
82            input_receiver: None,
83            feedback_receiver: None,
84            feedback_end_block_id,
85            input_block_id,
86            output_sender: None,
87            output_block_id,
88
89            content: Default::default(),
90            input_stash: Default::default(),
91            feedback_content: Default::default(),
92            input_finished: false,
93            state: IterationStateHandler::new(leader_block_id, state_ref, state_lock),
94        }
95    }
96
97    fn next_input(&mut self) -> Option<StreamElement<Out>> {
98        let item = self.input_stash.pop_front()?;
99
100        let el = match &item {
101            StreamElement::FlushAndRestart => {
102                log::debug!("input finished for iterate {}", self.coord);
103                self.input_finished = true;
104                // since this moment accessing the state for the next iteration must wait
105                self.state.lock();
106                StreamElement::FlushAndRestart
107            }
108            StreamElement::Item(_)
109            | StreamElement::Timestamped(_, _)
110            | StreamElement::Watermark(_)
111            | StreamElement::FlushBatch => item,
112            StreamElement::Terminate => {
113                log::debug!("Iterate at {} is terminating", self.coord);
114                let message = NetworkMessage::new_single(StreamElement::Terminate, self.coord);
115                self.output_sender.as_ref().unwrap().send(message).unwrap();
116                item
117            }
118        };
119        Some(el)
120    }
121
122    fn next_stored(&mut self) -> Option<StreamElement<Out>> {
123        let item = self.content.pop_front()?;
124        if matches!(item, StreamElement::FlushAndRestart) {
125            // since this moment accessing the state for the next iteration must wait
126            self.state.lock();
127        }
128        Some(item)
129    }
130
131    fn feedback_finished(&self) -> bool {
132        matches!(
133            self.feedback_content.back(),
134            Some(StreamElement::FlushAndRestart)
135        )
136    }
137
138    pub(crate) fn input_or_feedback(&mut self) {
139        let rx_feedback = self.feedback_receiver.as_ref().unwrap();
140
141        if let Some(rx_input) = self.input_receiver.as_ref() {
142            match rx_input.select(rx_feedback) {
143                SelectResult::A(Ok(msg)) => {
144                    self.input_stash.extend(msg);
145                }
146                SelectResult::B(Ok(msg)) => {
147                    self.feedback_content.extend(msg);
148                }
149                SelectResult::A(Err(Disconnected)) => {
150                    self.input_receiver = None;
151                    self.input_or_feedback();
152                }
153                SelectResult::B(Err(Disconnected)) => {
154                    log::error!("feedback_receiver disconnected!");
155                    panic!("feedback_receiver disconnected!");
156                }
157            }
158        } else {
159            self.feedback_content.extend(rx_feedback.recv().unwrap());
160        }
161    }
162
163    pub(crate) fn wait_update(&mut self) -> StateFeedback<State> {
164        // We need to stash inputs that arrive early to avoid deadlocks
165
166        let rx_state = self.state.state_receiver().unwrap();
167        loop {
168            let state_msg = if let Some(rx_input) = self.input_receiver.as_ref() {
169                match rx_state.select(rx_input) {
170                    SelectResult::A(Ok(state_msg)) => state_msg,
171                    SelectResult::A(Err(Disconnected)) => {
172                        log::error!("state_receiver disconnected!");
173                        panic!("state_receiver disconnected!");
174                    }
175                    SelectResult::B(Ok(msg)) => {
176                        self.input_stash.extend(msg);
177                        continue;
178                    }
179                    SelectResult::B(Err(Disconnected)) => {
180                        self.input_receiver = None;
181                        continue;
182                    }
183                }
184            } else {
185                rx_state.recv().unwrap()
186            };
187
188            assert!(state_msg.num_items() == 1);
189
190            match state_msg.into_iter().next().unwrap() {
191                StreamElement::Item((should_continue, new_state)) => {
192                    return (should_continue, new_state);
193                }
194                StreamElement::FlushBatch => {}
195                StreamElement::FlushAndRestart => {}
196                m => unreachable!(
197                    "Iterate received invalid message from IterationLeader: {}",
198                    m.variant()
199                ),
200            }
201        }
202    }
203}
204
205impl<Out: ExchangeData, State: ExchangeData + Sync> Operator for Iterate<Out, State> {
206    type Out = Out;
207
208    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
209        self.coord = metadata.coord;
210
211        let endpoint = ReceiverEndpoint::new(metadata.coord, self.input_block_id);
212        self.input_receiver = Some(metadata.network.get_receiver(endpoint));
213
214        let feedback_end_block_id = self.feedback_end_block_id.load(Ordering::Acquire) as BlockId;
215        let feedback_endpoint = ReceiverEndpoint::new(metadata.coord, feedback_end_block_id);
216        self.feedback_receiver = Some(metadata.network.get_receiver(feedback_endpoint));
217
218        let output_block_id = self.output_block_id.load(Ordering::Acquire) as BlockId;
219        let output_endpoint = ReceiverEndpoint::new(
220            Coord::new(
221                output_block_id,
222                metadata.coord.host_id,
223                metadata.coord.replica_id,
224            ),
225            metadata.coord.block_id,
226        );
227        self.output_sender = Some(metadata.network.get_sender(output_endpoint));
228
229        self.state.setup(metadata);
230    }
231
232    fn next(&mut self) -> StreamElement<Out> {
233        loop {
234            // try to make progress on the feedback
235            while let Ok(message) = self.feedback_receiver.as_ref().unwrap().try_recv() {
236                self.feedback_content.extend(&mut message.into_iter());
237            }
238
239            if !self.input_finished {
240                while self.input_stash.is_empty() {
241                    self.input_or_feedback();
242                }
243
244                return self.next_input().unwrap();
245            }
246
247            if !self.content.is_empty() {
248                return self.next_stored().unwrap();
249            }
250
251            while !self.feedback_finished() {
252                self.input_or_feedback();
253            }
254
255            // All feedback received
256
257            log::debug!("Iterate at {} has finished the iteration", self.coord);
258            assert!(self.content.is_empty());
259            std::mem::swap(&mut self.content, &mut self.feedback_content);
260
261            let state_update = self.wait_update();
262
263            if let IterationResult::Finished = self.state.wait_sync_state(state_update) {
264                log::debug!("Iterate block at {} finished", self.coord,);
265                // cleanup so that if this is a nested iteration next time we'll be good to start again
266                self.input_finished = false;
267
268                let message =
269                    NetworkMessage::new_batch(self.content.drain(..).collect(), self.coord);
270                self.output_sender.as_ref().unwrap().send(message).unwrap();
271            }
272
273            // This iteration has ended but FlushAndRestart has already been sent. To avoid sending
274            // twice the FlushAndRestart repeat.
275        }
276    }
277
278    fn structure(&self) -> BlockStructure {
279        let mut operator = OperatorStructure::new::<Out, _>("Iterate");
280        operator
281            .receivers
282            .push(OperatorReceiver::new::<StateFeedback<State>>(
283                self.state.leader_block_id,
284            ));
285        operator.receivers.push(OperatorReceiver::new::<Out>(
286            self.feedback_end_block_id.load(Ordering::Acquire) as BlockId,
287        ));
288        operator
289            .receivers
290            .push(OperatorReceiver::new::<Out>(self.input_block_id));
291        let output_block_id = self.output_block_id.load(Ordering::Acquire);
292        operator.connections.push(Connection::new::<Out, _>(
293            output_block_id as BlockId,
294            &NextStrategy::only_one(),
295        ));
296        BlockStructure::default().add_operator(operator)
297    }
298}
299
300impl<Out: ExchangeData, State: ExchangeData + Sync> Display for Iterate<Out, State> {
301    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
302        write!(f, "Iterate<{}>", std::any::type_name::<Out>())
303    }
304}
305
306impl<Out: ExchangeData, OperatorChain> Stream<OperatorChain>
307where
308    OperatorChain: Operator<Out = Out> + 'static,
309{
310    /// Construct an iterative dataflow where the input stream is fed inside a cycle. What comes
311    /// out of the loop will be fed back at the next iteration.
312    ///
313    /// This iteration is stateful, this means that all the replicas have a read-only access to the
314    /// _iteration state_. The initial value of the state is given as parameter. When an iteration
315    /// ends all the elements are reduced locally at each replica producing a `DeltaUpdate`. Those
316    /// delta updates are later reduced on a single node that, using the `global_fold` function will
317    /// compute the state for the next iteration. This state is also used in `loop_condition` to
318    /// check whether the next iteration should start or not. `loop_condition` is also allowed to
319    /// mutate the state.
320    ///
321    /// The initial value of `DeltaUpdate` is initialized with [`Default::default()`].
322    ///
323    /// The content of the loop has a new scope: it's defined by the `body` function that takes as
324    /// parameter the stream of data coming inside the iteration and a reference to the state. This
325    /// function should return the stream of the data that exits from the loop (that will be fed
326    /// back).
327    ///
328    /// This construct produces two stream:
329    ///
330    /// - the first is a stream with a single item: the final state of the iteration
331    /// - the second if the set of elements that exited the loop during the last iteration (i.e. the
332    ///   ones that should have been fed back in the next iteration).
333    ///
334    /// **Note**: due to an internal limitation, it's not currently possible to add an iteration
335    /// operator when the stream has limited parallelism. This means, for example, that after a
336    /// non-parallel source you have to add a shuffle.
337    ///
338    /// **Note**: this operator will split the current block.
339    ///
340    /// ## Example
341    /// ```
342    /// # use noir_compute::{StreamContext, RuntimeConfig};
343    /// # use noir_compute::operator::source::IteratorSource;
344    /// # let mut env = StreamContext::new(RuntimeConfig::local(1));
345    /// let s = env.stream_iter(0..3).shuffle();
346    /// let (state, items) = s.iterate(
347    ///     3, // at most 3 iterations
348    ///     0, // the initial state is zero
349    ///     |s, state| s.map(|n| n + 10),
350    ///     |delta: &mut i32, n| *delta += n,
351    ///     |state, delta| *state += delta,
352    ///     |_state| true,
353    /// );
354    /// let state = state.collect_vec();
355    /// let items = items.collect_vec();
356    /// env.execute_blocking();
357    ///
358    /// assert_eq!(state.get().unwrap(), vec![10 + 11 + 12 + 20 + 21 + 22 + 30 + 31 + 32]);
359    /// assert_eq!(items.get().unwrap(), vec![30, 31, 32]);
360    /// ```
361    pub fn iterate<
362        Body,
363        StateUpdate: ExchangeData + Default,
364        State: ExchangeData + Sync,
365        OperatorChain2,
366    >(
367        self,
368        num_iterations: usize,
369        initial_state: State,
370        body: Body,
371        local_fold: impl Fn(&mut StateUpdate, Out) + Send + Clone + 'static,
372        global_fold: impl Fn(&mut State, StateUpdate) + Send + Clone + 'static,
373        loop_condition: impl Fn(&mut State) -> bool + Send + Clone + 'static,
374    ) -> (
375        Stream<impl Operator<Out = State>>,
376        Stream<impl Operator<Out = Out>>,
377    )
378    where
379        Body: FnOnce(
380            Stream<Iterate<Out, State>>,
381            IterationStateHandle<State>,
382        ) -> Stream<OperatorChain2>,
383        OperatorChain2: Operator<Out = Out> + 'static,
384    {
385        // this is required because if the iteration block is not present on all the hosts, the ones
386        // without it won't receive the state updates.
387        assert!(
388            self.block.scheduling.replication.is_unlimited(),
389            "Cannot have an iteration block with limited parallelism"
390        );
391
392        let state = IterationStateHandle::new(initial_state.clone());
393        let state_clone = state.clone();
394        let batch_mode = self.block.batch_mode;
395        let ctx = self.ctx;
396
397        // the id of the block where IterationEnd is. At this moment we cannot know it, so we
398        // store a fake value inside this and as soon as we know it we set it to the right value.
399        let shared_state_update_id = Arc::new(AtomicUsize::new(0));
400        let shared_feedback_id = Arc::new(AtomicUsize::new(0));
401        let shared_output_id = Arc::new(AtomicUsize::new(0));
402
403        // prepare the stream with the IterationLeader block, this will provide the state output
404        let leader_block = ctx.lock().new_block(
405            IterationLeader::new(
406                initial_state,
407                num_iterations,
408                global_fold,
409                loop_condition,
410                shared_state_update_id.clone(),
411            ),
412            batch_mode,
413            self.block.iteration_ctx.clone(),
414        );
415        // the output stream is outside this loop, so it doesn't have the lock for this state
416
417        // the lock for synchronizing the access to the state of this iteration
418        let state_lock = Arc::new(IterationStateLock::default());
419
420        let mut input_block = self
421            .block
422            .add_operator(|prev| End::new(prev, NextStrategy::only_one(), batch_mode));
423        input_block.is_only_one_strategy = true;
424
425        let iter_source = Iterate::new(
426            state,
427            input_block.id,
428            leader_block.id,
429            shared_feedback_id.clone(),
430            shared_output_id.clone(),
431            state_lock.clone(),
432        );
433        let mut iter_block =
434            ctx.lock()
435                .new_block(iter_source, batch_mode, input_block.iteration_ctx.clone());
436        let iter_id = iter_block.id;
437
438        iter_block.iteration_ctx.push(state_lock.clone());
439        // save the stack of the iteration for checking the stream returned by the body
440        let pre_iter_stack = iter_block.iteration_ctx();
441
442        // prepare the stream that will output the content of the loop
443        let output_block = ctx.lock().new_block(
444            Start::single(iter_block.id, iter_block.iteration_ctx.last().cloned()),
445            batch_mode,
446            Default::default(),
447        );
448        let output_id = output_block.id;
449
450        let iter_stream = Stream::new(ctx.clone(), iter_block);
451        // attach the body of the loop to the Iterate operator
452        let body_stream = body(iter_stream, state_clone);
453
454        // Split the body of the loop in 2: the end block of the loop must ignore the output stream
455        // since it's manually handled by the Iterate operator.
456        let mut body_stream = body_stream.split_block(
457            move |prev, next_strategy, batch_mode| {
458                let mut end = End::new(prev, next_strategy, batch_mode);
459                end.ignore_destination(output_id);
460                end
461            },
462            NextStrategy::only_one(),
463        );
464        let body_id = body_stream.block.id;
465
466        let post_iter_stack = body_stream.block.iteration_ctx();
467        assert_eq!(
468            pre_iter_stack, post_iter_stack,
469            "The body of the iteration should return the stream given as parameter"
470        );
471
472        body_stream.block.iteration_ctx.pop().unwrap();
473
474        // First split of the body: the data will be reduced into delta updates
475        let state_block = ctx.lock().new_block(
476            Start::single(body_stream.block.id, Some(state_lock)),
477            batch_mode,
478            Default::default(),
479        );
480        let state_stream = Stream::new(ctx.clone(), state_block);
481        let state_stream = state_stream
482            .key_by(|_| ())
483            .fold(StateUpdate::default(), local_fold)
484            .drop_key()
485            .add_operator(|prev| IterationEnd::new(prev, leader_block.id));
486
487        // Second split of the body: the data will be fed back to the Iterate block
488        let batch_mode = body_stream.block.batch_mode;
489        let mut feedback_stream = body_stream.add_operator(|prev| {
490            let mut end = End::new(prev, NextStrategy::only_one(), batch_mode);
491            end.mark_feedback(iter_id);
492            end
493        });
494        feedback_stream.block.is_only_one_strategy = true;
495
496        let mut ctx_lock = ctx.lock();
497        let scheduler = ctx_lock.scheduler_mut();
498        scheduler.connect_blocks(input_block.id, iter_id, TypeId::of::<Out>());
499        // connect the end of the loop to the IterationEnd
500        scheduler.connect_blocks(body_id, state_stream.block.id, TypeId::of::<Out>());
501        // connect the IterationEnd to the IterationLeader
502        scheduler.connect_blocks(
503            state_stream.block.id,
504            leader_block.id,
505            TypeId::of::<StateUpdate>(),
506        );
507        // connect the IterationLeader to the Iterate
508        scheduler.connect_blocks(
509            leader_block.id,
510            iter_id,
511            TypeId::of::<StateFeedback<State>>(),
512        );
513        // connect the feedback
514        scheduler.connect_blocks(feedback_stream.block.id, iter_id, TypeId::of::<Out>());
515        // connect the output stream
516        scheduler.connect_blocks_fragile(iter_id, output_block.id, TypeId::of::<Out>());
517
518        // store the id of the blocks we now know
519        shared_state_update_id.store(state_stream.block.id as usize, Ordering::Release);
520        shared_feedback_id.store(feedback_stream.block.id as usize, Ordering::Release);
521        shared_output_id.store(output_block.id as usize, Ordering::Release);
522
523        scheduler.schedule_block(state_stream.block);
524        scheduler.schedule_block(feedback_stream.block);
525        scheduler.schedule_block(input_block);
526
527        drop(ctx_lock);
528        // TODO: check parallelism and make sure the blocks are spawned on the same replicas
529
530        // FIXME: this add_block is here just to make sure that the NextStrategy of output_stream
531        //        is not changed by the following operators. This because the next strategy affects
532        //        the connections made by the scheduler and if accidentally set to OnlyOne will
533        //        break the connections.
534        (
535            Stream::new(ctx.clone(), leader_block).split_block(End::new, NextStrategy::random()),
536            Stream::new(ctx, output_block),
537        )
538    }
539}
540
541impl<Out: ExchangeData, State: ExchangeData + Sync> Source for Iterate<Out, State> {
542    fn replication(&self) -> Replication {
543        Replication::Unlimited
544    }
545}