noir_compute/operator/iteration/
replay.rs1use std::any::TypeId;
2use std::fmt::Display;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5
6use crate::block::{BlockStructure, NextStrategy, OperatorReceiver, OperatorStructure};
7
8use crate::network::Coord;
9use crate::operator::end::End;
10use crate::operator::iteration::iteration_end::IterationEnd;
11use crate::operator::iteration::leader::IterationLeader;
12use crate::operator::iteration::state_handler::IterationStateHandler;
13use crate::operator::iteration::{
14 IterationResult, IterationStateHandle, IterationStateLock, StateFeedback,
15};
16use crate::operator::{Data, ExchangeData, Operator, StreamElement};
17use crate::scheduler::{BlockId, ExecutionMetadata};
18use crate::stream::Stream;
19
20#[derive(Debug, Clone)]
24pub struct Replay<Out: Data, State: ExchangeData, OperatorChain>
25where
26 OperatorChain: Operator<Out = Out>,
27{
28 coord: Coord,
30
31 state: IterationStateHandler<State>,
33
34 prev: OperatorChain,
36
37 content: Vec<StreamElement<Out>>,
39 content_index: usize,
41
42 input_finished: bool,
44}
45
46impl<Out: Data, State: ExchangeData, OperatorChain> Display for Replay<Out, State, OperatorChain>
47where
48 OperatorChain: Operator<Out = Out>,
49{
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(
52 f,
53 "{} -> Replay<{}>",
54 self.prev,
55 std::any::type_name::<Out>()
56 )
57 }
58}
59
60impl<Out: Data, State: ExchangeData, OperatorChain> Replay<Out, State, OperatorChain>
61where
62 OperatorChain: Operator<Out = Out>,
63{
64 fn new(
65 prev: OperatorChain,
66 state_ref: IterationStateHandle<State>,
67 leader_block_id: BlockId,
68 state_lock: Arc<IterationStateLock>,
69 ) -> Self {
70 Self {
71 coord: Coord::new(0, 0, 0),
73
74 prev,
75 content: Default::default(),
76 content_index: 0,
77 input_finished: false,
78 state: IterationStateHandler::new(leader_block_id, state_ref, state_lock),
79 }
80 }
81
82 fn input_next(&mut self) -> Option<StreamElement<Out>> {
83 if self.input_finished {
84 return None;
85 }
86
87 let item = match self.prev.next() {
88 StreamElement::FlushAndRestart => {
89 log::debug!(
90 "Replay at {} received all the input: {} elements total",
91 self.coord,
92 self.content.len()
93 );
94 self.input_finished = true;
95 self.content.push(StreamElement::FlushAndRestart);
96 self.content_index = self.content.len();
98 self.state.lock();
100 StreamElement::FlushAndRestart
101 }
102 el @ StreamElement::Item(_)
104 | el @ StreamElement::Timestamped(_, _)
105 | el @ StreamElement::Watermark(_) => {
106 self.content.push(el.clone());
107 el
108 }
109 StreamElement::FlushBatch => StreamElement::FlushBatch,
111 StreamElement::Terminate => {
112 log::debug!("Replay at {} is terminating", self.coord);
113 StreamElement::Terminate
114 }
115 };
116 Some(item)
117 }
118
119 fn wait_update(&mut self) -> StateFeedback<State> {
120 let state_receiver = self.state.state_receiver().unwrap();
121 loop {
128 let message = state_receiver.recv().unwrap();
129 assert!(message.num_items() == 1);
130
131 match message.into_iter().next().unwrap() {
132 StreamElement::Item((should_continue, new_state)) => {
133 return (should_continue, new_state);
134 }
135 StreamElement::FlushBatch => {}
136 StreamElement::FlushAndRestart => {}
137 m => unreachable!(
138 "Iterate received invalid message from IterationLeader: {}",
139 m.variant()
140 ),
141 }
142 }
143 }
144}
145
146impl<Out: Data, State: ExchangeData + Sync, OperatorChain> Operator
147 for Replay<Out, State, OperatorChain>
148where
149 OperatorChain: Operator<Out = Out>,
150{
151 type Out = Out;
152
153 fn setup(&mut self, metadata: &mut ExecutionMetadata) {
154 self.coord = metadata.coord;
155 self.prev.setup(metadata);
156 self.state.setup(metadata);
157 }
158
159 fn next(&mut self) -> StreamElement<Out> {
160 loop {
161 if let Some(value) = self.input_next() {
162 return value;
163 }
164 if self.content_index < self.content.len() {
168 let item = self.content[self.content_index].clone();
169 self.content_index += 1;
170 if matches!(item, StreamElement::FlushAndRestart) {
171 self.state.lock();
173 }
174 return item;
175 }
176
177 log::debug!("Replay at {} has ended the iteration", self.coord);
178
179 self.content_index = 0;
180
181 let state_update = self.wait_update();
182
183 if let IterationResult::Finished = self.state.wait_sync_state(state_update) {
184 log::debug!("Replay block at {} ended the iteration", self.coord);
185 self.content.clear();
187 self.input_finished = false;
188 }
189
190 }
193 }
194
195 fn structure(&self) -> BlockStructure {
196 let mut operator = OperatorStructure::new::<Out, _>("Replay");
197 operator
198 .receivers
199 .push(OperatorReceiver::new::<StateFeedback<State>>(
200 self.state.leader_block_id,
201 ));
202 self.prev.structure().add_operator(operator)
203 }
204}
205
206impl<Out: Data, OperatorChain> Stream<OperatorChain>
207where
208 OperatorChain: Operator<Out = Out> + 'static,
209{
210 pub fn replay<
257 Body,
258 DeltaUpdate: ExchangeData + Default,
259 State: ExchangeData + Sync,
260 OperatorChain2,
261 >(
262 self,
263 num_iterations: usize,
264 initial_state: State,
265 body: Body,
266 local_fold: impl Fn(&mut DeltaUpdate, Out) + Send + Clone + 'static,
267 global_fold: impl Fn(&mut State, DeltaUpdate) + Send + Clone + 'static,
268 loop_condition: impl Fn(&mut State) -> bool + Send + Clone + 'static,
269 ) -> Stream<impl Operator<Out = State>>
270 where
271 Body: FnOnce(
272 Stream<Replay<Out, State, OperatorChain>>,
273 IterationStateHandle<State>,
274 ) -> Stream<OperatorChain2>,
275 OperatorChain2: Operator<Out = Out> + 'static,
276 {
277 assert!(
280 self.block.scheduling.replication.is_unlimited(),
281 "Cannot have an iteration block with limited parallelism"
282 );
283
284 let state = IterationStateHandle::new(initial_state.clone());
285 let state_clone = state.clone();
286 let env = self.ctx.clone();
287
288 let feedback_block_id = Arc::new(AtomicUsize::new(0));
291
292 let output_block = env.lock().new_block(
293 IterationLeader::new(
294 initial_state,
295 num_iterations,
296 global_fold,
297 loop_condition,
298 feedback_block_id.clone(),
299 ),
300 Default::default(),
301 self.block.iteration_ctx.clone(),
302 );
303 let output_id = output_block.id;
304 let state_lock = Arc::new(IterationStateLock::default());
308
309 let mut iter_start =
310 self.add_operator(|prev| Replay::new(prev, state, output_id, state_lock.clone()));
311 let replay_block_id = iter_start.block.id;
312
313 iter_start.block.iteration_ctx.push(state_lock);
315 let pre_iter_stack = iter_start.block.iteration_ctx();
316
317 let mut iter_end = body(iter_start, state_clone)
318 .key_by(|_| ())
319 .fold(DeltaUpdate::default(), local_fold)
320 .drop_key();
321
322 let post_iter_stack = iter_end.block.iteration_ctx();
323 if pre_iter_stack != post_iter_stack {
324 panic!("The body of the iteration should return the stream given as parameter");
325 }
326 iter_end.block.iteration_ctx.pop().unwrap();
327
328 let iter_end = iter_end.add_operator(|prev| IterationEnd::new(prev, output_id));
329 let iteration_end_block_id = iter_end.block.id;
330
331 let mut ctx_lock = iter_end.ctx.lock();
332 let scheduler = ctx_lock.scheduler_mut();
333 scheduler.connect_blocks(
335 iteration_end_block_id,
336 output_id,
337 TypeId::of::<DeltaUpdate>(),
338 );
339 scheduler.connect_blocks(
340 output_id,
341 replay_block_id,
342 TypeId::of::<StateFeedback<State>>(),
343 );
344 scheduler.schedule_block(iter_end.block);
345 drop(ctx_lock);
346
347 feedback_block_id.store(iteration_end_block_id as usize, Ordering::Release);
349
350 Stream::new(env, output_block).split_block(End::new, NextStrategy::random())
357 }
358}