noir_compute/operator/source/
parallel_iterator.rs1use std::fmt::Display;
2use std::ops::Range;
3
4use crate::block::{BlockStructure, OperatorKind, OperatorStructure, Replication};
5use crate::operator::source::Source;
6use crate::operator::{Operator, StreamElement};
7use crate::scheduler::ExecutionMetadata;
8use crate::{CoordUInt, Stream};
9
10pub trait IntoParallelSource: Clone + Send {
11 type Iter: Iterator;
12 fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter;
13}
14
15impl<It, G> IntoParallelSource for G
16where
17 It: Iterator + Send + 'static,
18 G: FnOnce(CoordUInt, CoordUInt) -> It + Send + Clone,
19{
20 type Iter = It;
21
22 fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
23 self(index, peers)
24 }
25}
26
27impl IntoParallelSource for Range<u64> {
28 type Iter = Range<u64>;
29
30 fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
31 let n = self.end - self.start;
32 let chunk_size = (n.saturating_add(peers - 1)) / peers;
33 let start = self.start.saturating_add(index * chunk_size);
34 let end = (start.saturating_add(chunk_size))
35 .min(self.end)
36 .max(self.start);
37
38 start..end
39 }
40}
41
42macro_rules! impl_into_parallel_source_range {
43 ($t:ty) => {
44 impl IntoParallelSource for Range<$t> {
45 type Iter = Range<$t>;
46
47 fn generate_iterator(self, index: CoordUInt, peers: CoordUInt) -> Self::Iter {
48 let index: i64 = index.try_into().unwrap();
49 let peers: i64 = peers.try_into().unwrap();
50 let n = self.end as i64 - self.start as i64;
51 let chunk_size = (n.saturating_add(peers - 1)) / peers;
52 let start = (self.start as i64).saturating_add(index * chunk_size);
53 let end = (start.saturating_add(chunk_size))
54 .min(self.end as i64)
55 .max(self.start as i64);
56
57 let (start, end) = (start.try_into().unwrap(), end.try_into().unwrap());
58 start..end
59 }
60 }
61 };
62}
63
64impl_into_parallel_source_range!(u8);
65impl_into_parallel_source_range!(u16);
66impl_into_parallel_source_range!(u32);
67
68impl_into_parallel_source_range!(usize);
69
70impl_into_parallel_source_range!(i8);
71impl_into_parallel_source_range!(i16);
72impl_into_parallel_source_range!(i32);
73impl_into_parallel_source_range!(i64);
74impl_into_parallel_source_range!(isize);
75
76enum IteratorGenerator<Source: IntoParallelSource> {
82 Generator(Source),
84 Iterator(Source::Iter),
86 Generating,
89}
90
91impl<Source: IntoParallelSource> IteratorGenerator<Source> {
92 fn generate(&mut self, global_id: CoordUInt, instances: CoordUInt) {
96 let gen = std::mem::replace(self, IteratorGenerator::Generating);
97 let iter = match gen {
98 IteratorGenerator::Generator(gen) => gen.generate_iterator(global_id, instances),
99 _ => unreachable!("generate on non-Generator variant"),
100 };
101 *self = IteratorGenerator::Iterator(iter);
102 }
103
104 fn next(&mut self) -> Option<<Source::Iter as Iterator>::Item> {
106 match self {
107 IteratorGenerator::Iterator(iter) => iter.next(),
108 _ => unreachable!("next on non-Iterator variant"),
109 }
110 }
111}
112
113impl<Source: IntoParallelSource> Clone for IteratorGenerator<Source> {
114 fn clone(&self) -> Self {
115 match self {
116 Self::Generator(gen) => Self::Generator(gen.clone()),
117 _ => panic!("Can clone only before generating the iterator"),
118 }
119 }
120}
121
122#[derive(Derivative)]
128#[derivative(Debug)]
129pub struct ParallelIteratorSource<Source>
130where
131 Source: IntoParallelSource,
132{
133 #[derivative(Debug = "ignore")]
134 inner: IteratorGenerator<Source>,
135 terminated: bool,
136}
137
138impl<Source> Display for ParallelIteratorSource<Source>
139where
140 Source: IntoParallelSource,
141{
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(
144 f,
145 "ParallelIteratorSource<{}>",
146 std::any::type_name::<<Source::Iter as Iterator>::Item>()
147 )
148 }
149}
150
151impl<S> Operator for ParallelIteratorSource<S>
152where
153 S: IntoParallelSource,
154 S::Iter: Send,
155 <S::Iter as Iterator>::Item: Send,
156{
157 type Out = <S::Iter as Iterator>::Item;
158
159 fn setup(&mut self, metadata: &mut ExecutionMetadata) {
160 self.inner.generate(
161 metadata.global_id,
162 metadata
163 .replicas
164 .len()
165 .try_into()
166 .expect("Num replicas > max id"),
167 );
168 }
169
170 fn next(&mut self) -> StreamElement<Self::Out> {
171 if self.terminated {
172 return StreamElement::Terminate;
173 }
174 match self.inner.next() {
176 Some(t) => StreamElement::Item(t),
177 None => {
178 self.terminated = true;
179 StreamElement::FlushAndRestart
180 }
181 }
182 }
183
184 fn structure(&self) -> BlockStructure {
185 let mut operator =
186 OperatorStructure::new::<<S::Iter as Iterator>::Item, _>("ParallelIteratorSource");
187 operator.kind = OperatorKind::Source;
188 BlockStructure::default().add_operator(operator)
189 }
190}
191
192impl<S> Clone for ParallelIteratorSource<S>
193where
194 S: IntoParallelSource,
195{
196 fn clone(&self) -> Self {
197 Self {
198 inner: self.inner.clone(),
199 terminated: false,
200 }
201 }
202}
203
204impl crate::StreamContext {
205 pub fn stream_par_iter<Source>(
231 &self,
232 generator: Source,
233 ) -> Stream<ParallelIteratorSource<Source>>
234 where
235 Source: IntoParallelSource + 'static,
236 Source::Iter: Send,
237 <Source::Iter as Iterator>::Item: Send,
238 {
239 let source = ParallelIteratorSource::new(generator);
240 self.stream(source)
241 }
242}
243
244impl<S> ParallelIteratorSource<S>
245where
246 S: IntoParallelSource,
247{
248 pub fn new(generator: S) -> Self {
275 Self {
276 inner: IteratorGenerator::Generator(generator),
277 terminated: false,
278 }
279 }
280}
281
282impl<S> Source for ParallelIteratorSource<S>
283where
284 S: IntoParallelSource,
285 S::Iter: Send,
286 <S::Iter as Iterator>::Item: Send,
287{
288 fn replication(&self) -> Replication {
289 Replication::Unlimited
290 }
291}