1use rayon::prelude::*;
2use rten_base::iter::{ParIter, SplitIterator};
3
4use super::{
5 AxisChunks, AxisChunksMut, AxisIter, AxisIterMut, InnerIter, InnerIterBase, InnerIterMut, Iter,
6 IterMut, LaneRanges, Lanes, LanesMut, Offsets, OffsetsKind,
7};
8use crate::layout::RemoveDim;
9use crate::{Layout, MutLayout, Storage};
10
11macro_rules! impl_parallel_iterator {
14 () => {
15 type Iter = ParIter<Self>;
16 type Item = <Self as Iterator>::Item;
17
18 fn into_par_iter(self) -> Self::Iter {
19 self.into()
20 }
21 };
22}
23
24impl SplitIterator for Offsets {
25 fn split_at(self, index: usize) -> (Self, Self) {
26 assert!(index <= self.len());
27 let (left_kind, right_kind) = match self.base {
28 OffsetsKind::Range(r) => {
29 let left = r.start..r.start + index;
30 let right = r.start + index..r.end;
31 (OffsetsKind::Range(left), OffsetsKind::Range(right))
32 }
33 OffsetsKind::Indexing(base) => {
34 let (left, right) = base.split_at(index);
35 (OffsetsKind::Indexing(left), OffsetsKind::Indexing(right))
36 }
37 };
38 (Offsets { base: left_kind }, Offsets { base: right_kind })
39 }
40}
41
42impl<L: Layout + Clone> SplitIterator for InnerIterBase<L> {
43 fn split_at(self, index: usize) -> (Self, Self) {
44 let (left_offsets, right_offsets) = self.outer_offsets.split_at(index);
45 let left = Self {
46 outer_offsets: left_offsets,
47 inner_layout: self.inner_layout.clone(),
48 inner_data_len: self.inner_data_len,
49 };
50 let right = Self {
51 outer_offsets: right_offsets,
52 inner_layout: self.inner_layout,
53 inner_data_len: self.inner_data_len,
54 };
55 (left, right)
56 }
57}
58
59impl<'a, T, L: MutLayout + Send + Sync> SplitIterator for InnerIter<'a, T, L> {
60 fn split_at(self, index: usize) -> (Self, Self) {
61 let (left_base, right_base) = self.base.split_at(index);
62 let left = Self {
63 base: left_base,
64 data: self.data,
65 };
66 let right = Self {
67 base: right_base,
68 data: self.data,
69 };
70 (left, right)
71 }
72}
73
74impl<'a, T, L: MutLayout + Send + Sync> IntoParallelIterator for InnerIter<'a, T, L> {
75 impl_parallel_iterator!();
76}
77
78impl<'a, T, L: MutLayout + Send + Sync> SplitIterator for InnerIterMut<'a, T, L> {
79 fn split_at(self, index: usize) -> (Self, Self) {
80 let (left_base, right_base) = self.base.split_at(index);
81 let len = self.data.len();
82
83 let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
87
88 let left = Self {
89 base: left_base,
90 data: left_data,
91 };
92 let right = Self {
93 base: right_base,
94 data: right_data,
95 };
96 (left, right)
97 }
98}
99
100impl<'a, T, L: MutLayout + Send + Sync> IntoParallelIterator for InnerIterMut<'a, T, L> {
101 impl_parallel_iterator!();
102}
103
104impl<'a, T, L: MutLayout + RemoveDim> SplitIterator for AxisIter<'a, T, L> {
105 fn split_at(self, index: usize) -> (Self, Self) {
106 let (left_view, right_view) = self.view.split_at(self.axis, index);
107 let left = AxisIter::new(&left_view, self.axis);
108 let right = AxisIter::new(&right_view, self.axis);
109 (left, right)
110 }
111}
112
113impl<'a, T, L: MutLayout + RemoveDim + Send> IntoParallelIterator for AxisIter<'a, T, L>
114where
115 <L as RemoveDim>::Output: Send,
116{
117 impl_parallel_iterator!();
118}
119
120impl<'a, T, L: MutLayout + RemoveDim> SplitIterator for AxisIterMut<'a, T, L> {
121 fn split_at(self, index: usize) -> (Self, Self) {
122 let (left_view, right_view) = self.view.split_at_mut(self.axis, index);
123 let left = AxisIterMut::new(left_view, self.axis);
124 let right = AxisIterMut::new(right_view, self.axis);
125 (left, right)
126 }
127}
128
129impl<'a, T, L: MutLayout + RemoveDim + Send> IntoParallelIterator for AxisIterMut<'a, T, L>
130where
131 <L as RemoveDim>::Output: Send,
132{
133 impl_parallel_iterator!();
134}
135
136impl<'a, T, L: MutLayout> SplitIterator for AxisChunks<'a, T, L> {
137 fn split_at(mut self, index: usize) -> (Self, Self) {
138 let (left_remainder, right_remainder) = if let Some(remainder) = self.remainder.take() {
139 let (l, r) = remainder.split_at(self.axis, self.chunk_size * index);
140 (Some(l), Some(r))
141 } else {
142 (None, None)
143 };
144
145 let left = AxisChunks {
146 remainder: left_remainder,
147 axis: self.axis,
148 chunk_size: self.chunk_size,
149 };
150 let right = AxisChunks {
151 remainder: right_remainder,
152 axis: self.axis,
153 chunk_size: self.chunk_size,
154 };
155
156 (left, right)
157 }
158}
159
160impl<'a, T, L: MutLayout + Send> IntoParallelIterator for AxisChunks<'a, T, L> {
161 impl_parallel_iterator!();
162}
163
164impl<'a, T, L: MutLayout> SplitIterator for AxisChunksMut<'a, T, L> {
165 fn split_at(mut self, index: usize) -> (Self, Self) {
166 let (left_remainder, right_remainder) = if let Some(remainder) = self.remainder.take() {
167 let (l, r) = remainder.split_at_mut(self.axis, self.chunk_size * index);
168 (Some(l), Some(r))
169 } else {
170 (None, None)
171 };
172
173 let left = Self {
174 remainder: left_remainder,
175 axis: self.axis,
176 chunk_size: self.chunk_size,
177 };
178 let right = Self {
179 remainder: right_remainder,
180 axis: self.axis,
181 chunk_size: self.chunk_size,
182 };
183
184 (left, right)
185 }
186}
187
188impl<'a, T, L: MutLayout + Send> IntoParallelIterator for AxisChunksMut<'a, T, L> {
189 impl_parallel_iterator!();
190}
191
192impl<'a, T> SplitIterator for Iter<'a, T> {
193 fn split_at(self, index: usize) -> (Self, Self) {
194 let (left_offsets, right_offsets) = self.offsets.split_at(index);
195 let left = Self {
196 offsets: left_offsets,
197 data: self.data,
198 };
199 let right = Self {
200 offsets: right_offsets,
201 data: self.data,
202 };
203 (left, right)
204 }
205}
206
207impl<'a, T: Sync> IntoParallelIterator for Iter<'a, T> {
208 impl_parallel_iterator!();
209}
210
211impl<'a, T> SplitIterator for IterMut<'a, T> {
212 fn split_at(self, index: usize) -> (Self, Self) {
213 let (left_offsets, right_offsets) = self.offsets.split_at(index);
214 let len = self.data.len();
215 let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
216 let left = Self {
217 offsets: left_offsets,
218 data: left_data,
219 };
220 let right = Self {
221 offsets: right_offsets,
222 data: right_data,
223 };
224 (left, right)
225 }
226}
227
228impl<'a, T: Sync + Send> IntoParallelIterator for IterMut<'a, T> {
229 impl_parallel_iterator!();
230}
231
232impl SplitIterator for LaneRanges {
233 fn split_at(self, index: usize) -> (Self, Self) {
234 let (left_offsets, right_offsets) = self.offsets.split_at(index);
235 let left = LaneRanges {
236 offsets: left_offsets,
237 dim_size: self.dim_size,
238 dim_stride: self.dim_stride,
239 };
240 let right = LaneRanges {
241 offsets: right_offsets,
242 dim_size: self.dim_size,
243 dim_stride: self.dim_stride,
244 };
245 (left, right)
246 }
247}
248
249impl<'a, T> SplitIterator for Lanes<'a, T> {
250 fn split_at(self, index: usize) -> (Self, Self) {
251 let (left_range, right_range) = self.ranges.split_at(index);
252
253 let left = Lanes {
254 data: self.data,
255 ranges: left_range,
256 lane_layout: self.lane_layout,
257 };
258 let right = Lanes {
259 data: self.data,
260 ranges: right_range,
261 lane_layout: self.lane_layout,
262 };
263
264 (left, right)
265 }
266}
267
268impl<'a, T: Sync + Send> IntoParallelIterator for Lanes<'a, T> {
269 impl_parallel_iterator!();
270}
271
272impl<'a, T> SplitIterator for LanesMut<'a, T> {
273 fn split_at(self, index: usize) -> (Self, Self) {
274 let (left_range, right_range) = self.ranges.split_at(index);
275 let len = self.data.len();
276
277 let (left_data, right_data) = self.data.split_mut(0..len, 0..len);
280
281 let left = Self {
282 data: left_data,
283 ranges: left_range,
284 lane_layout: self.lane_layout,
285 };
286 let right = Self {
287 data: right_data,
288 ranges: right_range,
289 lane_layout: self.lane_layout,
290 };
291
292 (left, right)
293 }
294}
295
296impl<T: Sync + Send> IntoParallelIterator for LanesMut<'_, T> {
297 impl_parallel_iterator!();
298}
299
300#[cfg(test)]
301mod tests {
302 use rayon::prelude::*;
303
304 use crate::rng::XorShiftRng;
305 use crate::{AsView, Tensor};
306
307 macro_rules! test_parallel_iterator {
314 ($x:ident, $iter:expr) => {
315 let mut rng = XorShiftRng::new(1234);
316 let $x = Tensor::<f32>::rand(&[4, 8, 16, 32], &mut rng);
317 let serial: Vec<_> = $iter.collect();
318 let parallel: Vec<_> = $iter.into_par_iter().collect();
319 assert_eq!(serial, parallel);
320 };
321 }
322
323 macro_rules! test_parallel_iterator_mut {
326 ($x:ident, $iter:expr, $item_sum:expr) => {
327 let mut rng = XorShiftRng::new(1234);
328
329 let mut $x =
333 Tensor::<i32>::from_simple_fn(&[4, 8, 16, 32], || (rng.next_f32() * 100.) as i32);
334 let serial: i32 = $iter.map($item_sum).sum();
335 let parallel: i32 = $iter.into_par_iter().map($item_sum).sum();
336
337 assert_eq!(serial, parallel);
338 };
339 }
340
341 macro_rules! test_parallel_iterator_flatten {
346 ($x:ident, $iter:expr) => {
347 let mut rng = XorShiftRng::new(1234);
348 let $x = Tensor::<f32>::rand(&[4, 8, 16, 32], &mut rng);
349
350 let serial: Vec<_> = $iter.collect();
351 let parallel: Vec<_> = $iter.into_par_iter().collect();
352
353 let serial_items: Vec<f32> = serial.into_iter().flatten().copied().collect();
354 let parallel_items: Vec<f32> = parallel.into_iter().flatten().copied().collect();
355 assert_eq!(serial_items, parallel_items);
356 };
357 }
358
359 #[test]
363 #[cfg_attr(miri, ignore)]
364 fn test_inner_iter_parallel() {
365 test_parallel_iterator!(x, x.inner_iter::<2>());
366 }
367
368 #[test]
369 #[cfg_attr(miri, ignore)]
370 fn test_inner_iter_mut_parallel() {
371 test_parallel_iterator_mut!(x, x.inner_iter_mut::<2>(), |x| x.iter().sum::<i32>());
372 }
373
374 #[test]
375 #[cfg_attr(miri, ignore)]
376 fn test_iter_parallel() {
377 test_parallel_iterator!(x, x.iter());
378 }
379
380 #[test]
381 #[cfg_attr(miri, ignore)]
382 fn test_iter_mut_parallel() {
383 test_parallel_iterator_mut!(x, x.iter_mut(), |x| *x);
384 }
385
386 #[test]
387 #[cfg_attr(miri, ignore)]
388 fn test_axis_chunks_parallel() {
389 test_parallel_iterator!(x, x.axis_chunks(0, 2));
390 }
391
392 #[test]
393 #[cfg_attr(miri, ignore)]
394 fn test_axis_chunks_mut_parallel() {
395 test_parallel_iterator_mut!(x, x.axis_chunks_mut(0, 2), |x| x.iter().sum::<i32>());
396 }
397
398 #[test]
399 #[cfg_attr(miri, ignore)]
400 fn test_axis_iter_parallel() {
401 test_parallel_iterator!(x, x.axis_iter(0));
402 }
403
404 #[test]
405 #[cfg_attr(miri, ignore)]
406 fn test_axis_iter_mut_parallel() {
407 test_parallel_iterator_mut!(x, x.axis_iter_mut(0), |x| x.iter().sum::<i32>());
408 }
409
410 #[test]
411 #[cfg_attr(miri, ignore)]
412 fn test_lanes_parallel() {
413 test_parallel_iterator_flatten!(x, x.lanes(0));
414 }
415
416 #[test]
417 #[cfg_attr(miri, ignore)]
418 fn test_lanes_mut_parallel() {
419 test_parallel_iterator_mut!(x, x.lanes_mut(0), |x| x.map(|x| *x).sum::<i32>());
420 }
421}