noir_compute/operator/iteration/
iterate_delta.rs

1use serde::{Deserialize, Serialize};
2
3use crate::block::{BlockStructure, OperatorStructure};
4use crate::operator::iteration::IterationStateHandle;
5use crate::operator::{
6    ExchangeData, ExchangeDataKey, Operator, SimpleStartOperator, StreamElement,
7};
8use crate::scheduler::ExecutionMetadata;
9use crate::KeyedStream;
10
11#[derive(Clone, Serialize, Deserialize, Default, Debug)]
12struct TerminationCond {
13    something_changed: bool,
14    last_iteration: bool,
15    iter: usize,
16}
17
18#[derive(Clone, Serialize, Deserialize, Debug)]
19enum Msg<I, U, D, O> {
20    Init(I),
21    Update(U),
22    Delta(D),
23    Output(O),
24}
25
26impl<I, U, D, O> Msg<I, U, D, O> {
27    /// Returns `true` if the msg is [`Update`].
28    ///
29    /// [`Update`]: Msg::Update
30    #[must_use]
31    fn is_update(&self) -> bool {
32        matches!(self, Self::Update(..))
33    }
34
35    /// Returns `true` if the msg is [`Output`].
36    ///
37    /// [`Output`]: Msg::Output
38    #[must_use]
39    fn is_output(&self) -> bool {
40        matches!(self, Self::Output(..))
41    }
42
43    fn unwrap_update(self) -> U {
44        if let Self::Update(v) = self {
45            v
46        } else {
47            panic!("unwrap on wrong iteration message type")
48        }
49    }
50
51    fn unwrap_output(self) -> O {
52        if let Self::Output(v) = self {
53            v
54        } else {
55            panic!("unwrap on wrong iteration message type")
56        }
57    }
58}
59
60#[derive(Clone)]
61pub struct DeltaIterate<
62    Key: ExchangeData,
63    I: ExchangeData,
64    U: ExchangeData,
65    D: ExchangeData,
66    O: ExchangeData,
67> {
68    prev: SimpleStartOperator<(Key, Msg<I, U, D, O>)>,
69}
70
71impl<Key: ExchangeData, I: ExchangeData, U: ExchangeData, D: ExchangeData, O: ExchangeData> Operator
72    for DeltaIterate<Key, I, U, D, O>
73{
74    type Out = (Key, U);
75
76    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
77        self.prev.setup(metadata);
78    }
79
80    fn next(&mut self) -> StreamElement<(Key, U)> {
81        self.prev.next().map(|(k, v)| (k, v.unwrap_update()))
82    }
83
84    fn structure(&self) -> BlockStructure {
85        self.prev
86            .structure()
87            .add_operator(OperatorStructure::new::<(Key, U), _>("DeltaIterate"))
88    }
89}
90
91impl<Key: ExchangeData, I: ExchangeData, U: ExchangeData, D: ExchangeData, O: ExchangeData>
92    std::fmt::Display for DeltaIterate<Key, I, U, D, O>
93{
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        write!(f, "UpdateStart")
96    }
97}
98
99impl<Key: ExchangeDataKey, In: ExchangeData + Default, OperatorChain> KeyedStream<OperatorChain>
100where
101    OperatorChain: Operator<Out = (Key, In)> + 'static,
102{
103    /// TODO DOCS
104    pub fn delta_iterate<U: ExchangeData, D: ExchangeData, O: ExchangeData, Body, BodyOperator>(
105        self,
106        num_iterations: usize,
107        process_delta: impl Fn(&Key, &mut In, D) + Clone + Send + 'static,
108        make_update: impl Fn(&Key, &mut In) -> U + Clone + Send + 'static,
109        make_output: impl Fn(&Key, In) -> O + Clone + Send + 'static,
110        condition: impl Fn(&D) -> bool + Clone + Send + 'static,
111        body: Body,
112    ) -> KeyedStream<impl Operator<Out = (Key, O)>>
113    where
114        Body: FnOnce(KeyedStream<DeltaIterate<Key, In, U, D, O>>) -> KeyedStream<BodyOperator>
115            + 'static,
116        BodyOperator: Operator<Out = (Key, D)> + 'static,
117    {
118        let (state, out) = self.map(|(_, v)| Msg::Init(v)).unkey().iterate(
119            num_iterations,
120            TerminationCond {
121                something_changed: false,
122                last_iteration: false,
123                iter: 0,
124            },
125            move |s, state: IterationStateHandle<TerminationCond>| {
126                let mut routes = s
127                    .to_keyed()
128                    .rich_map({
129                        let mut local_state: In = Default::default();
130                        move |(k, msg): (_, Msg<_, _, _, _>)| {
131                            let state = state.get();
132                            if state.last_iteration || state.iter == num_iterations - 2 {
133                                return Msg::Output(make_output(
134                                    k,
135                                    std::mem::take(&mut local_state),
136                                ));
137                            }
138
139                            match msg {
140                                Msg::Init(init) => local_state = init,
141                                Msg::Delta(delta) => process_delta(k, &mut local_state, delta),
142                                _ => unreachable!("invalid message at DeltaIterate start"),
143                            }
144
145                            Msg::Update(make_update(k, &mut local_state))
146                        }
147                    })
148                    .unkey()
149                    .route()
150                    .add_route(|(_, v)| v.is_update())
151                    .add_route(|(_, v)| v.is_output())
152                    .build_inner()
153                    .into_iter();
154
155                let update_stream = body(
156                    routes
157                        .next()
158                        .unwrap()
159                        .to_keyed()
160                        .add_operator(|prev| DeltaIterate { prev }),
161                )
162                .map(|(_, v)| Msg::Delta(v))
163                .unkey();
164                let output_stream = routes.next().unwrap();
165
166                update_stream.merge(output_stream)
167            },
168            move |changed: &mut TerminationCond, x| match x.1 {
169                Msg::Delta(u) if (condition)(&u) => changed.something_changed = true,
170                Msg::Delta(_) => {}
171                Msg::Output(_) => changed.last_iteration = true,
172                _ => unreachable!(),
173            },
174            |global, local| {
175                global.something_changed |= local.something_changed;
176                global.last_iteration |= local.last_iteration;
177            },
178            |s| {
179                let cond = !s.last_iteration;
180                if !s.something_changed {
181                    s.last_iteration = true;
182                }
183                s.something_changed = false;
184                s.iter += 1;
185                cond
186            },
187        );
188
189        state.for_each(std::mem::drop);
190        out.to_keyed().map(|(_, v)| v.unwrap_output())
191    }
192}