noir_compute/operator/iteration/
mod.rs

1//! Utilities for iteration operators
2
3use std::cell::UnsafeCell;
4use std::fmt::{Debug, Formatter};
5use std::sync::{Arc, Condvar, Mutex};
6
7use serde::{Deserialize, Serialize};
8
9mod iterate;
10mod iterate_delta;
11mod iteration_end;
12mod leader;
13mod replay;
14mod state_handler;
15
16#[derive(Debug, Serialize, Deserialize, Clone)]
17pub(crate) enum IterationResult {
18    /// Continue iterating
19    Continue,
20    /// The iteration cycle has finished
21    Finished,
22}
23impl IterationResult {
24    pub(crate) fn from_condition(should_continue: bool) -> IterationResult {
25        if should_continue {
26            Self::Continue
27        } else {
28            Self::Finished
29        }
30    }
31}
32
33/// The information about the new state of an iteration:
34///
35/// - a boolean indicating if a new iteration should start
36/// - the new state for the next iteration
37pub(crate) type StateFeedback<State> = (IterationResult, State);
38
39/// A shared reference to the state of an iteration,
40///
41/// This state is shared between all the replicas inside the host. Additional synchronization must
42/// be put in place for using this reference safely.
43#[derive(Clone, Debug)]
44pub(crate) struct IterationStateRef<State> {
45    /// The storage for the state.
46    ///
47    /// Access it via `set` and `get` **with additional synchronization** in place.
48    state: Arc<UnsafeCell<State>>,
49}
50
51impl<State> IterationStateRef<State> {
52    fn new(init: State) -> Self {
53        Self {
54            state: Arc::new(UnsafeCell::new(init)),
55        }
56    }
57
58    /// Change the value of the state with the one specified.
59    ///
60    /// ## Safety
61    ///
62    /// This will just write to a mutable pointer. Additional synchronization should be put in place
63    /// before calling this method. All the references obtained with `get` should be dropped before
64    /// calling this method, and no 2 thread can call this simultaneously.
65    unsafe fn set(&self, new_state: State) {
66        let state_ptr = &mut *self.state.get();
67        *state_ptr = new_state;
68    }
69
70    /// Obtain a reference to the state.
71    ///
72    /// ## Safety
73    ///
74    /// This will just unsafely borrow the local state. Additional synchronization should be put in
75    /// place before calling this method. The reference returned by this method should not be used
76    /// while calling `set`.
77    unsafe fn get(&self) -> &State {
78        &*self.state.get()
79    }
80}
81
82/// We grant that the user of this structs has put enough synchronization to avoid undefined
83/// behaviour.
84unsafe impl<State: Send + Sync> Send for IterationStateRef<State> {}
85
86/// Handle to the state of the iteration.
87#[derive(Clone)]
88pub struct IterationStateHandle<T> {
89    /// A reference to the output state that will be accessible by the user.
90    result: IterationStateRef<T>,
91}
92
93impl<T: Clone> IterationStateHandle<T> {
94    pub(crate) fn new(init: T) -> Self {
95        Self {
96            result: IterationStateRef::new(init),
97        }
98    }
99
100    /// Set the new value for the state of the iteration.
101    ///
102    /// ## Safety
103    ///
104    /// It's important that the operator that manages this state takes extra care ensuring that no
105    /// undefined behaviour due to data-race is present. Calling `set` while there are still some
106    /// references from `get` around is forbidden. Calling `set` while another thread is calling
107    /// `set` is forbidden.
108    pub(crate) unsafe fn set(&self, new_state: T) {
109        self.result.set(new_state);
110    }
111
112    /// Obtain a reference to the global iteration state.
113    ///
114    /// ## Safety
115    ///
116    /// The returned reference must not be stored in any way, especially between iteration
117    /// boundaries (when the state is updated).
118    pub fn get(&self) -> &T {
119        unsafe { self.result.get() }
120    }
121}
122
123impl<T> Debug for IterationStateHandle<T> {
124    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
125        f.debug_struct("IterationStateHandle").finish()
126    }
127}
128
129/// When the iteration block sends the `FlushAndRestart` message, the state of this host is in a
130/// critical state: the iteration block does not update it until it receives the new state from the
131/// leader and the downstream operators may access the state of the current iteration.
132///
133/// When the leader sends the new state we know for sure that all the downstream operators have
134/// ended their computation and won't access the state until they receive the first message of the
135/// next iteration. When there are shuffles this may lead to problems since the first message a
136/// block receives can be from another host that received the state before this host. When this
137/// happens the downstream operator will access the old state and this cause undefined behaviour
138/// and data-race with the local leader.
139///
140/// To avoid this race we put in place a lock between iterations: when a block receives a
141/// `FlushAndRestart`, before accepting any other message it waits for the local leader to update
142/// the local state.
143///
144/// We cannot simply use a `Mutex<bool>` since it's possible that the state is unlocked, the stream
145/// is consumed very quickly and the state is locked immediately after. If this happens fast enough
146/// the downstream operators may fail to notice that the state got unlocked, and deadlock the
147/// stream.
148///
149/// To avoid this problem, instead of keeping a _locked_ boolean, we keep the _generation_ of the
150/// state: every time an iteration ends the generation is incremented. This allows the downstream
151/// operators to know if they _skipped_ an iteration.
152///
153/// The value of the generation has this meaning:
154/// - even value: the state is clean and can be accessed safely
155/// - odd value: the state is locked, it cannot be locked again. The state can be accessed safely
156///   with a generation lower or equal to this generation.
157///
158/// This means that locking and unlocking the state increments the generation by 2.
159#[derive(Debug, Default)]
160#[allow(clippy::mutex_atomic)]
161pub(crate) struct IterationStateLock {
162    /// The index of the generation.
163    generation: Mutex<usize>,
164    /// A conditional variable for notifying the downstream operators that the state got unlocked.
165    cond_var: Condvar,
166}
167
168#[allow(clippy::mutex_atomic)]
169impl IterationStateLock {
170    /// Lock the state.
171    ///
172    /// This operation is idempotent until `unlock` is called.
173    pub fn lock(&self) {
174        let mut lock = self.generation.lock().unwrap();
175        if *lock % 2 == 0 {
176            *lock += 1;
177        }
178    }
179
180    /// Unlock a locked state.
181    ///
182    /// This will notify all the operators that are waiting for the state to be updated.
183    pub fn unlock(&self) {
184        let mut lock = self.generation.lock().unwrap();
185        assert_eq!(*lock % 2, 1, "cannot unlock a non-locked lock");
186        *lock += 1;
187        self.cond_var.notify_all();
188    }
189
190    /// Block the thread if the current generation of the lock is lower that the requested one.
191    pub fn wait_for_update(&self, generation: usize) {
192        let _gen = self
193            .cond_var
194            .wait_while(self.generation.lock().unwrap(), |r| *r < generation)
195            .unwrap();
196    }
197}