1#![allow(clippy::missing_transmute_annotations)]
2
3use crate::prelude_dev::*;
4use core::mem::transmute;
5
6pub struct IterAxesView<'a, T, B>
9where
10 B: DeviceAPI<T>,
11{
12 axes_iter: IterLayout<IxD>,
13 view: TensorView<'a, T, B, IxD>,
14}
15
16impl<T, B> IterAxesView<'_, T, B>
17where
18 B: DeviceAPI<T>,
19{
20 pub fn update_offset(&mut self, offset: usize) {
21 unsafe { self.view.layout.set_offset(offset) };
22 }
23}
24
25impl<'a, T, B> Iterator for IterAxesView<'a, T, B>
26where
27 B: DeviceAPI<T>,
28{
29 type Item = TensorView<'a, T, B, IxD>;
30
31 fn next(&mut self) -> Option<Self::Item> {
32 self.axes_iter.next().map(|offset| {
33 self.update_offset(offset);
34 unsafe { transmute(self.view.view()) }
35 })
36 }
37}
38
39impl<T, B> DoubleEndedIterator for IterAxesView<'_, T, B>
40where
41 B: DeviceAPI<T>,
42{
43 fn next_back(&mut self) -> Option<Self::Item> {
44 self.axes_iter.next_back().map(|offset| {
45 self.update_offset(offset);
46 unsafe { transmute(self.view.view()) }
47 })
48 }
49}
50
51impl<T, B> ExactSizeIterator for IterAxesView<'_, T, B>
52where
53 B: DeviceAPI<T>,
54{
55 fn len(&self) -> usize {
56 self.axes_iter.len()
57 }
58}
59
60impl<T, B> IterSplitAtAPI for IterAxesView<'_, T, B>
61where
62 B: DeviceAPI<T>,
63{
64 fn split_at(self, index: usize) -> (Self, Self) {
65 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
66 let view_lhs = unsafe { transmute(self.view.view()) };
67 let lhs = IterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
68 let rhs = IterAxesView { axes_iter: rhs_axes_iter, view: self.view };
69 return (lhs, rhs);
70 }
71}
72
73impl<'a, R, T, B, D> TensorAny<R, T, B, D>
74where
75 T: Clone,
76 R: DataCloneAPI<Data = B::Raw>,
77 D: DimAPI,
78 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
79{
80 pub fn axes_iter_with_order_f<I>(&self, axes: I, order: TensorIterOrder) -> Result<IterAxesView<'a, T, B>>
81 where
82 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
83 {
84 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
86 let axes: Vec<isize> = axes
87 .try_into()
88 .map_err(Into::into)?
89 .as_ref()
90 .iter()
91 .map(|&v| if v >= 0 { v } else { v + ndim })
92 .collect::<Vec<isize>>();
93 let mut axes_check = axes.clone();
94 axes_check.sort();
95 if axes.first().is_some_and(|&v| v < 0) {
97 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
98 }
99 for i in 0..axes_check.len() - 1 {
100 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
101 }
102
103 let layout = self.layout().to_dim::<IxD>()?;
105 let shape_full = layout.shape();
106 let stride_full = layout.stride();
107 let offset = layout.offset();
108
109 let mut shape_axes = vec![];
111 let mut stride_axes = vec![];
112 for &idx in &axes {
113 shape_axes.push(shape_full[idx as usize]);
114 stride_axes.push(stride_full[idx as usize]);
115 }
116 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
117
118 let mut shape_inner = vec![];
120 let mut stride_inner = vec![];
121 for idx in 0..ndim {
122 if !axes.contains(&idx) {
123 shape_inner.push(shape_full[idx as usize]);
124 stride_inner.push(stride_full[idx as usize]);
125 }
126 }
127 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
128
129 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
131 let mut view = self.view().into_dyn();
132 view.layout = layout_inner.clone();
133 let iter = IterAxesView { axes_iter, view: unsafe { transmute(view) } };
134 Ok(iter)
135 }
136
137 pub fn axes_iter_f<I>(&self, axes: I) -> Result<IterAxesView<'a, T, B>>
138 where
139 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
140 {
141 self.axes_iter_with_order_f(axes, TensorIterOrder::default())
142 }
143
144 pub fn axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IterAxesView<'a, T, B>
145 where
146 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
147 {
148 self.axes_iter_with_order_f(axes, order).rstsr_unwrap()
149 }
150
151 pub fn axes_iter<I>(&self, axes: I) -> IterAxesView<'a, T, B>
152 where
153 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
154 {
155 self.axes_iter_f(axes).rstsr_unwrap()
156 }
157}
158
159pub struct IterAxesMut<'a, T, B>
164where
165 B: DeviceAPI<T>,
166{
167 axes_iter: IterLayout<IxD>,
168 view: TensorMut<'a, T, B, IxD>,
169}
170
171impl<T, B> IterAxesMut<'_, T, B>
172where
173 B: DeviceAPI<T>,
174{
175 pub fn update_offset(&mut self, offset: usize) {
176 unsafe { self.view.layout.set_offset(offset) };
177 }
178}
179
180impl<'a, T, B> Iterator for IterAxesMut<'a, T, B>
181where
182 B: DeviceAPI<T>,
183{
184 type Item = TensorMut<'a, T, B, IxD>;
185
186 fn next(&mut self) -> Option<Self::Item> {
187 self.axes_iter.next().map(|offset| {
188 self.update_offset(offset);
189 unsafe { transmute(self.view.view_mut()) }
190 })
191 }
192}
193
194impl<T, B> DoubleEndedIterator for IterAxesMut<'_, T, B>
195where
196 B: DeviceAPI<T>,
197{
198 fn next_back(&mut self) -> Option<Self::Item> {
199 self.axes_iter.next_back().map(|offset| {
200 self.update_offset(offset);
201 unsafe { transmute(self.view.view_mut()) }
202 })
203 }
204}
205
206impl<T, B> ExactSizeIterator for IterAxesMut<'_, T, B>
207where
208 B: DeviceAPI<T>,
209{
210 fn len(&self) -> usize {
211 self.axes_iter.len()
212 }
213}
214
215impl<T, B> IterSplitAtAPI for IterAxesMut<'_, T, B>
216where
217 B: DeviceAPI<T>,
218{
219 fn split_at(mut self, index: usize) -> (Self, Self) {
220 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
221 let view_lhs = unsafe { transmute(self.view.view_mut()) };
222 let lhs = IterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
223 let rhs = IterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
224 return (lhs, rhs);
225 }
226}
227
228impl<'a, R, T, B, D> TensorAny<R, T, B, D>
229where
230 T: Clone,
231 R: DataMutAPI<Data = B::Raw>,
232 D: DimAPI,
233 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
234{
235 pub fn axes_iter_mut_with_order_f<I>(&'a mut self, axes: I, order: TensorIterOrder) -> Result<IterAxesMut<'a, T, B>>
236 where
237 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
238 {
239 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
241 let axes: Vec<isize> = axes
242 .try_into()
243 .map_err(Into::into)?
244 .as_ref()
245 .iter()
246 .map(|&v| if v >= 0 { v } else { v + ndim })
247 .collect::<Vec<isize>>();
248 let mut axes_check = axes.clone();
249 axes_check.sort();
250 if axes.first().is_some_and(|&v| v < 0) {
252 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
253 }
254 for i in 0..axes_check.len() - 1 {
255 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
256 }
257
258 let layout = self.layout().to_dim::<IxD>()?;
260 let shape_full = layout.shape();
261 let stride_full = layout.stride();
262 let offset = layout.offset();
263
264 let mut shape_axes = vec![];
266 let mut stride_axes = vec![];
267 for &idx in &axes {
268 shape_axes.push(shape_full[idx as usize]);
269 stride_axes.push(stride_full[idx as usize]);
270 }
271 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
272
273 let mut shape_inner = vec![];
275 let mut stride_inner = vec![];
276 for idx in 0..ndim {
277 if !axes.contains(&idx) {
278 shape_inner.push(shape_full[idx as usize]);
279 stride_inner.push(stride_full[idx as usize]);
280 }
281 }
282 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
283
284 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
286 let mut view = self.view_mut().into_dyn();
287 view.layout = layout_inner.clone();
288 let iter = IterAxesMut { axes_iter, view };
289 Ok(iter)
290 }
291
292 pub fn axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IterAxesMut<'a, T, B>>
293 where
294 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
295 {
296 self.axes_iter_mut_with_order_f(axes, TensorIterOrder::default())
297 }
298
299 pub fn axes_iter_mut_with_order<I>(&'a mut self, axes: I, order: TensorIterOrder) -> IterAxesMut<'a, T, B>
300 where
301 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
302 {
303 self.axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
304 }
305
306 pub fn axes_iter_mut<I>(&'a mut self, axes: I) -> IterAxesMut<'a, T, B>
307 where
308 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
309 {
310 self.axes_iter_mut_f(axes).rstsr_unwrap()
311 }
312}
313
314pub struct IndexedIterAxesView<'a, T, B>
319where
320 B: DeviceAPI<T>,
321{
322 axes_iter: IterLayout<IxD>,
323 view: TensorView<'a, T, B, IxD>,
324}
325
326impl<T, B> IndexedIterAxesView<'_, T, B>
327where
328 B: DeviceAPI<T>,
329{
330 pub fn update_offset(&mut self, offset: usize) {
331 unsafe { self.view.layout.set_offset(offset) };
332 }
333}
334
335impl<'a, T, B> Iterator for IndexedIterAxesView<'a, T, B>
336where
337 B: DeviceAPI<T>,
338{
339 type Item = (IxD, TensorView<'a, T, B, IxD>);
340
341 fn next(&mut self) -> Option<Self::Item> {
342 let index = match &self.axes_iter {
343 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
344 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
345 };
346 self.axes_iter.next().map(|offset| {
347 self.update_offset(offset);
348 (index, unsafe { transmute(self.view.view()) })
349 })
350 }
351}
352
353impl<T, B> DoubleEndedIterator for IndexedIterAxesView<'_, T, B>
354where
355 B: DeviceAPI<T>,
356{
357 fn next_back(&mut self) -> Option<Self::Item> {
358 let index = match &self.axes_iter {
359 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
360 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
361 };
362 self.axes_iter.next_back().map(|offset| {
363 self.update_offset(offset);
364 (index, unsafe { transmute(self.view.view()) })
365 })
366 }
367}
368
369impl<T, B> ExactSizeIterator for IndexedIterAxesView<'_, T, B>
370where
371 B: DeviceAPI<T>,
372{
373 fn len(&self) -> usize {
374 self.axes_iter.len()
375 }
376}
377
378impl<T, B> IterSplitAtAPI for IndexedIterAxesView<'_, T, B>
379where
380 B: DeviceAPI<T>,
381{
382 fn split_at(self, index: usize) -> (Self, Self) {
383 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.split_at(index);
384 let view_lhs = unsafe { transmute(self.view.view()) };
385 let lhs = IndexedIterAxesView { axes_iter: lhs_axes_iter, view: view_lhs };
386 let rhs = IndexedIterAxesView { axes_iter: rhs_axes_iter, view: self.view };
387 return (lhs, rhs);
388 }
389}
390
391impl<'a, R, T, B, D> TensorAny<R, T, B, D>
392where
393 T: Clone,
394 R: DataCloneAPI<Data = B::Raw>,
395 D: DimAPI,
396 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
397{
398 pub fn indexed_axes_iter_with_order_f<I>(
399 &self,
400 axes: I,
401 order: TensorIterOrder,
402 ) -> Result<IndexedIterAxesView<'a, T, B>>
403 where
404 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
405 {
406 use TensorIterOrder::*;
407 match order {
409 C | F => (),
410 _ => rstsr_invalid!(order, "This function only accepts TensorIterOrder::C|F.",)?,
411 };
412 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
414 let axes: Vec<isize> = axes
415 .try_into()
416 .map_err(Into::into)?
417 .as_ref()
418 .iter()
419 .map(|&v| if v >= 0 { v } else { v + ndim })
420 .collect::<Vec<isize>>();
421 let mut axes_check = axes.clone();
422 axes_check.sort();
423 if axes.first().is_some_and(|&v| v < 0) {
425 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
426 }
427 for i in 0..axes_check.len() - 1 {
428 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
429 }
430
431 let layout = self.layout().to_dim::<IxD>()?;
433 let shape_full = layout.shape();
434 let stride_full = layout.stride();
435 let offset = layout.offset();
436
437 let mut shape_axes = vec![];
439 let mut stride_axes = vec![];
440 for &idx in &axes {
441 shape_axes.push(shape_full[idx as usize]);
442 stride_axes.push(stride_full[idx as usize]);
443 }
444 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
445
446 let mut shape_inner = vec![];
448 let mut stride_inner = vec![];
449 for idx in 0..ndim {
450 if !axes.contains(&idx) {
451 shape_inner.push(shape_full[idx as usize]);
452 stride_inner.push(stride_full[idx as usize]);
453 }
454 }
455 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
456
457 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
459 let mut view = self.view().into_dyn();
460 view.layout = layout_inner.clone();
461 let iter = IndexedIterAxesView { axes_iter, view: unsafe { transmute(view) } };
462 Ok(iter)
463 }
464
465 pub fn indexed_axes_iter_f<I>(&self, axes: I) -> Result<IndexedIterAxesView<'a, T, B>>
466 where
467 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
468 {
469 let default_order = self.device().default_order();
470 let order = match default_order {
471 RowMajor => TensorIterOrder::C,
472 ColMajor => TensorIterOrder::F,
473 };
474 self.indexed_axes_iter_with_order_f(axes, order)
475 }
476
477 pub fn indexed_axes_iter_with_order<I>(&self, axes: I, order: TensorIterOrder) -> IndexedIterAxesView<'a, T, B>
478 where
479 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
480 {
481 self.indexed_axes_iter_with_order_f(axes, order).rstsr_unwrap()
482 }
483
484 pub fn indexed_axes_iter<I>(&self, axes: I) -> IndexedIterAxesView<'a, T, B>
485 where
486 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
487 {
488 self.indexed_axes_iter_f(axes).rstsr_unwrap()
489 }
490}
491
492pub struct IndexedIterAxesMut<'a, T, B>
497where
498 B: DeviceAPI<T>,
499{
500 axes_iter: IterLayout<IxD>,
501 view: TensorMut<'a, T, B, IxD>,
502}
503
504impl<T, B> IndexedIterAxesMut<'_, T, B>
505where
506 B: DeviceAPI<T>,
507{
508 pub fn update_offset(&mut self, offset: usize) {
509 unsafe { self.view.layout.set_offset(offset) };
510 }
511}
512
513impl<'a, T, B> Iterator for IndexedIterAxesMut<'a, T, B>
514where
515 B: DeviceAPI<T>,
516{
517 type Item = (IxD, TensorMut<'a, T, B, IxD>);
518
519 fn next(&mut self) -> Option<Self::Item> {
520 let index = match &self.axes_iter {
521 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
522 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
523 };
524 self.axes_iter.next().map(|offset| {
525 self.update_offset(offset);
526 unsafe { transmute((index, self.view.view_mut())) }
527 })
528 }
529}
530
531impl<T, B> DoubleEndedIterator for IndexedIterAxesMut<'_, T, B>
532where
533 B: DeviceAPI<T>,
534{
535 fn next_back(&mut self) -> Option<Self::Item> {
536 let index = match &self.axes_iter {
537 IterLayout::ColMajor(iter_inner) => iter_inner.index_start().clone(),
538 IterLayout::RowMajor(iter_inner) => iter_inner.index_start().clone(),
539 };
540 self.axes_iter.next_back().map(|offset| {
541 self.update_offset(offset);
542 unsafe { transmute((index, self.view.view_mut())) }
543 })
544 }
545}
546
547impl<T, B> ExactSizeIterator for IndexedIterAxesMut<'_, T, B>
548where
549 B: DeviceAPI<T>,
550{
551 fn len(&self) -> usize {
552 self.axes_iter.len()
553 }
554}
555
556impl<T, B> IterSplitAtAPI for IndexedIterAxesMut<'_, T, B>
557where
558 B: DeviceAPI<T>,
559{
560 fn split_at(mut self, index: usize) -> (Self, Self) {
561 let (lhs_axes_iter, rhs_axes_iter) = self.axes_iter.clone().split_at(index);
562 let view_lhs = unsafe { transmute(self.view.view_mut()) };
563 let lhs = IndexedIterAxesMut { axes_iter: lhs_axes_iter, view: view_lhs };
564 let rhs = IndexedIterAxesMut { axes_iter: rhs_axes_iter, view: self.view };
565 return (lhs, rhs);
566 }
567}
568
569impl<'a, R, T, B, D> TensorAny<R, T, B, D>
570where
571 T: Clone,
572 R: DataMutAPI<Data = B::Raw>,
573 D: DimAPI,
574 B: DeviceAPI<T, Raw = Vec<T>> + 'a,
575{
576 pub fn indexed_axes_iter_mut_with_order_f<I>(
577 &'a mut self,
578 axes: I,
579 order: TensorIterOrder,
580 ) -> Result<IndexedIterAxesMut<'a, T, B>>
581 where
582 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
583 {
584 let ndim: isize = TryInto::<isize>::try_into(self.ndim())?;
586 let axes: Vec<isize> = axes
587 .try_into()
588 .map_err(Into::into)?
589 .as_ref()
590 .iter()
591 .map(|&v| if v >= 0 { v } else { v + ndim })
592 .collect::<Vec<isize>>();
593 let mut axes_check = axes.clone();
594 axes_check.sort();
595 if axes.first().is_some_and(|&v| v < 0) {
597 return Err(rstsr_error!(InvalidValue, "Some negative index is too small."));
598 }
599 for i in 0..axes_check.len() - 1 {
600 rstsr_assert!(axes_check[i] != axes_check[i + 1], InvalidValue, "Same axes is not allowed here.")?;
601 }
602
603 let layout = self.layout().to_dim::<IxD>()?;
605 let shape_full = layout.shape();
606 let stride_full = layout.stride();
607 let offset = layout.offset();
608
609 let mut shape_axes = vec![];
611 let mut stride_axes = vec![];
612 for &idx in &axes {
613 shape_axes.push(shape_full[idx as usize]);
614 stride_axes.push(stride_full[idx as usize]);
615 }
616 let layout_axes = unsafe { Layout::new_unchecked(shape_axes, stride_axes, offset) };
617
618 let mut shape_inner = vec![];
620 let mut stride_inner = vec![];
621 for idx in 0..ndim {
622 if !axes.contains(&idx) {
623 shape_inner.push(shape_full[idx as usize]);
624 stride_inner.push(stride_full[idx as usize]);
625 }
626 }
627 let layout_inner = unsafe { Layout::new_unchecked(shape_inner, stride_inner, offset) };
628
629 let axes_iter = IterLayout::<IxD>::new(&layout_axes, order)?;
631 let mut view = self.view_mut().into_dyn();
632 view.layout = layout_inner.clone();
633 let iter = IndexedIterAxesMut { axes_iter, view };
634 Ok(iter)
635 }
636
637 pub fn indexed_axes_iter_mut_f<I>(&'a mut self, axes: I) -> Result<IndexedIterAxesMut<'a, T, B>>
638 where
639 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
640 {
641 let default_order = self.device().default_order();
642 let order = match default_order {
643 RowMajor => TensorIterOrder::C,
644 ColMajor => TensorIterOrder::F,
645 };
646 self.indexed_axes_iter_mut_with_order_f(axes, order)
647 }
648
649 pub fn indexed_axes_iter_mut_with_order<I>(
650 &'a mut self,
651 axes: I,
652 order: TensorIterOrder,
653 ) -> IndexedIterAxesMut<'a, T, B>
654 where
655 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
656 {
657 self.indexed_axes_iter_mut_with_order_f(axes, order).rstsr_unwrap()
658 }
659
660 pub fn indexed_axes_iter_mut<I>(&'a mut self, axes: I) -> IndexedIterAxesMut<'a, T, B>
661 where
662 I: TryInto<AxesIndex<isize>, Error: Into<Error>>,
663 {
664 self.indexed_axes_iter_mut_f(axes).rstsr_unwrap()
665 }
666}
667
668#[cfg(test)]
671mod tests_serial {
672 use super::*;
673
674 #[test]
675 fn test_axes_iter() {
676 let a = arange(120).into_shape([2, 3, 4, 5]);
677 let iter = a.axes_iter_f([0, 2]).unwrap();
678
679 let res = iter
680 .map(|view| {
681 println!("{view:3}");
682 view[[1, 2]]
683 })
684 .collect::<Vec<_>>();
685 #[cfg(not(feature = "col_major"))]
686 {
687 assert_eq!(res, vec![22, 27, 32, 37, 82, 87, 92, 97]);
691 }
692 #[cfg(feature = "col_major")]
693 {
694 assert_eq!(res, vec![50, 51, 56, 57, 62, 63, 68, 69]);
698 }
699 }
700
701 #[test]
702 fn test_axes_iter_mut() {
703 let mut a = arange(120).into_shape([2, 3, 4, 5]);
704 let iter = a.axes_iter_mut_with_order_f([0, 2], TensorIterOrder::C).unwrap();
705
706 let res = iter
707 .map(|mut view| {
708 view += 1;
709 println!("{view:3}");
710 view[[1, 2]]
711 })
712 .collect::<Vec<_>>();
713 println!("{res:?}");
714 #[cfg(not(feature = "col_major"))]
715 {
716 assert_eq!(res, vec![23, 28, 33, 38, 83, 88, 93, 98]);
720 }
721 #[cfg(feature = "col_major")]
722 {
723 assert_eq!(res, vec![51, 57, 63, 69, 52, 58, 64, 70]);
727 }
728 }
729
730 #[test]
731 fn test_indexed_axes_iter() {
732 let a = arange(120).into_shape([2, 3, 4, 5]);
733 let iter = a.indexed_axes_iter([0, 2]);
734
735 let res = iter
736 .map(|(index, view)| {
737 println!("{index:?}");
738 println!("{view:3}");
739 (index, view[[1, 2]])
740 })
741 .collect::<Vec<_>>();
742 #[cfg(not(feature = "col_major"))]
743 {
744 assert_eq!(res, vec![
748 (vec![0, 0], 22),
749 (vec![0, 1], 27),
750 (vec![0, 2], 32),
751 (vec![0, 3], 37),
752 (vec![1, 0], 82),
753 (vec![1, 1], 87),
754 (vec![1, 2], 92),
755 (vec![1, 3], 97)
756 ]);
757 }
758 #[cfg(feature = "col_major")]
759 {
760 assert_eq!(res, vec![
764 (vec![0, 0], 50),
765 (vec![1, 0], 51),
766 (vec![0, 1], 56),
767 (vec![1, 1], 57),
768 (vec![0, 2], 62),
769 (vec![1, 2], 63),
770 (vec![0, 3], 68),
771 (vec![1, 3], 69)
772 ]);
773 }
774 }
775}
776
777#[cfg(test)]
778#[cfg(feature = "rayon")]
779mod tests_parallel {
780 use super::*;
781 use rayon::prelude::*;
782
783 #[test]
784 fn test_axes_iter() {
785 let mut a = arange(65536).into_shape([16, 16, 16, 16]);
786 let iter = a.axes_iter_mut([0, 2]);
787
788 let res = iter
789 .into_par_iter()
790 .map(|mut view| {
791 view += 1;
792 println!("{view:6}");
793 view[[1, 2]]
794 })
795 .collect::<Vec<_>>();
796 println!("{res:?}");
797 #[cfg(not(feature = "col_major"))]
798 {
799 assert_eq!(res[..17], vec![
802 259, 275, 291, 307, 323, 339, 355, 371, 387, 403, 419, 435, 451, 467, 483, 499, 4355
803 ]);
804 }
805 #[cfg(feature = "col_major")]
806 {
807 assert_eq!(res[..17], vec![
811 8209, 8210, 8211, 8212, 8213, 8214, 8215, 8216, 8217, 8218, 8219, 8220, 8221, 8222, 8223, 8224, 8465
812 ]);
813 }
814 }
815}