Skip to main content

rstsr_core/tensor/
iterator_axes.rs

1#![allow(clippy::missing_transmute_annotations)]
2
3use crate::prelude_dev::*;
4use core::mem::transmute;
5
6/* #region axes iter view iterator */
7
8pub struct IterAxesView<'a, T, B>
9where
10    B: DeviceAPI<T>,
11{
12    axes_iter: IterLayout<IxD>,
13    view: TensorView<'a, T, B, IxD>,
14}
15
16impl<T, B> IterAxesView<'_, T, B>
17where
18    B: DeviceAPI<T>,
19{
20    pub fn update_offset(&mut self, offset: usize) {
21        unsafe { self.view.layout.set_offset(offset) };
22    }
23}
24
25impl<'a, T, B> Iterator for IterAxesView<'a, T, B>
26where
27    B: DeviceAPI<T>,
28{
29    type Item = TensorView<'a, T, B, IxD>;
30
31    fn next(&mut self) -> Option<Self::Item> {
32        self.axes_iter.next().map(|offset| {
33            self.update_offset(offset);
34            unsafe { transmute(self.view.view()) }
35        })
36    }
37}
38
39impl<T, B> DoubleEndedIterator for IterAxesView<'_, T, B>
40where
41    B: DeviceAPI<T>,
42{
43    fn next_back(&mut self) -> Option<Self::Item> {
44        self.axes_iter.next_back().map(|offset| {
45            self.update_offset(offset);
46            unsafe { transmute(self.view.view()) }
47        })
48    }
49}
50
51impl<T, B> ExactSizeIterator for IterAxesView<'_, T, B>
52where
53    B: DeviceAPI<T>,
54{
55    fn len(&self) -> usize {
56        self.axes_iter.len()
57    }
58}
59
60impl<T, B> IterSplitAtAPI for IterAxesView<'_, T, B>
61where
62    B: DeviceAPI<T>,
63{
64    fn split_at(self, index: usize) -> (Self, Self) {
65        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
66        let view_lhs = unsafe { transmute(self.view.view()) };
67        let lhs = IterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
68        let rhs = IterAxesView { axes_iter: rhs_axes_iter, view: self.view };
69        return (lhs, rhs);
70    }
71}
72
73impl<'a, R, T, B, D> TensorAny<R, T, B, D>
74where
75    T: Clone,
76    R: DataCloneAPI<Data = B::Raw>,
77    D: DimAPI,
78    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
79{
80    pub fn axes_iter_with_order_f<I>(&self, axes: I, order: TensorIterOrder) -> Result<IterAxesView<'a, T, B>>
81    where
82        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
83    {
84        // convert axis to negative indexes and sort
85        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
86        let axes: Vec<isize> = axes
87            .try_into()
88            .map_err(Into::into)?
89            .as_ref()
90            .iter()
91            .map(|&v| if v >= 0 { v } else { v + ndim })
92            .collect::<Vec<isize>>();
93        let mut axes_check = axes.clone();
94        axes_check.sort();
95        // check no two axis are the same, and no negative index too small
96        if axes.first().is_some_and(|&v| v < 0) {
97            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
98        }
99        for i in 0..axes_check.len() - 1 {
100            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
101        }
102
103        // get full layout
104        let layout = self.layout().to_dim::<IxD>()?;
105        let shape_full = layout.shape();
106        let stride_full = layout.stride();
107        let offset = layout.offset();
108
109        // get layout for axes_iter
110        let mut shape_axes = vec![];
111        let mut stride_axes = vec![];
112        for &idx in &axes {
113            shape_axes.push(shape_full[idx as usize]);
114            stride_axes.push(stride_full[idx as usize]);
115        }
116        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
117
118        // get layout for inner view
119        let mut shape_inner = vec![];
120        let mut stride_inner = vec![];
121        for idx in 0..ndim {
122            if !axes.contains(&idx) {
123                shape_inner.push(shape_full[idx as usize]);
124                stride_inner.push(stride_full[idx as usize]);
125            }
126        }
127        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
128
129        // create axes iter
130        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
131        let mut view = self.view().into_dyn();
132        view.layout = layout_inner.clone();
133        let iter = IterAxesView { axes_iter, view: unsafe { transmute(view) } };
134        Ok(iter)
135    }
136
137    pub fn axes_iter_f<I>(&self, axes: I) -> Result<IterAxesView<'a, T, B>>
138    where
139        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
140    {
141        self.axes_iter_with_order_f(axes, TensorIterOrder::default())
142    }
143
144    pub fn axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IterAxesView<'a, T, B>
145    where
146        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
147    {
148        self.axes_iter_with_order_f(axes, order).rstsr_unwrap()
149    }
150
151    pub fn axes_iter<I>(&self, axes: I) -> IterAxesView<'a, T, B>
152    where
153        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
154    {
155        self.axes_iter_f(axes).rstsr_unwrap()
156    }
157}
158
159/* #endregion */
160
161/* #region axes iter mut iterator */
162
163pub struct IterAxesMut<'a, T, B>
164where
165    B: DeviceAPI<T>,
166{
167    axes_iter: IterLayout<IxD>,
168    view: TensorMut<'a, T, B, IxD>,
169}
170
171impl<T, B> IterAxesMut<'_, T, B>
172where
173    B: DeviceAPI<T>,
174{
175    pub fn update_offset(&mut self, offset: usize) {
176        unsafe { self.view.layout.set_offset(offset) };
177    }
178}
179
180impl<'a, T, B> Iterator for IterAxesMut<'a, T, B>
181where
182    B: DeviceAPI<T>,
183{
184    type Item = TensorMut<'a, T, B, IxD>;
185
186    fn next(&mut self) -> Option<Self::Item> {
187        self.axes_iter.next().map(|offset| {
188            self.update_offset(offset);
189            unsafe { transmute(self.view.view_mut()) }
190        })
191    }
192}
193
194impl<T, B> DoubleEndedIterator for IterAxesMut<'_, T, B>
195where
196    B: DeviceAPI<T>,
197{
198    fn next_back(&mut self) -> Option<Self::Item> {
199        self.axes_iter.next_back().map(|offset| {
200            self.update_offset(offset);
201            unsafe { transmute(self.view.view_mut()) }
202        })
203    }
204}
205
206impl<T, B> ExactSizeIterator for IterAxesMut<'_, T, B>
207where
208    B: DeviceAPI<T>,
209{
210    fn len(&self) -> usize {
211        self.axes_iter.len()
212    }
213}
214
215impl<T, B> IterSplitAtAPI for IterAxesMut<'_, T, B>
216where
217    B: DeviceAPI<T>,
218{
219    fn split_at(mut self, index: usize) -> (Self, Self) {
220        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
221        let view_lhs = unsafe { transmute(self.view.view_mut()) };
222        let lhs = IterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
223        let rhs = IterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
224        return (lhs, rhs);
225    }
226}
227
228impl<'a, R, T, B, D> TensorAny<R, T, B, D>
229where
230    T: Clone,
231    R: DataMutAPI<Data = B::Raw>,
232    D: DimAPI,
233    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
234{
235    pub fn axes_iter_mut_with_order_f<I>(&'a mut self, axes: I, order: TensorIterOrder) -> Result<IterAxesMut<'a, T, B>>
236    where
237        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238    {
239        // convert axis to negative indexes and sort
240        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
241        let axes: Vec<isize> = axes
242            .try_into()
243            .map_err(Into::into)?
244            .as_ref()
245            .iter()
246            .map(|&v| if v >= 0 { v } else { v + ndim })
247            .collect::<Vec<isize>>();
248        let mut axes_check = axes.clone();
249        axes_check.sort();
250        // check no two axis are the same, and no negative index too small
251        if axes.first().is_some_and(|&v| v < 0) {
252            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
253        }
254        for i in 0..axes_check.len() - 1 {
255            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
256        }
257
258        // get full layout
259        let layout = self.layout().to_dim::<IxD>()?;
260        let shape_full = layout.shape();
261        let stride_full = layout.stride();
262        let offset = layout.offset();
263
264        // get layout for axes_iter
265        let mut shape_axes = vec![];
266        let mut stride_axes = vec![];
267        for &idx in &axes {
268            shape_axes.push(shape_full[idx as usize]);
269            stride_axes.push(stride_full[idx as usize]);
270        }
271        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
272
273        // get layout for inner view
274        let mut shape_inner = vec![];
275        let mut stride_inner = vec![];
276        for idx in 0..ndim {
277            if !axes.contains(&idx) {
278                shape_inner.push(shape_full[idx as usize]);
279                stride_inner.push(stride_full[idx as usize]);
280            }
281        }
282        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
283
284        // create axes iter
285        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
286        let mut view = self.view_mut().into_dyn();
287        view.layout = layout_inner.clone();
288        let iter = IterAxesMut { axes_iter, view };
289        Ok(iter)
290    }
291
292    pub fn axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IterAxesMut<'a, T, B>>
293    where
294        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
295    {
296        self.axes_iter_mut_with_order_f(axes, TensorIterOrder::default())
297    }
298
299    pub fn axes_iter_mut_with_order<I>(&'a mut self, axes: I, order: TensorIterOrder) -> IterAxesMut<'a, T, B>
300    where
301        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
302    {
303        self.axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
304    }
305
306    pub fn axes_iter_mut<I>(&'a mut self, axes: I) -> IterAxesMut<'a, T, B>
307    where
308        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
309    {
310        self.axes_iter_mut_f(axes).rstsr_unwrap()
311    }
312}
313
314/* #endregion */
315
316/* #region indexed axes iter view iterator */
317
318pub struct IndexedIterAxesView<'a, T, B>
319where
320    B: DeviceAPI<T>,
321{
322    axes_iter: IterLayout<IxD>,
323    view: TensorView<'a, T, B, IxD>,
324}
325
326impl<T, B> IndexedIterAxesView<'_, T, B>
327where
328    B: DeviceAPI<T>,
329{
330    pub fn update_offset(&mut self, offset: usize) {
331        unsafe { self.view.layout.set_offset(offset) };
332    }
333}
334
335impl<'a, T, B> Iterator for IndexedIterAxesView<'a, T, B>
336where
337    B: DeviceAPI<T>,
338{
339    type Item = (IxD, TensorView<'a, T, B, IxD>);
340
341    fn next(&mut self) -> Option<Self::Item> {
342        let index = match &self.axes_iter {
343            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
344            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
345        };
346        self.axes_iter.next().map(|offset| {
347            self.update_offset(offset);
348            (index, unsafe { transmute(self.view.view()) })
349        })
350    }
351}
352
353impl<T, B> DoubleEndedIterator for IndexedIterAxesView<'_, T, B>
354where
355    B: DeviceAPI<T>,
356{
357    fn next_back(&mut self) -> Option<Self::Item> {
358        let index = match &self.axes_iter {
359            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
360            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
361        };
362        self.axes_iter.next_back().map(|offset| {
363            self.update_offset(offset);
364            (index, unsafe { transmute(self.view.view()) })
365        })
366    }
367}
368
369impl<T, B> ExactSizeIterator for IndexedIterAxesView<'_, T, B>
370where
371    B: DeviceAPI<T>,
372{
373    fn len(&self) -> usize {
374        self.axes_iter.len()
375    }
376}
377
378impl<T, B> IterSplitAtAPI for IndexedIterAxesView<'_, T, B>
379where
380    B: DeviceAPI<T>,
381{
382    fn split_at(self, index: usize) -> (Self, Self) {
383        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
384        let view_lhs = unsafe { transmute(self.view.view()) };
385        let lhs = IndexedIterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
386        let rhs = IndexedIterAxesView { axes_iter: rhs_axes_iter, view: self.view };
387        return (lhs, rhs);
388    }
389}
390
391impl<'a, R, T, B, D> TensorAny<R, T, B, D>
392where
393    T: Clone,
394    R: DataCloneAPI<Data = B::Raw>,
395    D: DimAPI,
396    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
397{
398    pub fn indexed_axes_iter_with_order_f<I>(
399        &self,
400        axes: I,
401        order: TensorIterOrder,
402    ) -> Result<IndexedIterAxesView<'a, T, B>>
403    where
404        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
405    {
406        use TensorIterOrder::*;
407        // this function only accepts c/f iter order currently
408        match order {
409            C | F => (),
410            _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
411        };
412        // convert axis to negative indexes and sort
413        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
414        let axes: Vec<isize> = axes
415            .try_into()
416            .map_err(Into::into)?
417            .as_ref()
418            .iter()
419            .map(|&v| if v >= 0 { v } else { v + ndim })
420            .collect::<Vec<isize>>();
421        let mut axes_check = axes.clone();
422        axes_check.sort();
423        // check no two axis are the same, and no negative index too small
424        if axes.first().is_some_and(|&v| v < 0) {
425            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
426        }
427        for i in 0..axes_check.len() - 1 {
428            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
429        }
430
431        // get full layout
432        let layout = self.layout().to_dim::<IxD>()?;
433        let shape_full = layout.shape();
434        let stride_full = layout.stride();
435        let offset = layout.offset();
436
437        // get layout for axes_iter
438        let mut shape_axes = vec![];
439        let mut stride_axes = vec![];
440        for &idx in &axes {
441            shape_axes.push(shape_full[idx as usize]);
442            stride_axes.push(stride_full[idx as usize]);
443        }
444        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
445
446        // get layout for inner view
447        let mut shape_inner = vec![];
448        let mut stride_inner = vec![];
449        for idx in 0..ndim {
450            if !axes.contains(&idx) {
451                shape_inner.push(shape_full[idx as usize]);
452                stride_inner.push(stride_full[idx as usize]);
453            }
454        }
455        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
456
457        // create axes iter
458        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
459        let mut view = self.view().into_dyn();
460        view.layout = layout_inner.clone();
461        let iter = IndexedIterAxesView { axes_iter, view: unsafe { transmute(view) } };
462        Ok(iter)
463    }
464
465    pub fn indexed_axes_iter_f<I>(&self, axes: I) -> Result<IndexedIterAxesView<'a, T, B>>
466    where
467        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
468    {
469        let default_order = self.device().default_order();
470        let order = match default_order {
471            RowMajor => TensorIterOrder::C,
472            ColMajor => TensorIterOrder::F,
473        };
474        self.indexed_axes_iter_with_order_f(axes, order)
475    }
476
477    pub fn indexed_axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IndexedIterAxesView<'a, T, B>
478    where
479        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
480    {
481        self.indexed_axes_iter_with_order_f(axes, order).rstsr_unwrap()
482    }
483
484    pub fn indexed_axes_iter<I>(&self, axes: I) -> IndexedIterAxesView<'a, T, B>
485    where
486        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
487    {
488        self.indexed_axes_iter_f(axes).rstsr_unwrap()
489    }
490}
491
492/* #endregion */
493
494/* #region axes iter mut iterator */
495
496pub struct IndexedIterAxesMut<'a, T, B>
497where
498    B: DeviceAPI<T>,
499{
500    axes_iter: IterLayout<IxD>,
501    view: TensorMut<'a, T, B, IxD>,
502}
503
504impl<T, B> IndexedIterAxesMut<'_, T, B>
505where
506    B: DeviceAPI<T>,
507{
508    pub fn update_offset(&mut self, offset: usize) {
509        unsafe { self.view.layout.set_offset(offset) };
510    }
511}
512
513impl<'a, T, B> Iterator for IndexedIterAxesMut<'a, T, B>
514where
515    B: DeviceAPI<T>,
516{
517    type Item = (IxD, TensorMut<'a, T, B, IxD>);
518
519    fn next(&mut self) -> Option<Self::Item> {
520        let index = match &self.axes_iter {
521            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
522            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
523        };
524        self.axes_iter.next().map(|offset| {
525            self.update_offset(offset);
526            unsafe { transmute((index, self.view.view_mut())) }
527        })
528    }
529}
530
531impl<T, B> DoubleEndedIterator for IndexedIterAxesMut<'_, T, B>
532where
533    B: DeviceAPI<T>,
534{
535    fn next_back(&mut self) -> Option<Self::Item> {
536        let index = match &self.axes_iter {
537            IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
538            IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
539        };
540        self.axes_iter.next_back().map(|offset| {
541            self.update_offset(offset);
542            unsafe { transmute((index, self.view.view_mut())) }
543        })
544    }
545}
546
547impl<T, B> ExactSizeIterator for IndexedIterAxesMut<'_, T, B>
548where
549    B: DeviceAPI<T>,
550{
551    fn len(&self) -> usize {
552        self.axes_iter.len()
553    }
554}
555
556impl<T, B> IterSplitAtAPI for IndexedIterAxesMut<'_, T, B>
557where
558    B: DeviceAPI<T>,
559{
560    fn split_at(mut self, index: usize) -> (Self, Self) {
561        let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
562        let view_lhs = unsafe { transmute(self.view.view_mut()) };
563        let lhs = IndexedIterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
564        let rhs = IndexedIterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
565        return (lhs, rhs);
566    }
567}
568
569impl<'a, R, T, B, D> TensorAny<R, T, B, D>
570where
571    T: Clone,
572    R: DataMutAPI<Data = B::Raw>,
573    D: DimAPI,
574    B: DeviceAPI<T, Raw = Vec<T>> + 'a,
575{
576    pub fn indexed_axes_iter_mut_with_order_f<I>(
577        &'a mut self,
578        axes: I,
579        order: TensorIterOrder,
580    ) -> Result<IndexedIterAxesMut<'a, T, B>>
581    where
582        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
583    {
584        // convert axis to negative indexes and sort
585        let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
586        let axes: Vec<isize> = axes
587            .try_into()
588            .map_err(Into::into)?
589            .as_ref()
590            .iter()
591            .map(|&v| if v >= 0 { v } else { v + ndim })
592            .collect::<Vec<isize>>();
593        let mut axes_check = axes.clone();
594        axes_check.sort();
595        // check no two axis are the same, and no negative index too small
596        if axes.first().is_some_and(|&v| v < 0) {
597            return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
598        }
599        for i in 0..axes_check.len() - 1 {
600            rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
601        }
602
603        // get full layout
604        let layout = self.layout().to_dim::<IxD>()?;
605        let shape_full = layout.shape();
606        let stride_full = layout.stride();
607        let offset = layout.offset();
608
609        // get layout for axes_iter
610        let mut shape_axes = vec![];
611        let mut stride_axes = vec![];
612        for &idx in &axes {
613            shape_axes.push(shape_full[idx as usize]);
614            stride_axes.push(stride_full[idx as usize]);
615        }
616        let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
617
618        // get layout for inner view
619        let mut shape_inner = vec![];
620        let mut stride_inner = vec![];
621        for idx in 0..ndim {
622            if !axes.contains(&idx) {
623                shape_inner.push(shape_full[idx as usize]);
624                stride_inner.push(stride_full[idx as usize]);
625            }
626        }
627        let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
628
629        // create axes iter
630        let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
631        let mut view = self.view_mut().into_dyn();
632        view.layout = layout_inner.clone();
633        let iter = IndexedIterAxesMut { axes_iter, view };
634        Ok(iter)
635    }
636
637    pub fn indexed_axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IndexedIterAxesMut<'a, T, B>>
638    where
639        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
640    {
641        let default_order = self.device().default_order();
642        let order = match default_order {
643            RowMajor => TensorIterOrder::C,
644            ColMajor => TensorIterOrder::F,
645        };
646        self.indexed_axes_iter_mut_with_order_f(axes, order)
647    }
648
649    pub fn indexed_axes_iter_mut_with_order<I>(
650        &'a mut self,
651        axes: I,
652        order: TensorIterOrder,
653    ) -> IndexedIterAxesMut<'a, T, B>
654    where
655        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
656    {
657        self.indexed_axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
658    }
659
660    pub fn indexed_axes_iter_mut<I>(&'a mut self, axes: I) -> IndexedIterAxesMut<'a, T, B>
661    where
662        I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
663    {
664        self.indexed_axes_iter_mut_f(axes).rstsr_unwrap()
665    }
666}
667
668/* #endregion */
669
670#[cfg(test)]
671mod tests_serial {
672    use super::*;
673
674    #[test]
675    fn test_axes_iter() {
676        let a = arange(120).into_shape([2, 3, 4, 5]);
677        let iter = a.axes_iter_f([0, 2]).unwrap();
678
679        let res = iter
680            .map(|view| {
681                println!("{view:3}");
682                view[[1, 2]]
683            })
684            .collect::<Vec<_>>();
685        #[cfg(not(feature = "col_major"))]
686        {
687            // import numpy as np
688            // a = np.arange(120).reshape(2, 3, 4, 5)
689            // a[:, 1, :, 2].reshape(-1)
690            assert_eq!(res, vec![22, 27, 32, 37, 82, 87, 92, 97]);
691        }
692        #[cfg(feature = "col_major")]
693        {
694            // a = range(0, 119) |> collect;
695            // a = reshape(a, (2, 3, 4, 5));
696            // reshape(a[:, 2, :, 3], 8)'
697            assert_eq!(res, vec![50, 51, 56, 57, 62, 63, 68, 69]);
698        }
699    }
700
701    #[test]
702    fn test_axes_iter_mut() {
703        let mut a = arange(120).into_shape([2, 3, 4, 5]);
704        let iter = a.axes_iter_mut_with_order_f([0, 2], TensorIterOrder::C).unwrap();
705
706        let res = iter
707            .map(|mut view| {
708                view += 1;
709                println!("{view:3}");
710                view[[1, 2]]
711            })
712            .collect::<Vec<_>>();
713        println!("{res:?}");
714        #[cfg(not(feature = "col_major"))]
715        {
716            // import numpy as np
717            // a = np.arange(120).reshape(2, 3, 4, 5)
718            // a[:, 1, :, 2].reshape(-1) + 1
719            assert_eq!(res, vec![23, 28, 33, 38, 83, 88, 93, 98]);
720        }
721        #[cfg(feature = "col_major")]
722        {
723            // a = range(0, 119) |> collect;
724            // a = reshape(a, (2, 3, 4, 5));
725            // reshape(a[:, 2, :, 3]', 8)' .+ 1
726            assert_eq!(res, vec![51, 57, 63, 69, 52, 58, 64, 70]);
727        }
728    }
729
730    #[test]
731    fn test_indexed_axes_iter() {
732        let a = arange(120).into_shape([2, 3, 4, 5]);
733        let iter = a.indexed_axes_iter([0, 2]);
734
735        let res = iter
736            .map(|(index, view)| {
737                println!("{index:?}");
738                println!("{view:3}");
739                (index, view[[1, 2]])
740            })
741            .collect::<Vec<_>>();
742        #[cfg(not(feature = "col_major"))]
743        {
744            // import numpy as np
745            // a = np.arange(120).reshape(2, 3, 4, 5)
746            // a[:, 1, :, 2].reshape(-1)
747            assert_eq!(res, vec![
748                (vec![0, 0], 22),
749                (vec![0, 1], 27),
750                (vec![0, 2], 32),
751                (vec![0, 3], 37),
752                (vec![1, 0], 82),
753                (vec![1, 1], 87),
754                (vec![1, 2], 92),
755                (vec![1, 3], 97)
756            ]);
757        }
758        #[cfg(feature = "col_major")]
759        {
760            // a = range(0, 119) |> collect;
761            // a = reshape(a, (2, 3, 4, 5));
762            // reshape(a[:, 2, :, 3], 8)'
763            assert_eq!(res, vec![
764                (vec![0, 0], 50),
765                (vec![1, 0], 51),
766                (vec![0, 1], 56),
767                (vec![1, 1], 57),
768                (vec![0, 2], 62),
769                (vec![1, 2], 63),
770                (vec![0, 3], 68),
771                (vec![1, 3], 69)
772            ]);
773        }
774    }
775}
776
777#[cfg(test)]
778#[cfg(feature = "rayon")]
779mod tests_parallel {
780    use super::*;
781    use rayon::prelude::*;
782
783    #[test]
784    fn test_axes_iter() {
785        let mut a = arange(65536).into_shape([16, 16, 16, 16]);
786        let iter = a.axes_iter_mut([0, 2]);
787
788        let res = iter
789            .into_par_iter()
790            .map(|mut view| {
791                view += 1;
792                println!("{view:6}");
793                view[[1, 2]]
794            })
795            .collect::<Vec<_>>();
796        println!("{res:?}");
797        #[cfg(not(feature = "col_major"))]
798        {
799            // a = np.arange(65536).reshape(16, 16, 16, 16)
800            // a[:, 1, :, 2].reshape(-1)[:17] + 1
801            assert_eq!(res[..17], vec![
802                259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467, 483, 499, 4355
803            ]);
804        }
805        #[cfg(feature = "col_major")]
806        {
807            // a = range(0, 65535) |> collect;
808            // a = reshape(a, (16, 16, 16, 16))
809            // (reshape(a[:, 2, :, 3], 16 * 16) .+ 1)[1:17]
810            assert_eq!(res[..17], vec![
811                8209, 8210, 8211, 8212, 8213, 8214, 8215, 8216, 8217, 8218, 8219, 8220, 8221, 8222, 8223, 8224, 8465
812            ]);
813        }
814    }
815}