my_ecs/ecs/sched/
par.rs

1use super::{
2    ctrl::SUB_CONTEXT,
3    task::{ParTask, ParTaskHolder, TaskId},
4};
5use crate::global;
6use rayon::iter::{
7    plumbing::{
8        Consumer, Folder, Producer, ProducerCallback, Reducer, UnindexedConsumer, UnindexedProducer,
9    },
10    IndexedParallelIterator, ParallelIterator,
11};
12
13// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/rayon-core/src/lib.rs#L851
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15#[repr(transparent)]
16pub(crate) struct FnContext {
17    migrated: bool,
18}
19
20impl FnContext {
21    pub(super) const MIGRATED: Self = Self { migrated: true };
22    pub(super) const NOT_MIGRATED: Self = Self { migrated: false };
23}
24
25// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L255C1-L255C18
26#[derive(Debug, Clone, Copy)]
27#[repr(transparent)]
28struct Splitter {
29    splits: usize,
30}
31
32impl Splitter {
33    fn new() -> Option<Self> {
34        let ptr = SUB_CONTEXT.get();
35        (!ptr.is_dangling()).then(|| {
36            // Safety: `ptr` is a valid pointer.
37            Self {
38                splits: unsafe { ptr.as_ref().get_comm().num_siblings() },
39            }
40        })
41    }
42
43    fn try_split(&mut self, migrated: bool) -> bool {
44        // rayon implementation.
45        //
46        // If a task is migrated, i.e. took it from another worker's local queue,
47        // that could mean that some workers are hungry for tasks.
48        // Therefore, it would be good for keeping processors busy
49        // to split this task into enough small pieces.
50        // if migrated {
51        //     let num_workers = unsafe { SUB_CONTEXT.get().as_ref().siblings.len() };
52        //     self.splits = num_workers.max(self.splits / 2);
53        //     true
54        // } else {
55        //     let res = self.splits > 0;
56        //     self.splits /= 2;
57        //     res
58        // }
59
60        // my implementation.
61        //
62        // Currently, I need to optimize work distribution.
63        // Meanwhile, this could hide poor performance.
64        if !migrated {
65            self.splits /= 2;
66        }
67        self.splits > 0
68    }
69}
70
71// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L293
72#[derive(Debug, Clone, Copy)]
73struct LengthSplitter {
74    inner: Splitter,
75    min: usize,
76}
77
78impl LengthSplitter {
79    fn new(min: usize, max: usize, len: usize) -> Option<Self> {
80        let mut inner = Splitter::new()?;
81        let min_splits = len / max.max(1);
82        inner.splits = inner.splits.max(min_splits);
83
84        Some(Self {
85            inner,
86            min: min.max(1),
87        })
88    }
89
90    fn try_split(&mut self, len: usize, migrated: bool) -> bool {
91        len / 2 >= self.min && self.inner.try_split(migrated)
92    }
93}
94
95// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L350
96fn bridge<I, C>(par_iter: I, consumer: C) -> C::Result
97where
98    I: IndexedParallelIterator,
99    C: Consumer<I::Item>,
100{
101    global::stat::increase_parallel_task_count();
102
103    let len = par_iter.len();
104    return par_iter.with_producer(Callback { len, consumer });
105
106    struct Callback<C> {
107        len: usize,
108        consumer: C,
109    }
110
111    impl<C, I> ProducerCallback<I> for Callback<C>
112    where
113        C: Consumer<I>,
114    {
115        type Output = C::Result;
116        fn callback<P>(self, producer: P) -> C::Result
117        where
118            P: Producer<Item = I>,
119        {
120            bridge_producer_consumer(self.len, producer, self.consumer)
121        }
122    }
123}
124
125// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L390
126fn bridge_producer_consumer<P, C>(len: usize, producer: P, consumer: C) -> C::Result
127where
128    P: Producer,
129    C: Consumer<P::Item>,
130{
131    const MIGRATED: bool = false;
132    let res =
133        if let Some(splitter) = LengthSplitter::new(producer.min_len(), producer.max_len(), len) {
134            helper(len, MIGRATED, splitter, producer, consumer)
135        } else {
136            helper_no_split(producer, consumer)
137        };
138    return res;
139
140    // === Internal helper functions ===
141
142    fn helper<P, C>(
143        len: usize,
144        migrated: bool,
145        mut splitter: LengthSplitter,
146        producer: P,
147        consumer: C,
148    ) -> C::Result
149    where
150        P: Producer,
151        C: Consumer<P::Item>,
152    {
153        if consumer.full() {
154            consumer.into_folder().complete()
155        } else if splitter.try_split(len, migrated) {
156            let mid = len / 2;
157            let (l_producer, r_producer) = producer.split_at(mid);
158            let (l_consumer, r_consumer, reducer) = consumer.split_at(mid);
159            let (l_result, r_result) = join_context(
160                |f_cx: FnContext| helper(mid, f_cx.migrated, splitter, l_producer, l_consumer),
161                |f_cx: FnContext| {
162                    helper(len - mid, f_cx.migrated, splitter, r_producer, r_consumer)
163                },
164            );
165            reducer.reduce(l_result, r_result)
166        } else {
167            producer.fold_with(consumer.into_folder()).complete()
168        }
169    }
170
171    fn helper_no_split<P, C>(producer: P, consumer: C) -> C::Result
172    where
173        P: Producer,
174        C: Consumer<P::Item>,
175    {
176        if consumer.full() {
177            consumer.into_folder().complete()
178        } else {
179            producer.fold_with(consumer.into_folder()).complete()
180        }
181    }
182}
183
184// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L445
185#[allow(dead_code)] // For future use
186fn bridge_unindexed<P, C>(producer: P, consumer: C) -> C::Result
187where
188    P: UnindexedProducer,
189    C: UnindexedConsumer<P::Item>,
190{
191    global::stat::increase_parallel_task_count();
192    let splitter = Splitter::new().unwrap();
193    bridge_unindexed_producer_consumer(false, splitter, producer, consumer)
194}
195
196// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/src/iter/plumbing/mod.rs#L454
197fn bridge_unindexed_producer_consumer<P, C>(
198    migrated: bool,
199    mut splitter: Splitter,
200    producer: P,
201    consumer: C,
202) -> C::Result
203where
204    P: UnindexedProducer,
205    C: UnindexedConsumer<P::Item>,
206{
207    if consumer.full() {
208        consumer.into_folder().complete()
209    } else if splitter.try_split(migrated) {
210        match producer.split() {
211            (l_producer, Some(r_producer)) => {
212                let (reducer, l_consumer, r_consumer) =
213                    (consumer.to_reducer(), consumer.split_off_left(), consumer);
214                let bridge = bridge_unindexed_producer_consumer;
215                let (l_result, r_result) = join_context(
216                    |f_cx: FnContext| bridge(f_cx.migrated, splitter, l_producer, l_consumer),
217                    |f_cx: FnContext| bridge(f_cx.migrated, splitter, r_producer, r_consumer),
218                );
219                reducer.reduce(l_result, r_result)
220            }
221            (producer, None) => producer.fold_with(consumer.into_folder()).complete(),
222        }
223    } else {
224        producer.fold_with(consumer.into_folder()).complete()
225    }
226}
227
228// ref: https://github.com/rayon-rs/rayon/blob/7543ed40c9a017dee32b3dc72b3ae819820e8366/rayon-core/src/join/mod.rs#L115
229fn join_context<L, R, Lr, Rr>(l_f: L, r_f: R) -> (Lr, Rr)
230where
231    L: FnOnce(FnContext) -> Lr + Send,
232    R: FnOnce(FnContext) -> Rr + Send,
233    Lr: Send,
234    Rr: Send,
235{
236    let cx = unsafe { SUB_CONTEXT.get().as_ref() };
237
238    let r_holder = ParTaskHolder::new(r_f);
239    let r_task = unsafe { ParTask::new(&r_holder) };
240    let r_task_id = TaskId::Parallel(r_task);
241
242    // Puts 'Right task' into local queue then notifies it to another worker.
243    //
244    // TODO: Optimize.
245    // compare to rayon, too many steal operations take place.
246    // The more unnecessary steal operations, the poorer performance.
247    // I guess this frequent notification is one of the reasons.
248    // Anyway, I mitigated it by reducing split count.
249    // See `Splitter::try_split`.
250    cx.get_comm().push_parallel_task(r_task);
251    cx.get_comm().signal().sub().notify_one();
252
253    // Executes 'Left task'.
254    #[cfg(not(target_arch = "wasm32"))]
255    let l_res = {
256        let executor = std::panic::AssertUnwindSafe(move || l_f(FnContext::NOT_MIGRATED));
257        match std::panic::catch_unwind(executor) {
258            Ok(l_res) => l_res,
259            Err(payload) => {
260                // Panicked in `Left task`.
261                // But we need to hold `Right task` until it's finished
262                // if it was stolen by another worker.
263                if let Some(task) = cx.get_comm().pop_local() {
264                    debug_assert_eq!(task.id(), r_task_id);
265                } else {
266                    while !r_holder.is_executed() {
267                        std::thread::yield_now();
268                    }
269                }
270                std::panic::resume_unwind(payload);
271            }
272        }
273    };
274
275    // In web, we don't have a way to hold `Right task` when `Left task` panics.
276    #[cfg(target_arch = "wasm32")]
277    let l_res = l_f(FnContext::NOT_MIGRATED);
278
279    // If we could find a task from the local queue, it must be 'Right task',
280    // because the queue is LIFO fashion.
281    if let Some(task) = cx.get_comm().pop_local() {
282        debug_assert_eq!(task.id(), r_task_id);
283        let wid = cx.get_comm().worker_id();
284        r_task.execute(wid, FnContext::NOT_MIGRATED);
285    } else {
286        // We couldn't find a task from local queue.
287        // That means that another worker has stolen 'Right task'.
288        // While we wait for 'Right task' to be finished by the another worker,
289        // steals some tasks and executes them.
290        while !r_holder.is_executed() {
291            let mut steal = cx.get_comm().search();
292            cx.work(&mut steal);
293            // TODO: Busy waiting if it failed to steal tasks.
294        }
295    }
296
297    match unsafe { r_holder.return_or_panic_unchecked() } {
298        Ok(r_res) => (l_res, r_res),
299        Err(payload) => std::panic::resume_unwind(payload),
300    }
301}
302
303/// A trait for wrapping Rayon's parallel iterators in [`EcsPar`] in order to
304/// intercept function call to a Rayon API then to execute them in the ECS
305/// context.
306pub trait IntoEcsPar: ParallelIterator {
307    /// Wraps Rayon's parallel iterator in [`EcsPar`].
308    ///
309    /// `EcsPar` calls an ECS function to make use of ECS workers instead of
310    /// Rayon's workers.
311    #[inline]
312    fn into_ecs_par(self) -> EcsPar<Self> {
313        EcsPar(self)
314    }
315
316    /// Implementations must call [`bridge`] or [`bridge_unindexed`] instead of
317    /// [`rayon::iter::plumbing::bridge`] or
318    /// [`rayon::iter::plumbing::bridge_unindexed`].
319    #[doc(hidden)]
320    fn drive_unindexed<C>(self, consumer: C) -> C::Result
321    where
322        C: UnindexedConsumer<Self::Item>;
323}
324
325impl<I: IndexedParallelIterator> IntoEcsPar for I {
326    #[inline]
327    fn drive_unindexed<C>(self, consumer: C) -> C::Result
328    where
329        C: UnindexedConsumer<Self::Item>,
330    {
331        // Intercepts.
332        bridge(self, consumer)
333    }
334}
335
336/// A wrapper type of Rayon's parallel iterator.
337///
338/// Rayon's parallel iterator basically uses its own worker registry. It means
339/// that Rayon will spawn new workers regardless of living workers in ECS
340/// instance, which is a behavior you may not want. To use workers of ECS
341/// instance instead, just wrap the Rayon's parallel iterator in this wrapper.
342/// Then, this wrapper will intercept calls to Rayon's functions.
343///
344/// # Limitation
345///
346/// This wrapper currently requires implmentation of
347/// [`IndexedParallelIterator`].
348#[derive(Clone)]
349#[repr(transparent)]
350pub struct EcsPar<I>(pub I);
351
352impl<I: IntoEcsPar> ParallelIterator for EcsPar<I> {
353    type Item = I::Item;
354
355    fn drive_unindexed<C>(self, consumer: C) -> C::Result
356    where
357        C: UnindexedConsumer<Self::Item>,
358    {
359        // Intercepts
360        IntoEcsPar::drive_unindexed(self.0, consumer)
361    }
362}
363
364impl<I> IndexedParallelIterator for EcsPar<I>
365where
366    I: IntoEcsPar + IndexedParallelIterator,
367{
368    #[inline]
369    fn len(&self) -> usize {
370        self.0.len()
371    }
372
373    #[inline]
374    fn drive<C: Consumer<Self::Item>>(self, consumer: C) -> C::Result {
375        // Intercepts
376        bridge(self, consumer)
377    }
378
379    #[inline]
380    fn with_producer<CB: ProducerCallback<Self::Item>>(self, callback: CB) -> CB::Output {
381        self.0.with_producer(callback)
382    }
383}
384
385#[cfg(test)]
386mod tests {
387    /// Wraps rayon's parallel iterators in an interceptor.
388    #[test]
389    fn test_into_ecs_par() {
390        use super::*;
391        use rayon::iter::IntoParallelIterator;
392
393        // Array
394        let iter: rayon::array::IntoIter<i32, 2> = [0, 1].into_par_iter();
395        let _ecs_iter = iter.into_ecs_par();
396        // Range
397        let iter: rayon::range::Iter<i32> = (0..2).into_par_iter();
398        let _ecs_iter = iter.into_ecs_par();
399        // Slice
400        let iter: rayon::slice::Iter<'_, i32> = [0, 1][..].into_par_iter();
401        let _ecs_iter = iter.into_ecs_par();
402        // Zip
403        let range_iter0: rayon::range::Iter<i32> = (0..2).into_par_iter();
404        let range_iter1: rayon::range::Iter<i32> = (0..2).into_par_iter();
405        let zip_iter = range_iter0.zip(range_iter1);
406        let _ecs_iter = zip_iter.into_ecs_par();
407    }
408}