rten_tensor/iterators/
parallel.rs

1use rayon::prelude::*;
2use rten_base::iter::{ParIter, SplitIterator};
3
4use super::{
5    AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterBase, InnerIterMut, Iter,
6    IterMut, LaneRanges, Lanes, LanesMut, Offsets, OffsetsKind,
7};
8use crate::layout::RemoveDim;
9use crate::{Layout, MutLayout, Storage};
10
11/// Generate the body of an [`IntoParallelIterator`] impl which uses [`ParIter`]
12/// as the iterator type.
13macro_rules! impl_parallel_iterator {
14    () => {
15        type Iter = ParIter<Self>;
16        type Item = <Self as Iterator>::Item;
17
18        fn into_par_iter(self) -> Self::Iter {
19            self.into()
20        }
21    };
22}
23
24impl SplitIterator for Offsets {
25    fn split_at(self, index: usize) -> (Self, Self) {
26        assert!(index <= self.len());
27        let (left_kind, right_kind) = match self.base {
28            OffsetsKind::Range(r) => {
29                let left = r.start..r.start + index;
30                let right = r.start + index..r.end;
31                (OffsetsKind::Range(left), OffsetsKind::Range(right))
32            }
33            OffsetsKind::Indexing(base) => {
34                let (left, right) = base.split_at(index);
35                (OffsetsKind::Indexing(left), OffsetsKind::Indexing(right))
36            }
37        };
38        (Offsets { base: left_kind }, Offsets { base: right_kind })
39    }
40}
41
42impl<L: Layout + Clone> SplitIterator for InnerIterBase<L> {
43    fn split_at(self, index: usize) -> (Self, Self) {
44        let (left_offsets, right_offsets) = self.outer_offsets.split_at(index);
45        let left = Self {
46            outer_offsets: left_offsets,
47            inner_layout: self.inner_layout.clone(),
48            inner_data_len: self.inner_data_len,
49        };
50        let right = Self {
51            outer_offsets: right_offsets,
52            inner_layout: self.inner_layout,
53            inner_data_len: self.inner_data_len,
54        };
55        (left, right)
56    }
57}
58
59impl<'a, T, L: MutLayout + Send + Sync> SplitIterator for InnerIter<'a, T, L> {
60    fn split_at(self, index: usize) -> (Self, Self) {
61        let (left_base, right_base) = self.base.split_at(index);
62        let left = Self {
63            base: left_base,
64            data: self.data,
65        };
66        let right = Self {
67            base: right_base,
68            data: self.data,
69        };
70        (left, right)
71    }
72}
73
74impl<'a, T, L: MutLayout + Send + Sync> IntoParallelIterator for InnerIter<'a, T, L> {
75    impl_parallel_iterator!();
76}
77
78impl<'a, T, L: MutLayout + Send + Sync> SplitIterator for InnerIterMut<'a, T, L> {
79    fn split_at(self, index: usize) -> (Self, Self) {
80        let (left_base, right_base) = self.base.split_at(index);
81        let len = self.data.len();
82
83        // The left/right splits use the same storage. We rely on the left/right
84        // layouts being logically disjoint to ensure we don't create multiple
85        // mutable references to the same elements.
86        let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
87
88        let left = Self {
89            base: left_base,
90            data: left_data,
91        };
92        let right = Self {
93            base: right_base,
94            data: right_data,
95        };
96        (left, right)
97    }
98}
99
100impl<'a, T, L: MutLayout + Send + Sync> IntoParallelIterator for InnerIterMut<'a, T, L> {
101    impl_parallel_iterator!();
102}
103
104impl<'a, T, L: MutLayout + RemoveDim> SplitIterator for AxisIter<'a, T, L> {
105    fn split_at(self, index: usize) -> (Self, Self) {
106        let (left_view, right_view) = self.view.split_at(self.axis, index);
107        let left = AxisIter::new(&left_view, self.axis);
108        let right = AxisIter::new(&right_view, self.axis);
109        (left, right)
110    }
111}
112
113impl<'a, T, L: MutLayout + RemoveDim + Send> IntoParallelIterator for AxisIter<'a, T, L>
114where
115    <L as RemoveDim>::Output: Send,
116{
117    impl_parallel_iterator!();
118}
119
120impl<'a, T, L: MutLayout + RemoveDim> SplitIterator for AxisIterMut<'a, T, L> {
121    fn split_at(self, index: usize) -> (Self, Self) {
122        let (left_view, right_view) = self.view.split_at_mut(self.axis, index);
123        let left = AxisIterMut::new(left_view, self.axis);
124        let right = AxisIterMut::new(right_view, self.axis);
125        (left, right)
126    }
127}
128
129impl<'a, T, L: MutLayout + RemoveDim + Send> IntoParallelIterator for AxisIterMut<'a, T, L>
130where
131    <L as RemoveDim>::Output: Send,
132{
133    impl_parallel_iterator!();
134}
135
136impl<'a, T, L: MutLayout> SplitIterator for AxisChunks<'a, T, L> {
137    fn split_at(mut self, index: usize) -> (Self, Self) {
138        let (left_remainder, right_remainder) = if let Some(remainder) = self.remainder.take() {
139            let (l, r) = remainder.split_at(self.axis, self.chunk_size * index);
140            (Some(l), Some(r))
141        } else {
142            (None, None)
143        };
144
145        let left = AxisChunks {
146            remainder: left_remainder,
147            axis: self.axis,
148            chunk_size: self.chunk_size,
149        };
150        let right = AxisChunks {
151            remainder: right_remainder,
152            axis: self.axis,
153            chunk_size: self.chunk_size,
154        };
155
156        (left, right)
157    }
158}
159
160impl<'a, T, L: MutLayout + Send> IntoParallelIterator for AxisChunks<'a, T, L> {
161    impl_parallel_iterator!();
162}
163
164impl<'a, T, L: MutLayout> SplitIterator for AxisChunksMut<'a, T, L> {
165    fn split_at(mut self, index: usize) -> (Self, Self) {
166        let (left_remainder, right_remainder) = if let Some(remainder) = self.remainder.take() {
167            let (l, r) = remainder.split_at_mut(self.axis, self.chunk_size * index);
168            (Some(l), Some(r))
169        } else {
170            (None, None)
171        };
172
173        let left = Self {
174            remainder: left_remainder,
175            axis: self.axis,
176            chunk_size: self.chunk_size,
177        };
178        let right = Self {
179            remainder: right_remainder,
180            axis: self.axis,
181            chunk_size: self.chunk_size,
182        };
183
184        (left, right)
185    }
186}
187
188impl<'a, T, L: MutLayout + Send> IntoParallelIterator for AxisChunksMut<'a, T, L> {
189    impl_parallel_iterator!();
190}
191
192impl<'a, T> SplitIterator for Iter<'a, T> {
193    fn split_at(self, index: usize) -> (Self, Self) {
194        let (left_offsets, right_offsets) = self.offsets.split_at(index);
195        let left = Self {
196            offsets: left_offsets,
197            data: self.data,
198        };
199        let right = Self {
200            offsets: right_offsets,
201            data: self.data,
202        };
203        (left, right)
204    }
205}
206
207impl<'a, T: Sync> IntoParallelIterator for Iter<'a, T> {
208    impl_parallel_iterator!();
209}
210
211impl<'a, T> SplitIterator for IterMut<'a, T> {
212    fn split_at(self, index: usize) -> (Self, Self) {
213        let (left_offsets, right_offsets) = self.offsets.split_at(index);
214        let len = self.data.len();
215        let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
216        let left = Self {
217            offsets: left_offsets,
218            data: left_data,
219        };
220        let right = Self {
221            offsets: right_offsets,
222            data: right_data,
223        };
224        (left, right)
225    }
226}
227
228impl<'a, T: Sync + Send> IntoParallelIterator for IterMut<'a, T> {
229    impl_parallel_iterator!();
230}
231
232impl SplitIterator for LaneRanges {
233    fn split_at(self, index: usize) -> (Self, Self) {
234        let (left_offsets, right_offsets) = self.offsets.split_at(index);
235        let left = LaneRanges {
236            offsets: left_offsets,
237            dim_size: self.dim_size,
238            dim_stride: self.dim_stride,
239        };
240        let right = LaneRanges {
241            offsets: right_offsets,
242            dim_size: self.dim_size,
243            dim_stride: self.dim_stride,
244        };
245        (left, right)
246    }
247}
248
249impl<'a, T> SplitIterator for Lanes<'a, T> {
250    fn split_at(self, index: usize) -> (Self, Self) {
251        let (left_range, right_range) = self.ranges.split_at(index);
252
253        let left = Lanes {
254            data: self.data,
255            ranges: left_range,
256            lane_layout: self.lane_layout,
257        };
258        let right = Lanes {
259            data: self.data,
260            ranges: right_range,
261            lane_layout: self.lane_layout,
262        };
263
264        (left, right)
265    }
266}
267
268impl<'a, T: Sync + Send> IntoParallelIterator for Lanes<'a, T> {
269    impl_parallel_iterator!();
270}
271
272impl<'a, T> SplitIterator for LanesMut<'a, T> {
273    fn split_at(self, index: usize) -> (Self, Self) {
274        let (left_range, right_range) = self.ranges.split_at(index);
275        let len = self.data.len();
276
277        // Safety note: `split_mut` relies on the caller to ensure that
278        // associated layouts do not overlap.
279        let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
280
281        let left = Self {
282            data: left_data,
283            ranges: left_range,
284            lane_layout: self.lane_layout,
285        };
286        let right = Self {
287            data: right_data,
288            ranges: right_range,
289            lane_layout: self.lane_layout,
290        };
291
292        (left, right)
293    }
294}
295
296impl<T: Sync + Send> IntoParallelIterator for LanesMut<'_, T> {
297    impl_parallel_iterator!();
298}
299
300#[cfg(test)]
301mod tests {
302    use rayon::prelude::*;
303
304    use crate::rng::XorShiftRng;
305    use crate::{AsView, Tensor};
306
307    // These helpers use macros to work around difficulties expressing lifetime
308    // relationships between input and output in closures that take an `&Tensor`
309    // and return an `impl Iterator + IntoParallelIterator`.
310
311    // Test that the parallel version of an iterator yields the same items as
312    // the serial version.
313    macro_rules! test_parallel_iterator {
314        ($x:ident, $iter:expr) => {
315            let mut rng = XorShiftRng::new(1234);
316            let $x = Tensor::<f32>::rand(&[4, 8, 16, 32], &mut rng);
317            let serial: Vec<_> = $iter.collect();
318            let parallel: Vec<_> = $iter.into_par_iter().collect();
319            assert_eq!(serial, parallel);
320        };
321    }
322
323    // Test that the parallel version of a mutable iterator yields the same
324    // items as the serial version.
325    macro_rules! test_parallel_iterator_mut {
326        ($x:ident, $iter:expr, $item_sum:expr) => {
327            let mut rng = XorShiftRng::new(1234);
328
329            // Use ints rather than floats here to avoid mismatches due to
330            // parallel iteration visiting items in a different order to serial
331            // iteration.
332            let mut $x =
333                Tensor::<i32>::from_simple_fn(&[4, 8, 16, 32], || (rng.next_f32() * 100.) as i32);
334            let serial: i32 = $iter.map($item_sum).sum();
335            let parallel: i32 = $iter.into_par_iter().map($item_sum).sum();
336
337            assert_eq!(serial, parallel);
338        };
339    }
340
341    // Test that the parallel version of an iterator yields the same items as
342    // the serial version.
343    //
344    // This is a variant for the case where the items are themselves iterators.
345    macro_rules! test_parallel_iterator_flatten {
346        ($x:ident, $iter:expr) => {
347            let mut rng = XorShiftRng::new(1234);
348            let $x = Tensor::<f32>::rand(&[4, 8, 16, 32], &mut rng);
349
350            let serial: Vec<_> = $iter.collect();
351            let parallel: Vec<_> = $iter.into_par_iter().collect();
352
353            let serial_items: Vec<f32> = serial.into_iter().flatten().copied().collect();
354            let parallel_items: Vec<f32> = parallel.into_iter().flatten().copied().collect();
355            assert_eq!(serial_items, parallel_items);
356        };
357    }
358
359    // Parallel tests are skipped under Miri due to
360    // https://github.com/crossbeam-rs/crossbeam/issues/1181.
361
362    #[test]
363    #[cfg_attr(miri, ignore)]
364    fn test_inner_iter_parallel() {
365        test_parallel_iterator!(x, x.inner_iter::<2>());
366    }
367
368    #[test]
369    #[cfg_attr(miri, ignore)]
370    fn test_inner_iter_mut_parallel() {
371        test_parallel_iterator_mut!(x, x.inner_iter_mut::<2>(), |x| x.iter().sum::<i32>());
372    }
373
374    #[test]
375    #[cfg_attr(miri, ignore)]
376    fn test_iter_parallel() {
377        test_parallel_iterator!(x, x.iter());
378    }
379
380    #[test]
381    #[cfg_attr(miri, ignore)]
382    fn test_iter_mut_parallel() {
383        test_parallel_iterator_mut!(x, x.iter_mut(), |x| *x);
384    }
385
386    #[test]
387    #[cfg_attr(miri, ignore)]
388    fn test_axis_chunks_parallel() {
389        test_parallel_iterator!(x, x.axis_chunks(0, 2));
390    }
391
392    #[test]
393    #[cfg_attr(miri, ignore)]
394    fn test_axis_chunks_mut_parallel() {
395        test_parallel_iterator_mut!(x, x.axis_chunks_mut(0, 2), |x| x.iter().sum::<i32>());
396    }
397
398    #[test]
399    #[cfg_attr(miri, ignore)]
400    fn test_axis_iter_parallel() {
401        test_parallel_iterator!(x, x.axis_iter(0));
402    }
403
404    #[test]
405    #[cfg_attr(miri, ignore)]
406    fn test_axis_iter_mut_parallel() {
407        test_parallel_iterator_mut!(x, x.axis_iter_mut(0), |x| x.iter().sum::<i32>());
408    }
409
410    #[test]
411    #[cfg_attr(miri, ignore)]
412    fn test_lanes_parallel() {
413        test_parallel_iterator_flatten!(x, x.lanes(0));
414    }
415
416    #[test]
417    #[cfg_attr(miri, ignore)]
418    fn test_lanes_mut_parallel() {
419        test_parallel_iterator_mut!(x, x.lanes_mut(0), |x| x.map(|x| *x).sum::<i32>());
420    }
421}