ndarray/parallel/
impl_par_methods.rs

1use crate::AssignElem;
2use crate::{Array, ArrayRef, Dimension, IntoNdProducer, NdProducer, Zip};
3
4use super::send_producer::SendProducer;
5use crate::parallel::par::ParallelSplits;
6use crate::parallel::prelude::*;
7
8use crate::partial::Partial;
9
10/// # Parallel methods
11impl<A, D> ArrayRef<A, D>
12where
13    D: Dimension,
14    A: Send + Sync,
15{
16    /// Parallel version of `map_inplace`.
17    ///
18    /// Modify the array in place by calling `f` by mutable reference on each element.
19    ///
20    /// Elements are visited in arbitrary order.
21    pub fn par_map_inplace<F>(&mut self, f: F)
22    where F: Fn(&mut A) + Sync + Send
23    {
24        self.view_mut().into_par_iter().for_each(f)
25    }
26
27    /// Parallel version of `mapv_inplace`.
28    ///
29    /// Modify the array in place by calling `f` by **v**alue on each element.
30    /// The array is updated with the new values.
31    ///
32    /// Elements are visited in arbitrary order.
33    pub fn par_mapv_inplace<F>(&mut self, f: F)
34    where
35        F: Fn(A) -> A + Sync + Send,
36        A: Clone,
37    {
38        self.view_mut()
39            .into_par_iter()
40            .for_each(move |x| *x = f(x.clone()))
41    }
42}
43
44// Zip
45
46const COLLECT_MAX_SPLITS: usize = 10;
47
48macro_rules! zip_impl {
49    ($([$notlast:ident $($p:ident)*],)+) => {
50        $(
51        #[allow(non_snake_case)]
52        impl<D, $($p),*> Zip<($($p,)*), D>
53            where $($p::Item : Send , )*
54                  $($p : Send , )*
55                  D: Dimension,
56                  $($p: NdProducer<Dim=D> ,)*
57        {
58            /// The `par_for_each` method for `Zip`.
59            ///
60            /// This is a shorthand for using `.into_par_iter().for_each()` on
61            /// `Zip`.
62            ///
63            /// Requires crate feature `rayon`.
64            pub fn par_for_each<F>(self, function: F)
65                where F: Fn($($p::Item),*) + Sync + Send
66            {
67                self.into_par_iter().for_each(move |($($p,)*)| function($($p),*))
68            }
69
70            expand_if!(@bool [$notlast]
71
72            /// Map and collect the results into a new array, which has the same size as the
73            /// inputs.
74            ///
75            /// If all inputs are c- or f-order respectively, that is preserved in the output.
76            pub fn par_map_collect<R>(self, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
77                -> Array<R, D>
78                where R: Send
79            {
80                let mut output = self.uninitialized_for_current_layout::<R>();
81                let total_len = output.len();
82
83                // Create a parallel iterator that produces chunks of the zip with the output
84                // array.  It's crucial that both parts split in the same way, and in a way
85                // so that the chunks of the output are still contig.
86                //
87                // Use a raw view so that we can alias the output data here and in the partial
88                // result.
89                let splits = unsafe {
90                    ParallelSplits {
91                        iter: self.and(SendProducer::new(output.raw_view_mut().cast::<R>())),
92                        // Keep it from splitting the Zip down too small
93                        max_splits: COLLECT_MAX_SPLITS,
94                    }
95                };
96
97                let collect_result = splits.map(move |zip| {
98                    // Apply the mapping function on this chunk of the zip
99                    // Create a partial result for the contiguous slice of data being written to
100                    unsafe {
101                        zip.collect_with_partial(&f)
102                    }
103                })
104                .reduce(Partial::stub, Partial::try_merge);
105
106                if std::mem::needs_drop::<R>() {
107                    debug_assert_eq!(total_len, collect_result.len,
108                        "collect len is not correct, expected {}", total_len);
109                    assert!(collect_result.len == total_len,
110                        "Collect: Expected number of writes not completed");
111                }
112
113                // Here the collect result is complete, and we release its ownership and transfer
114                // it to the output array.
115                collect_result.release_ownership();
116                unsafe {
117                    output.assume_init()
118                }
119            }
120
121            /// Map and assign the results into the producer `into`, which should have the same
122            /// size as the other inputs.
123            ///
124            /// The producer should have assignable items as dictated by the `AssignElem` trait,
125            /// for example `&mut R`.
126            pub fn par_map_assign_into<R, Q>(self, into: Q, f: impl Fn($($p::Item,)* ) -> R + Sync + Send)
127                where Q: IntoNdProducer<Dim=D>,
128                      Q::Item: AssignElem<R> + Send,
129                      Q::Output: Send,
130            {
131                self.and(into)
132                    .par_for_each(move |$($p, )* output_| {
133                        output_.assign_elem(f($($p ),*));
134                    });
135            }
136
137            /// Parallel version of `fold`.
138            ///
139            /// Splits the producer in multiple tasks which each accumulate a single value
140            /// using the `fold` closure. Those tasks are executed in parallel and their results
141            /// are then combined to a single value using the `reduce` closure.
142            ///
143            /// The `identity` closure provides the initial values for each of the tasks and
144            /// for the final reduction.
145            ///
146            /// This is a shorthand for calling `self.into_par_iter().fold(...).reduce(...)`.
147            ///
148            /// Note that it is often more efficient to parallelize not per-element but rather
149            /// based on larger chunks of an array like generalized rows and operating on each chunk
150            /// using a sequential variant of the accumulation.
151            /// For example, sum each row sequentially and in parallel, taking advantage of locality
152            /// and vectorization within each task, and then reduce their sums to the sum of the matrix.
153            ///
154            /// Also note that the splitting of the producer into multiple tasks is _not_ deterministic
155            /// which needs to be considered when the accuracy of such an operation is analyzed.
156            ///
157            /// ## Examples
158            ///
159            /// ```rust
160            /// use ndarray::{Array, Zip};
161            ///
162            /// let a = Array::<usize, _>::ones((128, 1024));
163            /// let b = Array::<usize, _>::ones(128);
164            ///
165            /// let weighted_sum = Zip::from(a.rows()).and(&b).par_fold(
166            ///     || 0,
167            ///     |sum, row, factor| sum + row.sum() * factor,
168            ///     |sum, other_sum| sum + other_sum,
169            /// );
170            ///
171            /// assert_eq!(weighted_sum, a.len());
172            /// ```
173            pub fn par_fold<ID, F, R, T>(self, identity: ID, fold: F, reduce: R) -> T
174            where
175                ID: Fn() -> T + Send + Sync + Clone,
176                F: Fn(T, $($p::Item),*) -> T + Send + Sync,
177                R: Fn(T, T) -> T + Send + Sync,
178                T: Send
179            {
180                self.into_par_iter()
181                    .fold(identity.clone(), move |accumulator, ($($p,)*)| {
182                        fold(accumulator, $($p),*)
183                    })
184                    .reduce(identity, reduce)
185            }
186
187            );
188        }
189        )+
190    };
191}
192
193zip_impl! {
194    [true P1],
195    [true P1 P2],
196    [true P1 P2 P3],
197    [true P1 P2 P3 P4],
198    [true P1 P2 P3 P4 P5],
199    [false P1 P2 P3 P4 P5 P6],
200}