noir_compute/operator/iteration/
iterate_delta.rs1use serde::{Deserialize, Serialize};
2
3use crate::block::{BlockStructure, OperatorStructure};
4use crate::operator::iteration::IterationStateHandle;
5use crate::operator::{
6 ExchangeData, ExchangeDataKey, Operator, SimpleStartOperator, StreamElement,
7};
8use crate::scheduler::ExecutionMetadata;
9use crate::KeyedStream;
10
11#[derive(Clone, Serialize, Deserialize, Default, Debug)]
12struct TerminationCond {
13 something_changed: bool,
14 last_iteration: bool,
15 iter: usize,
16}
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19enum Msg<I, U, D, O> {
20 Init(I),
21 Update(U),
22 Delta(D),
23 Output(O),
24}
25
26impl<I, U, D, O> Msg<I, U, D, O> {
27 #[must_use]
31 fn is_update(&self) -> bool {
32 matches!(self, Self::Update(..))
33 }
34
35 #[must_use]
39 fn is_output(&self) -> bool {
40 matches!(self, Self::Output(..))
41 }
42
43 fn unwrap_update(self) -> U {
44 if let Self::Update(v) = self {
45 v
46 } else {
47 panic!("unwrap on wrong iteration message type")
48 }
49 }
50
51 fn unwrap_output(self) -> O {
52 if let Self::Output(v) = self {
53 v
54 } else {
55 panic!("unwrap on wrong iteration message type")
56 }
57 }
58}
59
60#[derive(Clone)]
61pub struct DeltaIterate<
62 Key: ExchangeData,
63 I: ExchangeData,
64 U: ExchangeData,
65 D: ExchangeData,
66 O: ExchangeData,
67> {
68 prev: SimpleStartOperator<(Key, Msg<I, U, D, O>)>,
69}
70
71impl<Key: ExchangeData, I: ExchangeData, U: ExchangeData, D: ExchangeData, O: ExchangeData> Operator
72 for DeltaIterate<Key, I, U, D, O>
73{
74 type Out = (Key, U);
75
76 fn setup(&mut self, metadata: &mut ExecutionMetadata) {
77 self.prev.setup(metadata);
78 }
79
80 fn next(&mut self) -> StreamElement<(Key, U)> {
81 self.prev.next().map(|(k, v)| (k, v.unwrap_update()))
82 }
83
84 fn structure(&self) -> BlockStructure {
85 self.prev
86 .structure()
87 .add_operator(OperatorStructure::new::<(Key, U), _>("DeltaIterate"))
88 }
89}
90
91impl<Key: ExchangeData, I: ExchangeData, U: ExchangeData, D: ExchangeData, O: ExchangeData>
92 std::fmt::Display for DeltaIterate<Key, I, U, D, O>
93{
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 write!(f, "UpdateStart")
96 }
97}
98
99impl<Key: ExchangeDataKey, In: ExchangeData + Default, OperatorChain> KeyedStream<OperatorChain>
100where
101 OperatorChain: Operator<Out = (Key, In)> + 'static,
102{
103 pub fn delta_iterate<U: ExchangeData, D: ExchangeData, O: ExchangeData, Body, BodyOperator>(
105 self,
106 num_iterations: usize,
107 process_delta: impl Fn(&Key, &mut In, D) + Clone + Send + 'static,
108 make_update: impl Fn(&Key, &mut In) -> U + Clone + Send + 'static,
109 make_output: impl Fn(&Key, In) -> O + Clone + Send + 'static,
110 condition: impl Fn(&D) -> bool + Clone + Send + 'static,
111 body: Body,
112 ) -> KeyedStream<impl Operator<Out = (Key, O)>>
113 where
114 Body: FnOnce(KeyedStream<DeltaIterate<Key, In, U, D, O>>) -> KeyedStream<BodyOperator>
115 + 'static,
116 BodyOperator: Operator<Out = (Key, D)> + 'static,
117 {
118 let (state, out) = self.map(|(_, v)| Msg::Init(v)).unkey().iterate(
119 num_iterations,
120 TerminationCond {
121 something_changed: false,
122 last_iteration: false,
123 iter: 0,
124 },
125 move |s, state: IterationStateHandle<TerminationCond>| {
126 let mut routes = s
127 .to_keyed()
128 .rich_map({
129 let mut local_state: In = Default::default();
130 move |(k, msg): (_, Msg<_, _, _, _>)| {
131 let state = state.get();
132 if state.last_iteration || state.iter == num_iterations - 2 {
133 return Msg::Output(make_output(
134 k,
135 std::mem::take(&mut local_state),
136 ));
137 }
138
139 match msg {
140 Msg::Init(init) => local_state = init,
141 Msg::Delta(delta) => process_delta(k, &mut local_state, delta),
142 _ => unreachable!("invalid message at DeltaIterate start"),
143 }
144
145 Msg::Update(make_update(k, &mut local_state))
146 }
147 })
148 .unkey()
149 .route()
150 .add_route(|(_, v)| v.is_update())
151 .add_route(|(_, v)| v.is_output())
152 .build_inner()
153 .into_iter();
154
155 let update_stream = body(
156 routes
157 .next()
158 .unwrap()
159 .to_keyed()
160 .add_operator(|prev| DeltaIterate { prev }),
161 )
162 .map(|(_, v)| Msg::Delta(v))
163 .unkey();
164 let output_stream = routes.next().unwrap();
165
166 update_stream.merge(output_stream)
167 },
168 move |changed: &mut TerminationCond, x| match x.1 {
169 Msg::Delta(u) if (condition)(&u) => changed.something_changed = true,
170 Msg::Delta(_) => {}
171 Msg::Output(_) => changed.last_iteration = true,
172 _ => unreachable!(),
173 },
174 |global, local| {
175 global.something_changed |= local.something_changed;
176 global.last_iteration |= local.last_iteration;
177 },
178 |s| {
179 let cond = !s.last_iteration;
180 if !s.something_changed {
181 s.last_iteration = true;
182 }
183 s.something_changed = false;
184 s.iter += 1;
185 cond
186 },
187 );
188
189 state.for_each(std::mem::drop);
190 out.to_keyed().map(|(_, v)| v.unwrap_output())
191 }
192}