noir_compute/operator/source/
parallel_iterator.rs

1use std::fmt::Display;
2use std::ops::Range;
3
4use crate::block::{BlockStructure, OperatorKind, OperatorStructure, Replication};
5use crate::operator::source::Source;
6use crate::operator::{Operator, StreamElement};
7use crate::scheduler::ExecutionMetadata;
8use crate::{CoordUInt, Stream};
9
10pub trait IntoParallelSource: Clone + Send {
11    type Iter: Iterator;
12    fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter;
13}
14
15impl<It, G> IntoParallelSource for G
16where
17    It: Iterator + Send + 'static,
18    G: FnOnce(CoordUInt, CoordUInt) -> It + Send + Clone,
19{
20    type Iter = It;
21
22    fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
23        self(index, peers)
24    }
25}
26
27impl IntoParallelSource for Range<u64> {
28    type Iter = Range<u64>;
29
30    fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
31        let n = self.end - self.start;
32        let chunk_size = (n.saturating_add(peers - 1)) / peers;
33        let start = self.start.saturating_add(index * chunk_size);
34        let end = (start.saturating_add(chunk_size))
35            .min(self.end)
36            .max(self.start);
37
38        start..end
39    }
40}
41
42macro_rules! impl_into_parallel_source_range {
43    ($t:ty) => {
44        impl IntoParallelSource for Range<$t> {
45            type Iter = Range<$t>;
46
47            fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
48                let index: i64 = index.try_into().unwrap();
49                let peers: i64 = peers.try_into().unwrap();
50                let n = self.end as i64 - self.start as i64;
51                let chunk_size = (n.saturating_add(peers - 1)) / peers;
52                let start = (self.start as i64).saturating_add(index * chunk_size);
53                let end = (start.saturating_add(chunk_size))
54                    .min(self.end as i64)
55                    .max(self.start as i64);
56
57                let (start, end) = (start.try_into().unwrap(), end.try_into().unwrap());
58                start..end
59            }
60        }
61    };
62}
63
64impl_into_parallel_source_range!(u8);
65impl_into_parallel_source_range!(u16);
66impl_into_parallel_source_range!(u32);
67
68impl_into_parallel_source_range!(usize);
69
70impl_into_parallel_source_range!(i8);
71impl_into_parallel_source_range!(i16);
72impl_into_parallel_source_range!(i32);
73impl_into_parallel_source_range!(i64);
74impl_into_parallel_source_range!(isize);
75
76/// This enum wraps either an `Iterator` that yields the items, or a generator function that
77/// produces such iterator.
78///
79/// This enum is `Clone` only _before_ generating the iterator. The generator function must be
80/// `Clone`, but the resulting iterator doesn't have to be so.
81enum IteratorGenerator<Source: IntoParallelSource> {
82    /// The function that generates the iterator.
83    Generator(Source),
84    /// The actual iterator that produces the items.
85    Iterator(Source::Iter),
86    /// An extra variant used when moving the generator out of the enum, and before putting back the
87    /// iterator. This makes this enum panic-safe in the `generate` method.
88    Generating,
89}
90
91impl<Source: IntoParallelSource> IteratorGenerator<Source> {
92    /// Consume the generator function and store the produced iterator.
93    ///
94    /// This method can be called only once.
95    fn generate(&mut self, global_id: CoordUInt, instances: CoordUInt) {
96        let gen = std::mem::replace(self, IteratorGenerator::Generating);
97        let iter = match gen {
98            IteratorGenerator::Generator(gen) => gen.generate_iterator(global_id, instances),
99            _ => unreachable!("generate on non-Generator variant"),
100        };
101        *self = IteratorGenerator::Iterator(iter);
102    }
103
104    /// If the `generate` method has been called, get the next element from the iterator.
105    fn next(&mut self) -> Option<<Source::Iter as Iterator>::Item> {
106        match self {
107            IteratorGenerator::Iterator(iter) => iter.next(),
108            _ => unreachable!("next on non-Iterator variant"),
109        }
110    }
111}
112
113impl<Source: IntoParallelSource> Clone for IteratorGenerator<Source> {
114    fn clone(&self) -> Self {
115        match self {
116            Self::Generator(gen) => Self::Generator(gen.clone()),
117            _ => panic!("Can clone only before generating the iterator"),
118        }
119    }
120}
121
122/// Source that ingests items into a stream using the maximum parallelism. The items are from the
123/// iterators returned by a generating function.
124///
125/// Each replica (i.e. each core) will have a different iterator. The iterator are produced by a
126/// generating function passed to the [`ParallelIteratorSource::new`] method.
127#[derive(Derivative)]
128#[derivative(Debug)]
129pub struct ParallelIteratorSource<Source>
130where
131    Source: IntoParallelSource,
132{
133    #[derivative(Debug = "ignore")]
134    inner: IteratorGenerator<Source>,
135    terminated: bool,
136}
137
138impl<Source> Display for ParallelIteratorSource<Source>
139where
140    Source: IntoParallelSource,
141{
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(
144            f,
145            "ParallelIteratorSource<{}>",
146            std::any::type_name::<<Source::Iter as Iterator>::Item>()
147        )
148    }
149}
150
151impl<S> Operator for ParallelIteratorSource<S>
152where
153    S: IntoParallelSource,
154    S::Iter: Send,
155    <S::Iter as Iterator>::Item: Send,
156{
157    type Out = <S::Iter as Iterator>::Item;
158
159    fn setup(&mut self, metadata: &mut ExecutionMetadata) {
160        self.inner.generate(
161            metadata.global_id,
162            metadata
163                .replicas
164                .len()
165                .try_into()
166                .expect("Num replicas > max id"),
167        );
168    }
169
170    fn next(&mut self) -> StreamElement<Self::Out> {
171        if self.terminated {
172            return StreamElement::Terminate;
173        }
174        // TODO: with adaptive batching this does not work since it never emits FlushBatch messages
175        match self.inner.next() {
176            Some(t) => StreamElement::Item(t),
177            None => {
178                self.terminated = true;
179                StreamElement::FlushAndRestart
180            }
181        }
182    }
183
184    fn structure(&self) -> BlockStructure {
185        let mut operator =
186            OperatorStructure::new::<<S::Iter as Iterator>::Item, _>("ParallelIteratorSource");
187        operator.kind = OperatorKind::Source;
188        BlockStructure::default().add_operator(operator)
189    }
190}
191
192impl<S> Clone for ParallelIteratorSource<S>
193where
194    S: IntoParallelSource,
195{
196    fn clone(&self) -> Self {
197        Self {
198            inner: self.inner.clone(),
199            terminated: false,
200        }
201    }
202}
203
204impl crate::StreamContext {
205    /// Convenience method, creates a `ParallelIteratorSource` and makes a stream using `StreamContext::stream`
206    /// # Example:
207    /// ```
208    /// use noir_compute::prelude::*;
209    ///
210    /// let env = StreamContext::default();
211    ///
212    /// env.stream_par_iter(0..10)
213    ///     .for_each(|q| println!("a: {q}"));
214    ///
215    /// let n = 10;
216    /// env.stream_par_iter(
217    ///     move |id, instances| {
218    ///         let chunk_size = (n + instances - 1) / instances;
219    ///         let remaining = n - n.min(chunk_size * id);
220    ///         let range = remaining.min(chunk_size);
221    ///         
222    ///         let start = id * chunk_size;
223    ///         let stop = id * chunk_size + range;
224    ///         start..stop
225    ///     })
226    ///    .for_each(|q| println!("b: {q}"));
227    ///
228    /// env.execute_blocking();
229    /// ```
230    pub fn stream_par_iter<Source>(
231        &self,
232        generator: Source,
233    ) -> Stream<ParallelIteratorSource<Source>>
234    where
235        Source: IntoParallelSource + 'static,
236        Source::Iter: Send,
237        <Source::Iter as Iterator>::Item: Send,
238    {
239        let source = ParallelIteratorSource::new(generator);
240        self.stream(source)
241    }
242}
243
244impl<S> ParallelIteratorSource<S>
245where
246    S: IntoParallelSource,
247{
248    /// Create a new source that ingest items into the stream using the maximum parallelism
249    /// available.
250    ///
251    /// The function passed as argument is cloned in each core, and called to get the iterator for
252    /// that replica. The first parameter passed to the function is a 0-based index of the replica,
253    /// while the second is the total number of replicas.
254    ///
255    /// ## Example
256    ///
257    /// ```
258    /// # use noir_compute::{StreamContext, RuntimeConfig};
259    /// # use noir_compute::operator::source::ParallelIteratorSource;
260    /// # let mut env = StreamContext::new(RuntimeConfig::local(1));
261    /// // generate the numbers from 0 to 99 using multiple replicas
262    /// let n = 100;
263    /// let source = ParallelIteratorSource::new(move |id, instances| {
264    ///     let chunk_size = (n + instances - 1) / instances;
265    ///     let remaining = n - n.min(chunk_size * id);
266    ///     let range = remaining.min(chunk_size);
267    ///     
268    ///     let start = id * chunk_size;
269    ///     let stop = id * chunk_size + range;
270    ///     start..stop
271    /// });
272    /// let s = env.stream(source);
273    /// ```
274    pub fn new(generator: S) -> Self {
275        Self {
276            inner: IteratorGenerator::Generator(generator),
277            terminated: false,
278        }
279    }
280}
281
282impl<S> Source for ParallelIteratorSource<S>
283where
284    S: IntoParallelSource,
285    S::Iter: Send,
286    <S::Iter as Iterator>::Item: Send,
287{
288    fn replication(&self) -> Replication {
289        Replication::Unlimited
290    }
291}