ndarray_parallel/
par.rs

1
2use rayon::iter::ParallelIterator;
3use rayon::iter::IndexedParallelIterator;
4use rayon::iter::plumbing::{Consumer, UnindexedConsumer};
5use rayon::iter::plumbing::bridge;
6use rayon::iter::plumbing::ProducerCallback;
7use rayon::iter::plumbing::Producer;
8use rayon::iter::plumbing::UnindexedProducer;
9use rayon::iter::plumbing::bridge_unindexed;
10use rayon::iter::plumbing::Folder;
11
12use ndarray::iter::AxisIter;
13use ndarray::iter::AxisIterMut;
14use ndarray::{Dimension};
15use ndarray::{ArrayView, ArrayViewMut};
16
17use super::NdarrayIntoParallelIterator;
18
19/// Parallel iterator wrapper.
20#[derive(Copy, Clone, Debug)]
21pub struct Parallel<I> {
22    iter: I,
23}
24
25/// Parallel producer wrapper.
26#[derive(Copy, Clone, Debug)]
27struct ParallelProducer<I>(I);
28
29macro_rules! par_iter_wrapper {
30    // thread_bounds are either Sync or Send + Sync
31    ($iter_name:ident, [$($thread_bounds:tt)*]) => {
32    impl<'a, A, D> NdarrayIntoParallelIterator for $iter_name<'a, A, D>
33        where D: Dimension,
34              A: $($thread_bounds)*,
35    {
36        type Item = <Self as Iterator>::Item;
37        type Iter = Parallel<Self>;
38        fn into_par_iter(self) -> Self::Iter {
39            Parallel {
40                iter: self,
41            }
42        }
43    }
44
45    impl<'a, A, D> ParallelIterator for Parallel<$iter_name<'a, A, D>>
46        where D: Dimension,
47              A: $($thread_bounds)*,
48    {
49        type Item = <$iter_name<'a, A, D> as Iterator>::Item;
50        fn drive_unindexed<C>(self, consumer: C) -> C::Result
51            where C: UnindexedConsumer<Self::Item>
52        {
53            bridge(self, consumer)
54        }
55
56        fn opt_len(&self) -> Option<usize> {
57            Some(self.iter.len())
58        }
59    }
60
61    impl<'a, A, D> IndexedParallelIterator for Parallel<$iter_name<'a, A, D>>
62        where D: Dimension,
63              A: $($thread_bounds)*,
64    {
65        fn with_producer<Cb>(self, callback: Cb) -> Cb::Output
66            where Cb: ProducerCallback<Self::Item>
67        {
68            callback.callback(ParallelProducer(self.iter))
69        }
70
71        fn len(&self) -> usize {
72            ExactSizeIterator::len(&self.iter)
73        }
74
75        fn drive<C>(self, consumer: C) -> C::Result
76            where C: Consumer<Self::Item>
77        {
78            bridge(self, consumer)
79        }
80    }
81
82    impl<'a, A, D> IntoIterator for ParallelProducer<$iter_name<'a, A, D>>
83        where D: Dimension,
84    {
85        type IntoIter = $iter_name<'a, A, D>;
86        type Item = <Self::IntoIter as Iterator>::Item;
87
88        fn into_iter(self) -> Self::IntoIter {
89            self.0
90        }
91    }
92
93    // This is the real magic, I guess
94    impl<'a, A, D> Producer for ParallelProducer<$iter_name<'a, A, D>>
95        where D: Dimension,
96              A: $($thread_bounds)*,
97    {
98        type IntoIter = $iter_name<'a, A, D>;
99        type Item = <Self::IntoIter as Iterator>::Item;
100
101        fn into_iter(self) -> Self::IntoIter {
102            self.0
103        }
104
105        fn split_at(self, i: usize) -> (Self, Self) {
106            let (a, b) = self.0.split_at(i);
107            (ParallelProducer(a), ParallelProducer(b))
108        }
109    }
110
111    }
112}
113
114
115par_iter_wrapper!(AxisIter, [Sync]);
116par_iter_wrapper!(AxisIterMut, [Send + Sync]);
117
118
119
120macro_rules! par_iter_view_wrapper {
121    // thread_bounds are either Sync or Send + Sync
122    ($view_name:ident, [$($thread_bounds:tt)*]) => {
123    impl<'a, A, D> NdarrayIntoParallelIterator for $view_name<'a, A, D>
124        where D: Dimension,
125              A: $($thread_bounds)*,
126    {
127        type Item = <Self as IntoIterator>::Item;
128        type Iter = Parallel<Self>;
129        fn into_par_iter(self) -> Self::Iter {
130            Parallel {
131                iter: self,
132            }
133        }
134    }
135
136
137    impl<'a, A, D> ParallelIterator for Parallel<$view_name<'a, A, D>>
138        where D: Dimension,
139              A: $($thread_bounds)*,
140    {
141        type Item = <$view_name<'a, A, D> as IntoIterator>::Item;
142        fn drive_unindexed<C>(self, consumer: C) -> C::Result
143            where C: UnindexedConsumer<Self::Item>
144        {
145            bridge_unindexed(ParallelProducer(self.iter), consumer)
146        }
147
148        fn opt_len(&self) -> Option<usize> {
149            None
150        }
151    }
152
153    impl<'a, A, D> UnindexedProducer for ParallelProducer<$view_name<'a, A, D>>
154        where D: Dimension,
155              A: $($thread_bounds)*,
156    {
157        type Item = <$view_name<'a, A, D> as IntoIterator>::Item;
158        fn split(self) -> (Self, Option<Self>) {
159            if self.0.len() <= 1 {
160                return (self, None)
161            }
162            let array = self.0;
163            let max_axis = array.max_stride_axis();
164            let mid = array.len_of(max_axis) / 2;
165            let (a, b) = array.split_at(max_axis, mid);
166            (ParallelProducer(a), Some(ParallelProducer(b)))
167        }
168
169        fn fold_with<F>(self, folder: F) -> F
170            where F: Folder<Self::Item>,
171        {
172            self.into_iter().fold(folder, move |f, elt| f.consume(elt))
173        }
174    }
175
176    impl<'a, A, D> IntoIterator for ParallelProducer<$view_name<'a, A, D>>
177        where D: Dimension,
178              A: $($thread_bounds)*,
179    {
180        type Item = <$view_name<'a, A, D> as IntoIterator>::Item;
181        type IntoIter = <$view_name<'a, A, D> as IntoIterator>::IntoIter;
182        fn into_iter(self) -> Self::IntoIter {
183            self.0.into_iter()
184        }
185    }
186
187    }
188}
189
190par_iter_view_wrapper!(ArrayView, [Sync]);
191par_iter_view_wrapper!(ArrayViewMut, [Sync + Send]);
192
193
194use ndarray::{Zip, NdProducer, FoldWhile};
195
196macro_rules! zip_impl {
197    ($([$($p:ident)*],)+) => {
198        $(
199        #[allow(non_snake_case)]
200        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> NdarrayIntoParallelIterator for Zip<($($p,)*), Dim>
201            where $($p::Item : Send , )*
202                  $($p : Send , )*
203        {
204            type Item = ($($p::Item ,)*);
205            type Iter = Parallel<Self>;
206            fn into_par_iter(self) -> Self::Iter {
207                Parallel {
208                    iter: self,
209                }
210            }
211        }
212
213        #[allow(non_snake_case)]
214        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ParallelIterator for Parallel<Zip<($($p,)*), Dim>>
215            where $($p::Item : Send , )*
216                  $($p : Send , )*
217        {
218            type Item = ($($p::Item ,)*);
219
220            fn drive_unindexed<Cons>(self, consumer: Cons) -> Cons::Result
221                where Cons: UnindexedConsumer<Self::Item>
222            {
223                bridge_unindexed(ParallelProducer(self.iter), consumer)
224            }
225
226            fn opt_len(&self) -> Option<usize> {
227                None
228            }
229        }
230
231        #[allow(non_snake_case)]
232        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> UnindexedProducer for ParallelProducer<Zip<($($p,)*), Dim>>
233            where $($p : Send , )*
234                  $($p::Item : Send , )*
235        {
236            type Item = ($($p::Item ,)*);
237
238            fn split(self) -> (Self, Option<Self>) {
239                if self.0.size() <= 1 {
240                    return (self, None)
241                }
242                let (a, b) = self.0.split();
243                (ParallelProducer(a), Some(ParallelProducer(b)))
244            }
245
246            fn fold_with<Fold>(self, folder: Fold) -> Fold
247                where Fold: Folder<Self::Item>,
248            {
249                self.0.fold_while(folder, |mut folder, $($p),*| {
250                    folder = folder.consume(($($p ,)*));
251                    if folder.full() {
252                        FoldWhile::Done(folder)
253                    } else {
254                        FoldWhile::Continue(folder)
255                    }
256                }).into_inner()
257            }
258        }
259        )+
260    }
261}
262
263zip_impl!{
264    [P1],
265    [P1 P2],
266    [P1 P2 P3],
267    [P1 P2 P3 P4],
268    [P1 P2 P3 P4 P5],
269    [P1 P2 P3 P4 P5 P6],
270}