1use super::{
2 ctrl::SUB_CONTEXT,
3 task::{ParTask, ParTaskHolder, TaskId},
4};
5use crate::global;
6use rayon::iter::{
7 plumbing::{
8 Consumer, Folder, Producer, ProducerCallback, Reducer, UnindexedConsumer, UnindexedProducer,
9 },
10 IndexedParallelIterator, ParallelIterator,
11};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(transparent)]
16pub(crate) struct FnContext {
17 migrated: bool,
18}
19
20impl FnContext {
21 pub(super) const MIGRATED: Self = Self { migrated: true };
22 pub(super) const NOT_MIGRATED: Self = Self { migrated: false };
23}
24
25#[derive(Debug, Clone, Copy)]
27#[repr(transparent)]
28struct Splitter {
29 splits: usize,
30}
31
32impl Splitter {
33 fn new() -> Option<Self> {
34 let ptr = SUB_CONTEXT.get();
35 (!ptr.is_dangling()).then(|| {
36 Self {
38 splits: unsafe { ptr.as_ref().get_comm().num_siblings() },
39 }
40 })
41 }
42
43 fn try_split(&mut self, migrated: bool) -> bool {
44 if !migrated {
65 self.splits /= 2;
66 }
67 self.splits > 0
68 }
69}
70
71#[derive(Debug, Clone, Copy)]
73struct LengthSplitter {
74 inner: Splitter,
75 min: usize,
76}
77
78impl LengthSplitter {
79 fn new(min: usize, max: usize, len: usize) -> Option<Self> {
80 let mut inner = Splitter::new()?;
81 let min_splits = len / max.max(1);
82 inner.splits = inner.splits.max(min_splits);
83
84 Some(Self {
85 inner,
86 min: min.max(1),
87 })
88 }
89
90 fn try_split(&mut self, len: usize, migrated: bool) -> bool {
91 len / 2 >= self.min && self.inner.try_split(migrated)
92 }
93}
94
95fn bridge<I, C>(par_iter: I, consumer: C) -> C::Result
97where
98 I: IndexedParallelIterator,
99 C: Consumer<I::Item>,
100{
101 global::stat::increase_parallel_task_count();
102
103 let len = par_iter.len();
104 return par_iter.with_producer(Callback { len, consumer });
105
106 struct Callback<C> {
107 len: usize,
108 consumer: C,
109 }
110
111 impl<C, I> ProducerCallback<I> for Callback<C>
112 where
113 C: Consumer<I>,
114 {
115 type Output = C::Result;
116 fn callback<P>(self, producer: P) -> C::Result
117 where
118 P: Producer<Item = I>,
119 {
120 bridge_producer_consumer(self.len, producer, self.consumer)
121 }
122 }
123}
124
125fn bridge_producer_consumer<P, C>(len: usize, producer: P, consumer: C) -> C::Result
127where
128 P: Producer,
129 C: Consumer<P::Item>,
130{
131 const MIGRATED: bool = false;
132 let res =
133 if let Some(splitter) = LengthSplitter::new(producer.min_len(), producer.max_len(), len) {
134 helper(len, MIGRATED, splitter, producer, consumer)
135 } else {
136 helper_no_split(producer, consumer)
137 };
138 return res;
139
140 fn helper<P, C>(
143 len: usize,
144 migrated: bool,
145 mut splitter: LengthSplitter,
146 producer: P,
147 consumer: C,
148 ) -> C::Result
149 where
150 P: Producer,
151 C: Consumer<P::Item>,
152 {
153 if consumer.full() {
154 consumer.into_folder().complete()
155 } else if splitter.try_split(len, migrated) {
156 let mid = len / 2;
157 let (l_producer, r_producer) = producer.split_at(mid);
158 let (l_consumer, r_consumer, reducer) = consumer.split_at(mid);
159 let (l_result, r_result) = join_context(
160 |f_cx: FnContext| helper(mid, f_cx.migrated, splitter, l_producer, l_consumer),
161 |f_cx: FnContext| {
162 helper(len - mid, f_cx.migrated, splitter, r_producer, r_consumer)
163 },
164 );
165 reducer.reduce(l_result, r_result)
166 } else {
167 producer.fold_with(consumer.into_folder()).complete()
168 }
169 }
170
171 fn helper_no_split<P, C>(producer: P, consumer: C) -> C::Result
172 where
173 P: Producer,
174 C: Consumer<P::Item>,
175 {
176 if consumer.full() {
177 consumer.into_folder().complete()
178 } else {
179 producer.fold_with(consumer.into_folder()).complete()
180 }
181 }
182}
183
184#[allow(dead_code)] fn bridge_unindexed<P, C>(producer: P, consumer: C) -> C::Result
187where
188 P: UnindexedProducer,
189 C: UnindexedConsumer<P::Item>,
190{
191 global::stat::increase_parallel_task_count();
192 let splitter = Splitter::new().unwrap();
193 bridge_unindexed_producer_consumer(false, splitter, producer, consumer)
194}
195
196fn bridge_unindexed_producer_consumer<P, C>(
198 migrated: bool,
199 mut splitter: Splitter,
200 producer: P,
201 consumer: C,
202) -> C::Result
203where
204 P: UnindexedProducer,
205 C: UnindexedConsumer<P::Item>,
206{
207 if consumer.full() {
208 consumer.into_folder().complete()
209 } else if splitter.try_split(migrated) {
210 match producer.split() {
211 (l_producer, Some(r_producer)) => {
212 let (reducer, l_consumer, r_consumer) =
213 (consumer.to_reducer(), consumer.split_off_left(), consumer);
214 let bridge = bridge_unindexed_producer_consumer;
215 let (l_result, r_result) = join_context(
216 |f_cx: FnContext| bridge(f_cx.migrated, splitter, l_producer, l_consumer),
217 |f_cx: FnContext| bridge(f_cx.migrated, splitter, r_producer, r_consumer),
218 );
219 reducer.reduce(l_result, r_result)
220 }
221 (producer, None) => producer.fold_with(consumer.into_folder()).complete(),
222 }
223 } else {
224 producer.fold_with(consumer.into_folder()).complete()
225 }
226}
227
228fn join_context<L, R, Lr, Rr>(l_f: L, r_f: R) -> (Lr, Rr)
230where
231 L: FnOnce(FnContext) -> Lr + Send,
232 R: FnOnce(FnContext) -> Rr + Send,
233 Lr: Send,
234 Rr: Send,
235{
236 let cx = unsafe { SUB_CONTEXT.get().as_ref() };
237
238 let r_holder = ParTaskHolder::new(r_f);
239 let r_task = unsafe { ParTask::new(&r_holder) };
240 let r_task_id = TaskId::Parallel(r_task);
241
242 cx.get_comm().push_parallel_task(r_task);
251 cx.get_comm().signal().sub().notify_one();
252
253 #[cfg(not(target_arch = "wasm32"))]
255 let l_res = {
256 let executor = std::panic::AssertUnwindSafe(move || l_f(FnContext::NOT_MIGRATED));
257 match std::panic::catch_unwind(executor) {
258 Ok(l_res) => l_res,
259 Err(payload) => {
260 if let Some(task) = cx.get_comm().pop_local() {
264 debug_assert_eq!(task.id(), r_task_id);
265 } else {
266 while !r_holder.is_executed() {
267 std::thread::yield_now();
268 }
269 }
270 std::panic::resume_unwind(payload);
271 }
272 }
273 };
274
275 #[cfg(target_arch = "wasm32")]
277 let l_res = l_f(FnContext::NOT_MIGRATED);
278
279 if let Some(task) = cx.get_comm().pop_local() {
282 debug_assert_eq!(task.id(), r_task_id);
283 let wid = cx.get_comm().worker_id();
284 r_task.execute(wid, FnContext::NOT_MIGRATED);
285 } else {
286 while !r_holder.is_executed() {
291 let mut steal = cx.get_comm().search();
292 cx.work(&mut steal);
293 }
295 }
296
297 match unsafe { r_holder.return_or_panic_unchecked() } {
298 Ok(r_res) => (l_res, r_res),
299 Err(payload) => std::panic::resume_unwind(payload),
300 }
301}
302
303pub trait IntoEcsPar: ParallelIterator {
307 #[inline]
312 fn into_ecs_par(self) -> EcsPar<Self> {
313 EcsPar(self)
314 }
315
316 #[doc(hidden)]
320 fn drive_unindexed<C>(self, consumer: C) -> C::Result
321 where
322 C: UnindexedConsumer<Self::Item>;
323}
324
325impl<I: IndexedParallelIterator> IntoEcsPar for I {
326 #[inline]
327 fn drive_unindexed<C>(self, consumer: C) -> C::Result
328 where
329 C: UnindexedConsumer<Self::Item>,
330 {
331 bridge(self, consumer)
333 }
334}
335
336#[derive(Clone)]
349#[repr(transparent)]
350pub struct EcsPar<I>(pub I);
351
352impl<I: IntoEcsPar> ParallelIterator for EcsPar<I> {
353 type Item = I::Item;
354
355 fn drive_unindexed<C>(self, consumer: C) -> C::Result
356 where
357 C: UnindexedConsumer<Self::Item>,
358 {
359 IntoEcsPar::drive_unindexed(self.0, consumer)
361 }
362}
363
364impl<I> IndexedParallelIterator for EcsPar<I>
365where
366 I: IntoEcsPar + IndexedParallelIterator,
367{
368 #[inline]
369 fn len(&self) -> usize {
370 self.0.len()
371 }
372
373 #[inline]
374 fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
375 bridge(self, consumer)
377 }
378
379 #[inline]
380 fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
381 self.0.with_producer(callback)
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 #[test]
389 fn test_into_ecs_par() {
390 use super::*;
391 use rayon::iter::IntoParallelIterator;
392
393 let iter: rayon::array::IntoIter<i32, 2> = [0, 1].into_par_iter();
395 let _ecs_iter = iter.into_ecs_par();
396 let iter: rayon::range::Iter<i32> = (0..2).into_par_iter();
398 let _ecs_iter = iter.into_ecs_par();
399 let iter: rayon::slice::Iter<'_, i32> = [0, 1][..].into_par_iter();
401 let _ecs_iter = iter.into_ecs_par();
402 let range_iter0: rayon::range::Iter<i32> = (0..2).into_par_iter();
404 let range_iter1: rayon::range::Iter<i32> = (0..2).into_par_iter();
405 let zip_iter = range_iter0.zip(range_iter1);
406 let _ecs_iter = zip_iter.into_ecs_par();
407 }
408}