noir_compute/operator/iteration/
replay.rs

1use std::any::TypeId;
2use std::fmt::Display;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5
6use crate::block::{BlockStructure, NextStrategy, OperatorReceiver, OperatorStructure};
7
8use crate::network::Coord;
9use crate::operator::end::End;
10use crate::operator::iteration::iteration_end::IterationEnd;
11use crate::operator::iteration::leader::IterationLeader;
12use crate::operator::iteration::state_handler::IterationStateHandler;
13use crate::operator::iteration::{
14    IterationResult, IterationStateHandle, IterationStateLock, StateFeedback,
15};
16use crate::operator::{Data, ExchangeData, Operator, StreamElement};
17use crate::scheduler::{BlockId, ExecutionMetadata};
18use crate::stream::Stream;
19
20/// This is the first operator of the chain of blocks inside an iteration.
21///
22/// If a new iteration should start, the initial dataset is replayed.
23#[derive(Debug, Clone)]
24pub struct Replay<Out: Data, State: ExchangeData, OperatorChain>
25where
26    OperatorChain: Operator<Out = Out>,
27{
28    /// The coordinate of this replica.
29    coord: Coord,
30
31    /// Helper structure that manages the iteration's state.
32    state: IterationStateHandler<State>,
33
34    /// The chain of previous operators where the dataset to replay is read from.
35    prev: OperatorChain,
36
37    /// The content of the stream to replay.
38    content: Vec<StreamElement<Out>>,
39    /// The index inside `content` of the first message to be sent.
40    content_index: usize,
41
42    /// Whether the input stream has ended or not.
43    input_finished: bool,
44}
45
46impl<Out: Data, State: ExchangeData, OperatorChain> Display for Replay<Out, State, OperatorChain>
47where
48    OperatorChain: Operator<Out = Out>,
49{
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(
52            f,
53            "{} -> Replay<{}>",
54            self.prev,
55            std::any::type_name::<Out>()
56        )
57    }
58}
59
60impl<Out: Data, State: ExchangeData, OperatorChain> Replay<Out, State, OperatorChain>
61where
62    OperatorChain: Operator<Out = Out>,
63{
64    fn new(
65        prev: OperatorChain,
66        state_ref: IterationStateHandle<State>,
67        leader_block_id: BlockId,
68        state_lock: Arc<IterationStateLock>,
69    ) -> Self {
70        Self {
71            // these fields will be set inside the `setup` method
72            coord: Coord::new(0, 0, 0),
73
74            prev,
75            content: Default::default(),
76            content_index: 0,
77            input_finished: false,
78            state: IterationStateHandler::new(leader_block_id, state_ref, state_lock),
79        }
80    }
81
82    fn input_next(&mut self) -> Option<StreamElement<Out>> {
83        if self.input_finished {
84            return None;
85        }
86
87        let item = match self.prev.next() {
88            StreamElement::FlushAndRestart => {
89                log::debug!(
90                    "Replay at {} received all the input: {} elements total",
91                    self.coord,
92                    self.content.len()
93                );
94                self.input_finished = true;
95                self.content.push(StreamElement::FlushAndRestart);
96                // the first iteration has already happened
97                self.content_index = self.content.len();
98                // since this moment accessing the state for the next iteration must wait
99                self.state.lock();
100                StreamElement::FlushAndRestart
101            }
102            // messages to save for the replay
103            el @ StreamElement::Item(_)
104            | el @ StreamElement::Timestamped(_, _)
105            | el @ StreamElement::Watermark(_) => {
106                self.content.push(el.clone());
107                el
108            }
109            // messages to forward without replaying
110            StreamElement::FlushBatch => StreamElement::FlushBatch,
111            StreamElement::Terminate => {
112                log::debug!("Replay at {} is terminating", self.coord);
113                StreamElement::Terminate
114            }
115        };
116        Some(item)
117    }
118
119    fn wait_update(&mut self) -> StateFeedback<State> {
120        let state_receiver = self.state.state_receiver().unwrap();
121        // TODO: check if affected by deadlock like iterate was in commit eb481da525850febe7cfb0963c6f3285252ecfaa
122        // If there is the possibility of input staying still in the channel
123        // waiting for the state, the iteration may deadlock
124        // to solve instead of blocking on the state receiver,
125        // a select must be performed allowing inputs to be stashed and
126        // be pulled off the channel
127        loop {
128            let message = state_receiver.recv().unwrap();
129            assert!(message.num_items() == 1);
130
131            match message.into_iter().next().unwrap() {
132                StreamElement::Item((should_continue, new_state)) => {
133                    return (should_continue, new_state);
134                }
135                StreamElement::FlushBatch => {}
136                StreamElement::FlushAndRestart => {}
137                m => unreachable!(
138                    "Iterate received invalid message from IterationLeader: {}",
139                    m.variant()
140                ),
141            }
142        }
143    }
144}
145
146impl<Out: Data, State: ExchangeData + Sync, OperatorChain> Operator
147    for Replay<Out, State, OperatorChain>
148where
149    OperatorChain: Operator<Out = Out>,
150{
151    type Out = Out;
152
153    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
154        self.coord = metadata.coord;
155        self.prev.setup(metadata);
156        self.state.setup(metadata);
157    }
158
159    fn next(&mut self) -> StreamElement<Out> {
160        loop {
161            if let Some(value) = self.input_next() {
162                return value;
163            }
164            // replay
165
166            // this iteration has not ended yet
167            if self.content_index < self.content.len() {
168                let item = self.content[self.content_index].clone();
169                self.content_index += 1;
170                if matches!(item, StreamElement::FlushAndRestart) {
171                    // since this moment accessing the state for the next iteration must wait
172                    self.state.lock();
173                }
174                return item;
175            }
176
177            log::debug!("Replay at {} has ended the iteration", self.coord);
178
179            self.content_index = 0;
180
181            let state_update = self.wait_update();
182
183            if let IterationResult::Finished = self.state.wait_sync_state(state_update) {
184                log::debug!("Replay block at {} ended the iteration", self.coord);
185                // cleanup so that if this is a nested iteration next time we'll be good to start again
186                self.content.clear();
187                self.input_finished = false;
188            }
189
190            // This iteration has ended but FlushAndRestart has already been sent. To avoid sending
191            // twice the FlushAndRestart repeat.
192        }
193    }
194
195    fn structure(&self) -> BlockStructure {
196        let mut operator = OperatorStructure::new::<Out, _>("Replay");
197        operator
198            .receivers
199            .push(OperatorReceiver::new::<StateFeedback<State>>(
200                self.state.leader_block_id,
201            ));
202        self.prev.structure().add_operator(operator)
203    }
204}
205
206impl<Out: Data, OperatorChain> Stream<OperatorChain>
207where
208    OperatorChain: Operator<Out = Out> + 'static,
209{
210    /// Construct an iterative dataflow where the input stream is repeatedly fed inside a cycle,
211    /// i.e. what comes into the cycle is _replayed_ at every iteration.
212    ///
213    /// This iteration is stateful, this means that all the replicas have a read-only access to the
214    /// _iteration state_. The initial value of the state is given as parameter. When an iteration
215    /// ends all the elements are reduced locally at each replica producing a `DeltaUpdate`. Those
216    /// delta updates are later reduced on a single node that, using the `global_fold` function will
217    /// compute the state for the next iteration. This state is also used in `loop_condition` to
218    /// check whether the next iteration should start or not. `loop_condition` is also allowed to
219    /// mutate the state.
220    ///
221    /// The initial value of `DeltaUpdate` is initialized with [`Default::default()`].
222    ///
223    /// The content of the loop has a new scope: it's defined by the `body` function that takes as
224    /// parameter the stream of data coming inside the iteration and a reference to the state. This
225    /// function should return the stream of the data that exits from the loop (that will be fed
226    /// back).
227    ///
228    /// This construct produces a single stream with a single element: the final state of the
229    /// iteration.
230    ///
231    /// **Note**: due to an internal limitation, it's not currently possible to add an iteration
232    /// operator when the stream has limited parallelism. This means, for example, that after a
233    /// non-parallel source you have to add a shuffle.
234    ///
235    /// **Note**: this operator will split the current block.
236    ///
237    /// ## Example
238    /// ```
239    /// # use noir_compute::{StreamContext, RuntimeConfig};
240    /// # use noir_compute::operator::source::IteratorSource;
241    /// # let mut env = StreamContext::new(RuntimeConfig::local(1));
242    /// let s = env.stream_iter(0..3).shuffle();
243    /// let state = s.replay(
244    ///     3, // at most 3 iterations
245    ///     0, // the initial state is zero
246    ///     |s, state| s.map(|n| n + 10),
247    ///     |delta: &mut i32, n| *delta += n,
248    ///     |state, delta| *state += delta,
249    ///     |_state| true,
250    /// );
251    /// let state = state.collect_vec();
252    /// env.execute_blocking();
253    ///
254    /// assert_eq!(state.get().unwrap(), vec![3 * (10 + 11 + 12)]);
255    /// ```
256    pub fn replay<
257        Body,
258        DeltaUpdate: ExchangeData + Default,
259        State: ExchangeData + Sync,
260        OperatorChain2,
261    >(
262        self,
263        num_iterations: usize,
264        initial_state: State,
265        body: Body,
266        local_fold: impl Fn(&mut DeltaUpdate, Out) + Send + Clone + 'static,
267        global_fold: impl Fn(&mut State, DeltaUpdate) + Send + Clone + 'static,
268        loop_condition: impl Fn(&mut State) -> bool + Send + Clone + 'static,
269    ) -> Stream<impl Operator<Out = State>>
270    where
271        Body: FnOnce(
272            Stream<Replay<Out, State, OperatorChain>>,
273            IterationStateHandle<State>,
274        ) -> Stream<OperatorChain2>,
275        OperatorChain2: Operator<Out = Out> + 'static,
276    {
277        // this is required because if the iteration block is not present on all the hosts, the ones
278        // without it won't receive the state updates.
279        assert!(
280            self.block.scheduling.replication.is_unlimited(),
281            "Cannot have an iteration block with limited parallelism"
282        );
283
284        let state = IterationStateHandle::new(initial_state.clone());
285        let state_clone = state.clone();
286        let env = self.ctx.clone();
287
288        // the id of the block where IterationEnd is. At this moment we cannot know it, so we
289        // store a fake value inside this and as soon as we know it we set it to the right value.
290        let feedback_block_id = Arc::new(AtomicUsize::new(0));
291
292        let output_block = env.lock().new_block(
293            IterationLeader::new(
294                initial_state,
295                num_iterations,
296                global_fold,
297                loop_condition,
298                feedback_block_id.clone(),
299            ),
300            Default::default(),
301            self.block.iteration_ctx.clone(),
302        );
303        let output_id = output_block.id;
304        // the output stream is outside this loop, so it doesn't have the lock for this state
305
306        // the lock for synchronizing the access to the state of this iteration
307        let state_lock = Arc::new(IterationStateLock::default());
308
309        let mut iter_start =
310            self.add_operator(|prev| Replay::new(prev, state, output_id, state_lock.clone()));
311        let replay_block_id = iter_start.block.id;
312
313        // save the stack of the iteration for checking the stream returned by the body
314        iter_start.block.iteration_ctx.push(state_lock);
315        let pre_iter_stack = iter_start.block.iteration_ctx();
316
317        let mut iter_end = body(iter_start, state_clone)
318            .key_by(|_| ())
319            .fold(DeltaUpdate::default(), local_fold)
320            .drop_key();
321
322        let post_iter_stack = iter_end.block.iteration_ctx();
323        if pre_iter_stack != post_iter_stack {
324            panic!("The body of the iteration should return the stream given as parameter");
325        }
326        iter_end.block.iteration_ctx.pop().unwrap();
327
328        let iter_end = iter_end.add_operator(|prev| IterationEnd::new(prev, output_id));
329        let iteration_end_block_id = iter_end.block.id;
330
331        let mut ctx_lock = iter_end.ctx.lock();
332        let scheduler = ctx_lock.scheduler_mut();
333        // connect the IterationEnd to the IterationLeader
334        scheduler.connect_blocks(
335            iteration_end_block_id,
336            output_id,
337            TypeId::of::<DeltaUpdate>(),
338        );
339        scheduler.connect_blocks(
340            output_id,
341            replay_block_id,
342            TypeId::of::<StateFeedback<State>>(),
343        );
344        scheduler.schedule_block(iter_end.block);
345        drop(ctx_lock);
346
347        // store the id of the block containing the IterationEnd
348        feedback_block_id.store(iteration_end_block_id as usize, Ordering::Release);
349
350        // TODO: check parallelism and make sure the blocks are spawned on the same replicas
351
352        // FIXME: this add_block is here just to make sure that the NextStrategy of output_stream
353        //        is not changed by the following operators. This because the next strategy affects
354        //        the connections made by the scheduler and if accidentally set to OnlyOne will
355        //        break the connections.
356        Stream::new(env, output_block).split_block(End::new, NextStrategy::random())
357    }
358}