noir_compute/operator/iteration/mod.rs
1//! Utilities for iteration operators
2
3use std::cell::UnsafeCell;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Condvar, Mutex};
6
7use serde::{Deserialize, Serialize};
8
9mod iterate;
10mod iterate_delta;
11mod iteration_end;
12mod leader;
13mod replay;
14mod state_handler;
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub(crate) enum IterationResult {
18 /// Continue iterating
19 Continue,
20 /// The iteration cycle has finished
21 Finished,
22}
23impl IterationResult {
24 pub(crate) fn from_condition(should_continue: bool) -> IterationResult {
25 if should_continue {
26 Self::Continue
27 } else {
28 Self::Finished
29 }
30 }
31}
32
33/// The information about the new state of an iteration:
34///
35/// - a boolean indicating if a new iteration should start
36/// - the new state for the next iteration
37pub(crate) type StateFeedback<State> = (IterationResult, State);
38
39/// A shared reference to the state of an iteration,
40///
41/// This state is shared between all the replicas inside the host. Additional synchronization must
42/// be put in place for using this reference safely.
43#[derive(Clone, Debug)]
44pub(crate) struct IterationStateRef<State> {
45 /// The storage for the state.
46 ///
47 /// Access it via `set` and `get` **with additional synchronization** in place.
48 state: Arc<UnsafeCell<State>>,
49}
50
51impl<State> IterationStateRef<State> {
52 fn new(init: State) -> Self {
53 Self {
54 state: Arc::new(UnsafeCell::new(init)),
55 }
56 }
57
58 /// Change the value of the state with the one specified.
59 ///
60 /// ## Safety
61 ///
62 /// This will just write to a mutable pointer. Additional synchronization should be put in place
63 /// before calling this method. All the references obtained with `get` should be dropped before
64 /// calling this method, and no 2 thread can call this simultaneously.
65 unsafe fn set(&self, new_state: State) {
66 let state_ptr = &mut *self.state.get();
67 *state_ptr = new_state;
68 }
69
70 /// Obtain a reference to the state.
71 ///
72 /// ## Safety
73 ///
74 /// This will just unsafely borrow the local state. Additional synchronization should be put in
75 /// place before calling this method. The reference returned by this method should not be used
76 /// while calling `set`.
77 unsafe fn get(&self) -> &State {
78 &*self.state.get()
79 }
80}
81
82/// We grant that the user of this structs has put enough synchronization to avoid undefined
83/// behaviour.
84unsafe impl<State: Send + Sync> Send for IterationStateRef<State> {}
85
86/// Handle to the state of the iteration.
87#[derive(Clone)]
88pub struct IterationStateHandle<T> {
89 /// A reference to the output state that will be accessible by the user.
90 result: IterationStateRef<T>,
91}
92
93impl<T: Clone> IterationStateHandle<T> {
94 pub(crate) fn new(init: T) -> Self {
95 Self {
96 result: IterationStateRef::new(init),
97 }
98 }
99
100 /// Set the new value for the state of the iteration.
101 ///
102 /// ## Safety
103 ///
104 /// It's important that the operator that manages this state takes extra care ensuring that no
105 /// undefined behaviour due to data-race is present. Calling `set` while there are still some
106 /// references from `get` around is forbidden. Calling `set` while another thread is calling
107 /// `set` is forbidden.
108 pub(crate) unsafe fn set(&self, new_state: T) {
109 self.result.set(new_state);
110 }
111
112 /// Obtain a reference to the global iteration state.
113 ///
114 /// ## Safety
115 ///
116 /// The returned reference must not be stored in any way, especially between iteration
117 /// boundaries (when the state is updated).
118 pub fn get(&self) -> &T {
119 unsafe { self.result.get() }
120 }
121}
122
123impl<T> Debug for IterationStateHandle<T> {
124 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125 f.debug_struct("IterationStateHandle").finish()
126 }
127}
128
129/// When the iteration block sends the `FlushAndRestart` message, the state of this host is in a
130/// critical state: the iteration block does not update it until it receives the new state from the
131/// leader and the downstream operators may access the state of the current iteration.
132///
133/// When the leader sends the new state we know for sure that all the downstream operators have
134/// ended their computation and won't access the state until they receive the first message of the
135/// next iteration. When there are shuffles this may lead to problems since the first message a
136/// block receives can be from another host that received the state before this host. When this
137/// happens the downstream operator will access the old state and this cause undefined behaviour
138/// and data-race with the local leader.
139///
140/// To avoid this race we put in place a lock between iterations: when a block receives a
141/// `FlushAndRestart`, before accepting any other message it waits for the local leader to update
142/// the local state.
143///
144/// We cannot simply use a `Mutex<bool>` since it's possible that the state is unlocked, the stream
145/// is consumed very quickly and the state is locked immediately after. If this happens fast enough
146/// the downstream operators may fail to notice that the state got unlocked, and deadlock the
147/// stream.
148///
149/// To avoid this problem, instead of keeping a _locked_ boolean, we keep the _generation_ of the
150/// state: every time an iteration ends the generation is incremented. This allows the downstream
151/// operators to know if they _skipped_ an iteration.
152///
153/// The value of the generation has this meaning:
154/// - even value: the state is clean and can be accessed safely
155/// - odd value: the state is locked, it cannot be locked again. The state can be accessed safely
156/// with a generation lower or equal to this generation.
157///
158/// This means that locking and unlocking the state increments the generation by 2.
159#[derive(Debug, Default)]
160#[allow(clippy::mutex_atomic)]
161pub(crate) struct IterationStateLock {
162 /// The index of the generation.
163 generation: Mutex<usize>,
164 /// A conditional variable for notifying the downstream operators that the state got unlocked.
165 cond_var: Condvar,
166}
167
168#[allow(clippy::mutex_atomic)]
169impl IterationStateLock {
170 /// Lock the state.
171 ///
172 /// This operation is idempotent until `unlock` is called.
173 pub fn lock(&self) {
174 let mut lock = self.generation.lock().unwrap();
175 if *lock % 2 == 0 {
176 *lock += 1;
177 }
178 }
179
180 /// Unlock a locked state.
181 ///
182 /// This will notify all the operators that are waiting for the state to be updated.
183 pub fn unlock(&self) {
184 let mut lock = self.generation.lock().unwrap();
185 assert_eq!(*lock % 2, 1, "cannot unlock a non-locked lock");
186 *lock += 1;
187 self.cond_var.notify_all();
188 }
189
190 /// Block the thread if the current generation of the lock is lower that the requested one.
191 pub fn wait_for_update(&self, generation: usize) {
192 let _gen = self
193 .cond_var
194 .wait_while(self.generation.lock().unwrap(), |r| *r < generation)
195 .unwrap();
196 }
197}