rstsr_core/tensor/
iterator_axes.rs

1#![allow(clippy::missing_transmute_annotations)]
2
3use crate::prelude_dev::*;
4use core::mem::transmute;
5
6/* #region axes iter view iterator */
7
8pub struct IterAxesView<'a, T, B>
9where
10    B: DeviceAPI<T>,
11{
12    axes_iter: IterLayout<IxD>,
13    view: TensorView<'a, T, B, IxD>,
14}
15
16impl<T, B> IterAxesView<'_, T, B>
17where
18    B: DeviceAPI<T>,
19{
20    pub fn update_offset(&mut self, offset: usize) {
21        unsafe { self.view.layout.set_offset(offset) };
22    }
23}
24
25impl<'a, T, B> Iterator for IterAxesView<'a, T, B>
26where
27    B: DeviceAPI<T>,
28{
29    type Item = TensorView<'a, T, B, IxD>;
30
31    fn next(&mut self) -> Option<Self::Item> {
32        self.axes_iter.next().map(|offset| {
33            self.update_offset(offset);
34            unsafe { transmute(self.view.view()) }
35        })
36    }
37}
38
39impl<T, B> DoubleEndedIterator for IterAxesView<'_, T, B>
40where
41    B: DeviceAPI<T>,
42{
43    fn next_back(&mut self) -> Option<Self::Item> {
44        self.axes_iter.next_back().map(|offset| {
45            self.update_offset(offset);
46            unsafe { transmute(self.view.view()) }
47        })
48    }
49}
50
51impl<T, B> ExactSizeIterator for IterAxesView<'_, T, B>
52where
53    B: DeviceAPI<T>,
54{
55    fn len(&self) -> usize {
56        self.axes_iter.len()
57    }
58}
59
60impl<T, B> IterSplitAtAPI for IterAxesView<'_, T, B>
61where
62    B: DeviceAPI<T>,
63{
64    fn split_at(self, index: usize) -> (Self, Self) {
65        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
66        let view_lhs = unsafe { transmute(self.view.view()) };
67        let lhs = IterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
68        let rhs = IterAxesView { axes_iter: rhs_axes_iter, view: self.view };
69        return (lhs, rhs);
70    }
71}
72
73impl<'a, R, T, B, D> TensorAny<R, T, B, D>
74where
75    T: Clone,
76    R: DataCloneAPI<Data = B::Raw>,
77    D: DimAPI,
78    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
79{
80    pub fn axes_iter_with_order_f<I>(&self, axes: I, order: TensorIterOrder) -> Result<IterAxesView<'a, T, B>>
81    where
82        I: TryInto<AxesIndex<isize>, Error = Error>,
83    {
84        // convert axis to negative indexes and sort
85        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
86        let axes: Vec<isize> =
87            axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
88        let mut axes_check = axes.clone();
89        axes_check.sort();
90        // check no two axis are the same, and no negative index too small
91        if axes.first().is_some_and(|&v| v < 0) {
92            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
93        }
94        for i in 0..axes_check.len() - 1 {
95            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
96        }
97
98        // get full layout
99        let layout = self.layout().to_dim::<IxD>()?;
100        let shape_full = layout.shape();
101        let stride_full = layout.stride();
102        let offset = layout.offset();
103
104        // get layout for axes_iter
105        let mut shape_axes = vec![];
106        let mut stride_axes = vec![];
107        for &idx in &axes {
108            shape_axes.push(shape_full[idx as usize]);
109            stride_axes.push(stride_full[idx as usize]);
110        }
111        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
112
113        // get layout for inner view
114        let mut shape_inner = vec![];
115        let mut stride_inner = vec![];
116        for idx in 0..ndim {
117            if !axes.contains(&idx) {
118                shape_inner.push(shape_full[idx as usize]);
119                stride_inner.push(stride_full[idx as usize]);
120            }
121        }
122        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
123
124        // create axes iter
125        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
126        let mut view = self.view().into_dyn();
127        view.layout = layout_inner.clone();
128        let iter = IterAxesView { axes_iter, view: unsafe { transmute(view) } };
129        Ok(iter)
130    }
131
132    pub fn axes_iter_f<I>(&self, axes: I) -> Result<IterAxesView<'a, T, B>>
133    where
134        I: TryInto<AxesIndex<isize>, Error = Error>,
135    {
136        self.axes_iter_with_order_f(axes, TensorIterOrder::default())
137    }
138
139    pub fn axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IterAxesView<'a, T, B>
140    where
141        I: TryInto<AxesIndex<isize>, Error = Error>,
142    {
143        self.axes_iter_with_order_f(axes, order).rstsr_unwrap()
144    }
145
146    pub fn axes_iter<I>(&self, axes: I) -> IterAxesView<'a, T, B>
147    where
148        I: TryInto<AxesIndex<isize>, Error = Error>,
149    {
150        self.axes_iter_f(axes).rstsr_unwrap()
151    }
152}
153
154/* #endregion */
155
156/* #region axes iter mut iterator */
157
158pub struct IterAxesMut<'a, T, B>
159where
160    B: DeviceAPI<T>,
161{
162    axes_iter: IterLayout<IxD>,
163    view: TensorMut<'a, T, B, IxD>,
164}
165
166impl<T, B> IterAxesMut<'_, T, B>
167where
168    B: DeviceAPI<T>,
169{
170    pub fn update_offset(&mut self, offset: usize) {
171        unsafe { self.view.layout.set_offset(offset) };
172    }
173}
174
175impl<'a, T, B> Iterator for IterAxesMut<'a, T, B>
176where
177    B: DeviceAPI<T>,
178{
179    type Item = TensorMut<'a, T, B, IxD>;
180
181    fn next(&mut self) -> Option<Self::Item> {
182        self.axes_iter.next().map(|offset| {
183            self.update_offset(offset);
184            unsafe { transmute(self.view.view_mut()) }
185        })
186    }
187}
188
189impl<T, B> DoubleEndedIterator for IterAxesMut<'_, T, B>
190where
191    B: DeviceAPI<T>,
192{
193    fn next_back(&mut self) -> Option<Self::Item> {
194        self.axes_iter.next_back().map(|offset| {
195            self.update_offset(offset);
196            unsafe { transmute(self.view.view_mut()) }
197        })
198    }
199}
200
201impl<T, B> ExactSizeIterator for IterAxesMut<'_, T, B>
202where
203    B: DeviceAPI<T>,
204{
205    fn len(&self) -> usize {
206        self.axes_iter.len()
207    }
208}
209
210impl<T, B> IterSplitAtAPI for IterAxesMut<'_, T, B>
211where
212    B: DeviceAPI<T>,
213{
214    fn split_at(mut self, index: usize) -> (Self, Self) {
215        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
216        let view_lhs = unsafe { transmute(self.view.view_mut()) };
217        let lhs = IterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
218        let rhs = IterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
219        return (lhs, rhs);
220    }
221}
222
223impl<'a, R, T, B, D> TensorAny<R, T, B, D>
224where
225    T: Clone,
226    R: DataMutAPI<Data = B::Raw>,
227    D: DimAPI,
228    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
229{
230    pub fn axes_iter_mut_with_order_f<I>(&'a mut self, axes: I, order: TensorIterOrder) -> Result<IterAxesMut<'a, T, B>>
231    where
232        I: TryInto<AxesIndex<isize>, Error = Error>,
233    {
234        // convert axis to negative indexes and sort
235        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
236        let axes: Vec<isize> =
237            axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
238        let mut axes_check = axes.clone();
239        axes_check.sort();
240        // check no two axis are the same, and no negative index too small
241        if axes.first().is_some_and(|&v| v < 0) {
242            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
243        }
244        for i in 0..axes_check.len() - 1 {
245            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
246        }
247
248        // get full layout
249        let layout = self.layout().to_dim::<IxD>()?;
250        let shape_full = layout.shape();
251        let stride_full = layout.stride();
252        let offset = layout.offset();
253
254        // get layout for axes_iter
255        let mut shape_axes = vec![];
256        let mut stride_axes = vec![];
257        for &idx in &axes {
258            shape_axes.push(shape_full[idx as usize]);
259            stride_axes.push(stride_full[idx as usize]);
260        }
261        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
262
263        // get layout for inner view
264        let mut shape_inner = vec![];
265        let mut stride_inner = vec![];
266        for idx in 0..ndim {
267            if !axes.contains(&idx) {
268                shape_inner.push(shape_full[idx as usize]);
269                stride_inner.push(stride_full[idx as usize]);
270            }
271        }
272        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
273
274        // create axes iter
275        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
276        let mut view = self.view_mut().into_dyn();
277        view.layout = layout_inner.clone();
278        let iter = IterAxesMut { axes_iter, view };
279        Ok(iter)
280    }
281
282    pub fn axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IterAxesMut<'a, T, B>>
283    where
284        I: TryInto<AxesIndex<isize>, Error = Error>,
285    {
286        self.axes_iter_mut_with_order_f(axes, TensorIterOrder::default())
287    }
288
289    pub fn axes_iter_mut_with_order<I>(&'a mut self, axes: I, order: TensorIterOrder) -> IterAxesMut<'a, T, B>
290    where
291        I: TryInto<AxesIndex<isize>, Error = Error>,
292    {
293        self.axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
294    }
295
296    pub fn axes_iter_mut<I>(&'a mut self, axes: I) -> IterAxesMut<'a, T, B>
297    where
298        I: TryInto<AxesIndex<isize>, Error = Error>,
299    {
300        self.axes_iter_mut_f(axes).rstsr_unwrap()
301    }
302}
303
304/* #endregion */
305
306/* #region indexed axes iter view iterator */
307
308pub struct IndexedIterAxesView<'a, T, B>
309where
310    B: DeviceAPI<T>,
311{
312    axes_iter: IterLayout<IxD>,
313    view: TensorView<'a, T, B, IxD>,
314}
315
316impl<T, B> IndexedIterAxesView<'_, T, B>
317where
318    B: DeviceAPI<T>,
319{
320    pub fn update_offset(&mut self, offset: usize) {
321        unsafe { self.view.layout.set_offset(offset) };
322    }
323}
324
325impl<'a, T, B> Iterator for IndexedIterAxesView<'a, T, B>
326where
327    B: DeviceAPI<T>,
328{
329    type Item = (IxD, TensorView<'a, T, B, IxD>);
330
331    fn next(&mut self) -> Option<Self::Item> {
332        let index = match &self.axes_iter {
333            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
334            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
335        };
336        self.axes_iter.next().map(|offset| {
337            self.update_offset(offset);
338            (index, unsafe { transmute(self.view.view()) })
339        })
340    }
341}
342
343impl<T, B> DoubleEndedIterator for IndexedIterAxesView<'_, T, B>
344where
345    B: DeviceAPI<T>,
346{
347    fn next_back(&mut self) -> Option<Self::Item> {
348        let index = match &self.axes_iter {
349            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
350            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
351        };
352        self.axes_iter.next_back().map(|offset| {
353            self.update_offset(offset);
354            (index, unsafe { transmute(self.view.view()) })
355        })
356    }
357}
358
359impl<T, B> ExactSizeIterator for IndexedIterAxesView<'_, T, B>
360where
361    B: DeviceAPI<T>,
362{
363    fn len(&self) -> usize {
364        self.axes_iter.len()
365    }
366}
367
368impl<T, B> IterSplitAtAPI for IndexedIterAxesView<'_, T, B>
369where
370    B: DeviceAPI<T>,
371{
372    fn split_at(self, index: usize) -> (Self, Self) {
373        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
374        let view_lhs = unsafe { transmute(self.view.view()) };
375        let lhs = IndexedIterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
376        let rhs = IndexedIterAxesView { axes_iter: rhs_axes_iter, view: self.view };
377        return (lhs, rhs);
378    }
379}
380
381impl<'a, R, T, B, D> TensorAny<R, T, B, D>
382where
383    T: Clone,
384    R: DataCloneAPI<Data = B::Raw>,
385    D: DimAPI,
386    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
387{
388    pub fn indexed_axes_iter_with_order_f<I>(
389        &self,
390        axes: I,
391        order: TensorIterOrder,
392    ) -> Result<IndexedIterAxesView<'a, T, B>>
393    where
394        I: TryInto<AxesIndex<isize>, Error = Error>,
395    {
396        use TensorIterOrder::*;
397        // this function only accepts c/f iter order currently
398        match order {
399            C | F => (),
400            _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
401        };
402        // convert axis to negative indexes and sort
403        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
404        let axes: Vec<isize> =
405            axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
406        let mut axes_check = axes.clone();
407        axes_check.sort();
408        // check no two axis are the same, and no negative index too small
409        if axes.first().is_some_and(|&v| v < 0) {
410            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
411        }
412        for i in 0..axes_check.len() - 1 {
413            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
414        }
415
416        // get full layout
417        let layout = self.layout().to_dim::<IxD>()?;
418        let shape_full = layout.shape();
419        let stride_full = layout.stride();
420        let offset = layout.offset();
421
422        // get layout for axes_iter
423        let mut shape_axes = vec![];
424        let mut stride_axes = vec![];
425        for &idx in &axes {
426            shape_axes.push(shape_full[idx as usize]);
427            stride_axes.push(stride_full[idx as usize]);
428        }
429        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
430
431        // get layout for inner view
432        let mut shape_inner = vec![];
433        let mut stride_inner = vec![];
434        for idx in 0..ndim {
435            if !axes.contains(&idx) {
436                shape_inner.push(shape_full[idx as usize]);
437                stride_inner.push(stride_full[idx as usize]);
438            }
439        }
440        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
441
442        // create axes iter
443        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
444        let mut view = self.view().into_dyn();
445        view.layout = layout_inner.clone();
446        let iter = IndexedIterAxesView { axes_iter, view: unsafe { transmute(view) } };
447        Ok(iter)
448    }
449
450    pub fn indexed_axes_iter_f<I>(&self, axes: I) -> Result<IndexedIterAxesView<'a, T, B>>
451    where
452        I: TryInto<AxesIndex<isize>, Error = Error>,
453    {
454        let default_order = self.device().default_order();
455        let order = match default_order {
456            RowMajor => TensorIterOrder::C,
457            ColMajor => TensorIterOrder::F,
458        };
459        self.indexed_axes_iter_with_order_f(axes, order)
460    }
461
462    pub fn indexed_axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IndexedIterAxesView<'a, T, B>
463    where
464        I: TryInto<AxesIndex<isize>, Error = Error>,
465    {
466        self.indexed_axes_iter_with_order_f(axes, order).rstsr_unwrap()
467    }
468
469    pub fn indexed_axes_iter<I>(&self, axes: I) -> IndexedIterAxesView<'a, T, B>
470    where
471        I: TryInto<AxesIndex<isize>, Error = Error>,
472    {
473        self.indexed_axes_iter_f(axes).rstsr_unwrap()
474    }
475}
476
477/* #endregion */
478
479/* #region axes iter mut iterator */
480
481pub struct IndexedIterAxesMut<'a, T, B>
482where
483    B: DeviceAPI<T>,
484{
485    axes_iter: IterLayout<IxD>,
486    view: TensorMut<'a, T, B, IxD>,
487}
488
489impl<T, B> IndexedIterAxesMut<'_, T, B>
490where
491    B: DeviceAPI<T>,
492{
493    pub fn update_offset(&mut self, offset: usize) {
494        unsafe { self.view.layout.set_offset(offset) };
495    }
496}
497
498impl<'a, T, B> Iterator for IndexedIterAxesMut<'a, T, B>
499where
500    B: DeviceAPI<T>,
501{
502    type Item = (IxD, TensorMut<'a, T, B, IxD>);
503
504    fn next(&mut self) -> Option<Self::Item> {
505        let index = match &self.axes_iter {
506            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
507            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
508        };
509        self.axes_iter.next().map(|offset| {
510            self.update_offset(offset);
511            unsafe { transmute((index, self.view.view_mut())) }
512        })
513    }
514}
515
516impl<T, B> DoubleEndedIterator for IndexedIterAxesMut<'_, T, B>
517where
518    B: DeviceAPI<T>,
519{
520    fn next_back(&mut self) -> Option<Self::Item> {
521        let index = match &self.axes_iter {
522            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
523            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
524        };
525        self.axes_iter.next_back().map(|offset| {
526            self.update_offset(offset);
527            unsafe { transmute((index, self.view.view_mut())) }
528        })
529    }
530}
531
532impl<T, B> ExactSizeIterator for IndexedIterAxesMut<'_, T, B>
533where
534    B: DeviceAPI<T>,
535{
536    fn len(&self) -> usize {
537        self.axes_iter.len()
538    }
539}
540
541impl<T, B> IterSplitAtAPI for IndexedIterAxesMut<'_, T, B>
542where
543    B: DeviceAPI<T>,
544{
545    fn split_at(mut self, index: usize) -> (Self, Self) {
546        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
547        let view_lhs = unsafe { transmute(self.view.view_mut()) };
548        let lhs = IndexedIterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
549        let rhs = IndexedIterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
550        return (lhs, rhs);
551    }
552}
553
554impl<'a, R, T, B, D> TensorAny<R, T, B, D>
555where
556    T: Clone,
557    R: DataMutAPI<Data = B::Raw>,
558    D: DimAPI,
559    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
560{
561    pub fn indexed_axes_iter_mut_with_order_f<I>(
562        &'a mut self,
563        axes: I,
564        order: TensorIterOrder,
565    ) -> Result<IndexedIterAxesMut<'a, T, B>>
566    where
567        I: TryInto<AxesIndex<isize>, Error = Error>,
568    {
569        // convert axis to negative indexes and sort
570        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
571        let axes: Vec<isize> =
572            axes.try_into()?.as_ref().iter().map(|&v| if v >= 0 { v } else { v + ndim }).collect::<Vec<isize>>();
573        let mut axes_check = axes.clone();
574        axes_check.sort();
575        // check no two axis are the same, and no negative index too small
576        if axes.first().is_some_and(|&v| v < 0) {
577            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
578        }
579        for i in 0..axes_check.len() - 1 {
580            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
581        }
582
583        // get full layout
584        let layout = self.layout().to_dim::<IxD>()?;
585        let shape_full = layout.shape();
586        let stride_full = layout.stride();
587        let offset = layout.offset();
588
589        // get layout for axes_iter
590        let mut shape_axes = vec![];
591        let mut stride_axes = vec![];
592        for &idx in &axes {
593            shape_axes.push(shape_full[idx as usize]);
594            stride_axes.push(stride_full[idx as usize]);
595        }
596        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
597
598        // get layout for inner view
599        let mut shape_inner = vec![];
600        let mut stride_inner = vec![];
601        for idx in 0..ndim {
602            if !axes.contains(&idx) {
603                shape_inner.push(shape_full[idx as usize]);
604                stride_inner.push(stride_full[idx as usize]);
605            }
606        }
607        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
608
609        // create axes iter
610        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
611        let mut view = self.view_mut().into_dyn();
612        view.layout = layout_inner.clone();
613        let iter = IndexedIterAxesMut { axes_iter, view };
614        Ok(iter)
615    }
616
617    pub fn indexed_axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IndexedIterAxesMut<'a, T, B>>
618    where
619        I: TryInto<AxesIndex<isize>, Error = Error>,
620    {
621        let default_order = self.device().default_order();
622        let order = match default_order {
623            RowMajor => TensorIterOrder::C,
624            ColMajor => TensorIterOrder::F,
625        };
626        self.indexed_axes_iter_mut_with_order_f(axes, order)
627    }
628
629    pub fn indexed_axes_iter_mut_with_order<I>(
630        &'a mut self,
631        axes: I,
632        order: TensorIterOrder,
633    ) -> IndexedIterAxesMut<'a, T, B>
634    where
635        I: TryInto<AxesIndex<isize>, Error = Error>,
636    {
637        self.indexed_axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
638    }
639
640    pub fn indexed_axes_iter_mut<I>(&'a mut self, axes: I) -> IndexedIterAxesMut<'a, T, B>
641    where
642        I: TryInto<AxesIndex<isize>, Error = Error>,
643    {
644        self.indexed_axes_iter_mut_f(axes).rstsr_unwrap()
645    }
646}
647
648/* #endregion */
649
650#[cfg(test)]
651mod tests_serial {
652    use super::*;
653
654    #[test]
655    fn test_axes_iter() {
656        let a = arange(120).into_shape([2, 3, 4, 5]);
657        let iter = a.axes_iter_f([0, 2]).unwrap();
658
659        let res = iter
660            .map(|view| {
661                println!("{view:3}");
662                view[[1, 2]]
663            })
664            .collect::<Vec<_>>();
665        #[cfg(not(feature = "col_major"))]
666        {
667            // import numpy as np
668            // a = np.arange(120).reshape(2, 3, 4, 5)
669            // a[:, 1, :, 2].reshape(-1)
670            assert_eq!(res, vec![22, 27, 32, 37, 82, 87, 92, 97]);
671        }
672        #[cfg(feature = "col_major")]
673        {
674            // a = range(0, 119) |> collect;
675            // a = reshape(a, (2, 3, 4, 5));
676            // reshape(a[:, 2, :, 3], 8)'
677            assert_eq!(res, vec![50, 51, 56, 57, 62, 63, 68, 69]);
678        }
679    }
680
681    #[test]
682    fn test_axes_iter_mut() {
683        let mut a = arange(120).into_shape([2, 3, 4, 5]);
684        let iter = a.axes_iter_mut_with_order_f([0, 2], TensorIterOrder::C).unwrap();
685
686        let res = iter
687            .map(|mut view| {
688                view += 1;
689                println!("{view:3}");
690                view[[1, 2]]
691            })
692            .collect::<Vec<_>>();
693        println!("{res:?}");
694        #[cfg(not(feature = "col_major"))]
695        {
696            // import numpy as np
697            // a = np.arange(120).reshape(2, 3, 4, 5)
698            // a[:, 1, :, 2].reshape(-1) + 1
699            assert_eq!(res, vec![23, 28, 33, 38, 83, 88, 93, 98]);
700        }
701        #[cfg(feature = "col_major")]
702        {
703            // a = range(0, 119) |> collect;
704            // a = reshape(a, (2, 3, 4, 5));
705            // reshape(a[:, 2, :, 3]', 8)' .+ 1
706            assert_eq!(res, vec![51, 57, 63, 69, 52, 58, 64, 70]);
707        }
708    }
709
710    #[test]
711    fn test_indexed_axes_iter() {
712        let a = arange(120).into_shape([2, 3, 4, 5]);
713        let iter = a.indexed_axes_iter([0, 2]);
714
715        let res = iter
716            .map(|(index, view)| {
717                println!("{index:?}");
718                println!("{view:3}");
719                (index, view[[1, 2]])
720            })
721            .collect::<Vec<_>>();
722        #[cfg(not(feature = "col_major"))]
723        {
724            // import numpy as np
725            // a = np.arange(120).reshape(2, 3, 4, 5)
726            // a[:, 1, :, 2].reshape(-1)
727            assert_eq!(res, vec![
728                (vec![0, 0], 22),
729                (vec![0, 1], 27),
730                (vec![0, 2], 32),
731                (vec![0, 3], 37),
732                (vec![1, 0], 82),
733                (vec![1, 1], 87),
734                (vec![1, 2], 92),
735                (vec![1, 3], 97)
736            ]);
737        }
738        #[cfg(feature = "col_major")]
739        {
740            // a = range(0, 119) |> collect;
741            // a = reshape(a, (2, 3, 4, 5));
742            // reshape(a[:, 2, :, 3], 8)'
743            assert_eq!(res, vec![
744                (vec![0, 0], 50),
745                (vec![1, 0], 51),
746                (vec![0, 1], 56),
747                (vec![1, 1], 57),
748                (vec![0, 2], 62),
749                (vec![1, 2], 63),
750                (vec![0, 3], 68),
751                (vec![1, 3], 69)
752            ]);
753        }
754    }
755}
756
757#[cfg(test)]
758#[cfg(feature = "rayon")]
759mod tests_parallel {
760    use super::*;
761    use rayon::prelude::*;
762
763    #[test]
764    fn test_axes_iter() {
765        let mut a = arange(65536).into_shape([16, 16, 16, 16]);
766        let iter = a.axes_iter_mut([0, 2]);
767
768        let res = iter
769            .into_par_iter()
770            .map(|mut view| {
771                view += 1;
772                println!("{view:6}");
773                view[[1, 2]]
774            })
775            .collect::<Vec<_>>();
776        println!("{res:?}");
777        #[cfg(not(feature = "col_major"))]
778        {
779            // a = np.arange(65536).reshape(16, 16, 16, 16)
780            // a[:, 1, :, 2].reshape(-1)[:17] + 1
781            assert_eq!(res[..17], vec![
782                259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467, 483, 499, 4355
783            ]);
784        }
785        #[cfg(feature = "col_major")]
786        {
787            // a = range(0, 65535) |> collect;
788            // a = reshape(a, (16, 16, 16, 16))
789            // (reshape(a[:, 2, :, 3], 16 * 16) .+ 1)[1:17]
790            assert_eq!(res[..17], vec![
791                8209, 8210, 8211, 8212, 8213, 8214, 8215, 8216, 8217, 8218, 8219, 8220, 8221, 8222, 8223, 8224, 8465
792            ]);
793        }
794    }
795}